Source code for tensorcv.train.base
from abc import abstractmethod
import weakref
import os
import tensorflow as tf
from .config import TrainConfig
from ..callbacks.base import Callback
from ..callbacks.group import Callbacks
from ..utils.sesscreate import ReuseSessionCreator
from ..callbacks.monitors import TrainingMonitor, Monitors
__all__ = ['Trainer']
def assert_type(v, tp):
assert isinstance(v, tp),\
"Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"
[docs]class Trainer(object):
""" base class for trainer """
def __init__(self, config):
assert_type(config, TrainConfig)
self._is_load = config.is_load
self.config = config
self.model = config.model
self.model.ex_init_model(config.dataflow, weakref.proxy(self))
self.dataflow = config.dataflow
# self.monitors = self.config.monitors
self._global_step = 0
self._callbacks = []
self.monitors = []
self.default_dirs = config.default_dirs
@property
def epochs_completed(self):
return self.dataflow.epochs_completed
@property
def get_global_step(self):
return self._global_step
[docs] def register_callback(self, cb):
assert_type(cb, Callback)
assert not isinstance(self._callbacks, Callbacks), \
"callbacks have been setup"
self._callbacks.append(cb)
[docs] def register_monitor(self, monitor):
assert_type(monitor, TrainingMonitor)
assert not isinstance(self.monitors, Monitors), \
"monitors have been setup"
self.monitors.append(monitor)
self.register_callback(monitor)
def _create_session(self):
hooks = self._callbacks.get_hooks()
self.sess = self.config.session_creator.create_session()
self.hooked_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=hooks)
if self._is_load:
load_model_path = os.path.join(self.config.model_dir,
self.config.model_name)
saver = tf.train.Saver()
saver.restore(self.sess, load_model_path)
[docs] def main_loop(self):
with self.sess.as_default():
self._callbacks.before_train()
while self.epochs_completed <= self.config.max_epoch:
self._global_step += 1
print('Epoch: {}. Step: {}'.\
format(self.epochs_completed, self._global_step))
# self._callbacks.before_epoch()
# TODO to be modified
self.model.set_is_training(True)
self._run_step()
# self._callbacks.after_epoch()
self._callbacks.trigger_step()
self._callbacks.after_train()
[docs] def train(self):
self.setup()
self.main_loop()
@abstractmethod
def _run_step(self):
model_feed = self.model.get_graph_feed()
self.hooked_sess.run(self.train_op, feed_dict=model_feed)
[docs] def setup(self):
# setup graph from model
self.setup_graph()
# setup callbacks
for cb in self.config.callbacks:
self.register_callback(cb)
for monitor in self.config.monitors:
self.register_monitor(monitor)
self._callbacks = Callbacks(self._callbacks)
self._callbacks.setup_graph(weakref.proxy(self))
self.monitors = Monitors(self.monitors)
# create session
self._create_session()
self.sess.graph.finalize()
[docs] def setup_graph(self):
self.model.create_graph()
self._setup()
self.model.setup_summary()
def _setup(self):
pass