Source code for tensorcv.callbacks.inputs

# File: input.py
# Author: Qian Ge <geqian1001@gmail.com>

import scipy.misc
import os

import tensorflow as tf

from .base import Callback
from ..dataflow.base import DataFlow

__all__ = ['FeedInput']

def assert_type(v, tp):
    assert isinstance(v, tp), \
    "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"

[docs]class FeedInput(Callback): """ input using feed """ def __init__(self, dataflow, placeholders): assert_type(dataflow, DataFlow) self.dataflow = dataflow if not isinstance(placeholders, list): print(type(placeholders)) placeholders = [placeholders] self.placeholders = placeholders # def _setup_graph(self): # pass def _setup_graph(self): pass # self.dataflow._setup(num_epoch=self.trainer.config.max_epoch) def _before_train(self): self.dataflow.before_read_setup() def _before_inference(self): self._before_train() def _before_run(self, _): cur_batch = self.dataflow.next_batch() # assert len(cur_batch) == len(self.placeholders), \ # "[FeedInput] lenght of input {} is not equal to length of placeholders {}"\ # .format(len(cur_batch), len(self.placeholders)) feed = dict(zip(self.placeholders, cur_batch)) return tf.train.SessionRunArgs(fetches=[], feed_dict=feed) def _after_train(self): self.dataflow.after_reading()