Source code for tensorcv.predicts.predictions

# File: predictions.py
# Author: Qian Ge <geqian1001@gmail.com>

import os
import scipy.io
import scipy.misc

import tensorflow as tf
import numpy as np

from ..utils.common import get_tensors_by_names
from ..utils.viz import *

__all__ = ['PredictionImage', 'PredictionScalar', 'PredictionMat', 'PredictionMeanScalar', 'PredictionOverlay']

def assert_type(v, tp):
    assert isinstance(v, tp), \
    "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"

class PredictionBase(object):
    """ base class for prediction 

    Attributes:
        _predictions
        _prefix_list
        _global_ind
        _save_dir
    """
    def __init__(self, prediction_tensors, save_prefix):
        """ init prediction object

        Get tensors to be predicted and the prefix for saving 
        each tensors

        Args:
            prediction_tensors : list[string] A tensor name or list of tensor names
            save_prefix: list[string] A string or list of strings
            Length of prediction_tensors and save_prefix have 
            to be the same
        """
        if not isinstance(prediction_tensors, list):
            prediction_tensors = [prediction_tensors]
        if not isinstance(save_prefix, list):
            save_prefix = [save_prefix]
        assert len(prediction_tensors) == len(save_prefix), \
        'Length of prediction_tensors {} and save_prefix {} has to be the same'.\
        format(len(prediction_tensors), len(save_prefix))

        self._predictions = prediction_tensors
        self._prefix_list = save_prefix
        self._global_ind = 0

    def setup(self, result_dir):
        assert os.path.isdir(result_dir)
        self._save_dir = result_dir

        self._predictions = get_tensors_by_names(self._predictions)

    def get_predictions(self):
        return self._predictions

    def after_prediction(self, results):
        """ process after predition
            default to save predictions
        """
        self._save_prediction(results)

    def _save_prediction(self, results):
        pass

    def after_finish_predict(self):
        """ process after all prediction steps """
        self._after_finish_predict()

    def _after_finish_predict(self):
        pass

[docs]class PredictionImage(PredictionBase): """ Predict image output and save as files. Images are saved every batch. Each batch result can be save in one image or individule images. """
[docs] def __init__(self, prediction_image_tensors, save_prefix, merge_im=False, tanh=False, color=False): """ Args: prediction_image_tensors (list): a list of tensor names save_prefix (list): a list of file prefix for saving each tensor in prediction_image_tensors merge_im (bool): merge output of one batch or not """ self._merge = merge_im self._tanh = tanh self._color = color super(PredictionImage, self).__init__(prediction_tensors=prediction_image_tensors, save_prefix=save_prefix)
def _save_prediction(self, results): for re, prefix in zip(results, self._prefix_list): cur_global_ind = self._global_ind if self._merge and re.shape[0] > 1: grid_size = self._get_grid_size(re.shape[0]) save_path = os.path.join(self._save_dir, str(cur_global_ind) + '_' + prefix + '.png') save_merge_images(np.squeeze(re), [grid_size, grid_size], save_path, tanh=self._tanh, color=self._color) cur_global_ind += 1 else: for im in re: save_path = os.path.join(self._save_dir, str(cur_global_ind) + '_' + prefix + '.png') if self._color: im = intensity_to_rgb(np.squeeze(im), normalize=True) scipy.misc.imsave(save_path, np.squeeze(im)) cur_global_ind += 1 self._global_ind = cur_global_ind def _get_grid_size(self, batch_size): try: return self._grid_size except AttributeError: self._grid_size = np.ceil(batch_size**0.5).astype(int) return self._grid_size
[docs]class PredictionOverlay(PredictionImage): def __init__(self, prediction_image_tensors, save_prefix, merge_im=False, tanh=False, color=False): if not isinstance(prediction_image_tensors, list): prediction_image_tensors = [prediction_image_tensors] assert len(prediction_image_tensors) == 2,\ '[PredictionOverlay] requires two image tensors but the input len = {}.'.\ format(len(prediction_image_tensors)) super(PredictionOverlay, self).__init__(prediction_image_tensors, save_prefix, merge_im=merge_im, tanh=tanh, color=color) self._overlay_prefix = '{}_{}'.format(self._prefix_list[0], self._prefix_list[1]) def _save_prediction(self, results): cur_global_ind = self._global_ind if self._merge and results[0].shape[0] > 1: overlay_im_list = [] for im_1, im_2 in zip(results[0], results[1]): overlay_im = image_overlay(im_1, im_2, color=self._color) overlay_im_list.append(overlay_im) grid_size = self._get_grid_size(results[0].shape[0]) save_path = os.path.join(self._save_dir, str(cur_global_ind) + '_' + self._overlay_prefix + '.png') save_merge_images(np.squeeze(overlay_im_list), [grid_size, grid_size], save_path, tanh=self._tanh, color=False) cur_global_ind += 1 else: for im_1, im_2 in zip(results[0], results[1]): overlay_im = image_overlay(im_1, im_2, color=self._color) save_path = os.path.join(self._save_dir, str(cur_global_ind) + '_' + self._overlay_prefix + '.png') scipy.misc.imsave(save_path, np.squeeze(overlay_im)) cur_global_ind += 1 self._global_ind = cur_global_ind
[docs]class PredictionScalar(PredictionBase):
[docs] def __init__(self, prediction_scalar_tensors, print_prefix): """ Args: prediction_scalar_tensors (list): a list of tensor names print_prefix (list): a list of name prefix for printing each tensor in prediction_scalar_tensors """ super(PredictionScalar, self).__init__(prediction_tensors=prediction_scalar_tensors, save_prefix=print_prefix)
def _save_prediction(self, results): for re, prefix in zip(results, self._prefix_list): print('{} = {}'.format(prefix, re))
[docs]class PredictionMeanScalar(PredictionScalar): def __init__(self, prediction_scalar_tensors, print_prefix): super(PredictionMeanScalar, self).__init__(prediction_scalar_tensors=prediction_scalar_tensors, print_prefix=print_prefix) self.scalar_list = [[] for i in range(0, len(self._predictions))] def _save_prediction(self, results): cnt = 0 for re, prefix in zip(results, self._prefix_list): print('{} = {}'.format(prefix, re)) self.scalar_list[cnt].append(re) cnt += 1 def _after_finish_predict(self): for i, prefix in enumerate(self._prefix_list): print('Overall {} = {}'.format(prefix, np.mean(self.scalar_list[i])))
[docs]class PredictionMat(PredictionBase): def _save_prediction(self, results): save_path = os.path.join(self._save_dir, str(self._global_ind) + '_' + 'batch_test' + '.mat') scipy.io.savemat(save_path, {name: np.squeeze(val) for name, val in zip(self._prefix_list, results)}) self._global_ind += 1