Source code for tensorcv.train.simple
from abc import abstractmethod
import tensorflow as tf
from .config import TrainConfig, GANTrainConfig
from .base import Trainer
from ..callbacks.inputs import FeedInput
from ..callbacks.group import Callbacks
from ..callbacks.hooks import Callback2Hook
from ..models.base import BaseModel, GANBaseModel
from ..utils.sesscreate import ReuseSessionCreator
__all__ = ['SimpleFeedTrainer']
def assert_type(v, tp):
assert isinstance(v, tp),\
"Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"
[docs]class SimpleFeedTrainer(Trainer):
""" single optimizer """
def __init__(self, config):
assert_type(config.model, BaseModel)
super(SimpleFeedTrainer, self).__init__(config)
def _setup(self):
# TODO to be modified
cbs = FeedInput(self.dataflow, self.model.get_train_placeholder())
self.config.callbacks.append(cbs)
grads = self.model.get_grads()
opt = self.model.get_optimizer()
self.train_op = opt.apply_gradients(grads, name='train')
class GANFeedTrainer(Trainer):
def __init__(self, config):
assert_type(config, GANTrainConfig)
# assert_type(config.model, GANBaseModel)
# config.model.set_batch_size(config.batch_size)
super(GANFeedTrainer, self).__init__(config)
def _setup(self):
# TODO to be modified
# Since FeedInput only have before_run,
# it is safe to put this cb only in hooks.
cbs = FeedInput(self.dataflow, self.model.get_train_placeholder())
# self.config.callbacks.append(cbs)
self.feed_input_hook = [Callback2Hook(cbs)]
dis_grads = self.model.get_discriminator_grads()
dis_opt = self.model.get_discriminator_optimizer()
self.dis_train_op = dis_opt.apply_gradients(dis_grads,
name='discriminator_train')
gen_grads = self.model.get_generator_grads()
gen_opt = self.model.get_generator_optimizer()
self.gen_train_op = gen_opt.apply_gradients(gen_grads,
name='generator_train')
def _create_session(self):
self._dis_callbacks = Callbacks([cb
for cb in self.config.dis_callbacks])
self._gen_callbacks = Callbacks([cb
for cb in self.config.gen_callbacks])
dis_hooks = self._dis_callbacks.get_hooks()
gen_hooks = self._gen_callbacks.get_hooks()
self.sess = self.config.session_creator.create_session()
self.dis_hooked_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess),
hooks=dis_hooks + self.feed_input_hook)
self.gen_hooked_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess),
hooks=gen_hooks)
def _run_step(self):
model_feed = self.model.get_graph_feed()
self.dis_hooked_sess.run(self.dis_train_op, feed_dict=model_feed)
for k in range(0,2):
model_feed = self.model.get_graph_feed()
self.gen_hooked_sess.run(self.gen_train_op, feed_dict=ßmodel_feed)