Source code for tensorcv.utils.common
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: common.py
# Author: Qian Ge <geqian1001@gmail.com>
import math
import os
import tensorflow as tf
__all__ = ['apply_mask', 'apply_mask_inverse', 'get_tensors_by_names',
'deconv_size', 'match_tensor_save_name']
[docs]def apply_mask(input_matrix, mask):
"""Get partition of input_matrix using index 1 in mask.
Args:
input_matrix (Tensor): A Tensor
mask (int): A Tensor of type int32 with indices in {0, 1}. Shape
has to be the same as input_matrix.
Return:
A Tensor with elements from data with entries in mask equal to 1.
"""
return tf.dynamic_partition(input_matrix, mask, 2)[1]
[docs]def apply_mask_inverse(input_matrix, mask):
"""Get partition of input_matrix using index 0 in mask.
Args:
input_matrix (Tensor): A Tensor
mask (int): A Tensor of type int32 with indices in {0, 1}. Shape
has to be the same as input_matrix.
Return:
A Tensor with elements from data with entries in mask equal to 0.
"""
return tf.dynamic_partition(input_matrix, mask, 2)[0]
[docs]def get_tensors_by_names(names):
"""Get a list of tensors by the input name list.
Args:
names (str): A str or a list of str
Return:
A list of tensors with name in input names.
Warning:
If more than one tensor have the same name in the graph. This function
will only return the tensor with name NAME:0.
"""
if not isinstance(names, list):
names = [names]
graph = tf.get_default_graph()
tensor_list = []
# TODO assume there is no repeativie names
for name in names:
tensor_name = name + ':0'
tensor_list += graph.get_tensor_by_name(tensor_name),
return tensor_list
[docs]def deconv_size(input_height, input_width, stride=2):
"""
Compute the feature size (height and width) after filtering with
a specific stride. Mostly used for setting the shape for deconvolution.
Args:
input_height (int): height of input feature
input_width (int): width of input feature
stride (int): stride of the filter
Return:
(int, int): Height and width of feature after filtering.
"""
print('***** WARNING ********: deconv_size is moved to models.utils.py')
return int(math.ceil(float(input_height) / float(stride))),\
int(math.ceil(float(input_width) / float(stride)))
[docs]def match_tensor_save_name(tensor_names, save_names):
"""
Match tensor_names and corresponding save_names for saving the results of
the tenors. If the number of tensors is less or equal to the length
of save names, tensors will be saved using the corresponding names in
save_names. Otherwise, tensors will be saved using their own names.
Used for prediction or inference.
Args:
tensor_names (str): List of tensor names
save_names (str): List of names for saving tensors
Return:
(list, list): List of tensor names and list of names to save
the tensors.
"""
if not isinstance(tensor_names, list):
tensor_names = [tensor_names]
if save_names is None:
return tensor_names, tensor_names
elif not isinstance(save_names, list):
save_names = [save_names]
if len(save_names) < len(tensor_names):
return tensor_names, tensor_names
else:
return tensor_names, save_names
def check_dir(input_dir):
print('***** WARNING ********: check_dir is moved to utils.utils.py')
assert input_dir is not None, "dir cannot be None!"
assert os.path.isdir(input_dir), input_dir + ' does not exist!'
def assert_type(v, tp):
print('***** WARNING ********: assert_type is moved to utils.utils.py')
"""
Assert type of input v be type tp
"""
assert isinstance(v, tp),\
"Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"