Source code for tensorcv.predicts.simple
# File: predictions.py
# Author: Qian Ge <geqian1001@gmail.com>
import tensorflow as tf
from .base import Predictor
__all__ = ['SimpleFeedPredictor']
def assert_type(v, tp):
assert isinstance(v, tp),\
"Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"
[docs]class SimpleFeedPredictor(Predictor):
""" predictor with feed input """
# set_is_training
def __init__(self, config):
super(SimpleFeedPredictor, self).__init__(config)
# TODO change len_input to other
placeholders = self._model.get_prediction_placeholder()
if not isinstance(placeholders, list):
placeholders = [placeholders]
self._plhs = placeholders
# self.placeholder = self._model.get_random_vec_placeholder()
# assert self.len_input <= len(self.placeholder)
# self.placeholder = self.placeholder[0:self.len_input]
def _predict_step(self):
while self._input.epochs_completed < 1:
try:
cur_batch = self._input.next_batch()
except AttributeError:
cur_batch = self._input.next_batch()
feed = dict(zip(self._plhs, cur_batch))
self.hooked_sess.run(fetches=[], feed_dict=feed)
self._input.reset_epochs_completed(0)
def _after_prediction(self):
self._input.after_reading()