Source code for tensorcv.predicts.base

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: base.py
# Author: Qian Ge <geqian1001@gmail.com>

import os

import tensorflow as tf

from .config import PridectConfig
from ..utils.sesscreate import ReuseSessionCreator
from ..utils.common import assert_type
from ..callbacks.hooks import Prediction2Hook

__all__ = ['Predictor']


[docs]class Predictor(object): """Base class for a predictor. Used to run all predictions. Attributes: config (PridectConfig): the config used for this predictor model (ModelDes): input (DataFlow): sess (tf.Session): hooked_sess (tf.train.MonitoredSession): """
[docs] def __init__(self, config): """ Inits Predictor with config (PridectConfig). Will create session as well as monitored sessions for each predictions, and load pre-trained parameters. Args: config (PridectConfig): the config used for this predictor """ assert_type(config, PridectConfig) self._config = config self._model = config.model self._input = config.dataflow self._result_dir = config.result_dir # TODO to be modified self._model.set_is_training(False) self._model.create_graph() self._restore_vars = self._config.restore_vars # pass saving directory to predictions for pred in self._config.predictions: pred.setup(self._result_dir) hooks = [Prediction2Hook(pred) for pred in self._config.predictions] self.sess = self._config.session_creator.create_session() self.hooked_sess = tf.train.MonitoredSession( session_creator=ReuseSessionCreator(self.sess), hooks=hooks) # load pre-trained parameters load_model_path = os.path.join(self._config.model_dir, self._config.model_name) if self._restore_vars is not None: # variables = tf.contrib.framework.get_variables_to_restore() # variables_to_restore = [v for v in variables if v.name.split('/')[0] in self._restore_vars] # print(variables_to_restore) saver = tf.train.Saver(self._restore_vars) else: saver = tf.train.Saver() saver.restore(self.sess, load_model_path)
[docs] def run_predict(self): """ Run predictions and the process after finishing predictions. """ with self.sess.as_default(): self._input.before_read_setup() self._predict_step() for pred in self._config.predictions: pred.after_finish_predict() self.after_prediction()
def _predict_step(self): """ Run predictions. Defined in subclass. """ pass
[docs] def after_prediction(self): self._after_prediction()
def _after_prediction(self): pass