Source code for tensorcv.train.config

import scipy.misc
import os
import numpy as np

from ..dataflow.base import DataFlow
from ..models.base import ModelDes, GANBaseModel
from ..utils.default import get_default_session_config
from ..utils.sesscreate import NewSessionCreator
from ..callbacks.monitors import TFSummaryWriter
from ..callbacks.summary import TrainSummary
from ..utils.common import check_dir

__all__ = ['TrainConfig', 'GANTrainConfig']

def assert_type(v, tp):
    assert isinstance(v, tp),\
    "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"

[docs]class TrainConfig(object): def __init__(self, dataflow=None, model=None, callbacks=[], session_creator=None, monitors=None, batch_size=1, max_epoch=100, summary_periodic=None, is_load=False, model_name=None, default_dirs=None): self.default_dirs = default_dirs assert_type(monitors, TFSummaryWriter), \ "monitors has to be TFSummaryWriter at this point!" if not isinstance(monitors, list): monitors = [monitors] self.monitors = monitors assert dataflow is not None, "dataflow cannot be None!" assert_type(dataflow, DataFlow) self.dataflow = dataflow assert model is not None, "model cannot be None!" assert_type(model, ModelDes) self.model = model assert batch_size > 0 and max_epoch > 0 self.dataflow.set_batch_size(batch_size) self.model.set_batch_size(batch_size) self.batch_size = batch_size self.max_epoch = max_epoch self.is_load = is_load if is_load: assert not model_name is None,\ '[TrainConfig]: model_name cannot be None when is_load is True!' self.model_name = model_name try: self.model_dir = os.path.join(default_dirs.model_dir) check_dir(self.model_dir) except AttributeError: raise AttributeError('model_dir is not set!') # if callbacks is None: # callbacks = [] if not isinstance(callbacks, list): callbacks = [callbacks] self._callbacks = callbacks # TODO model.default_collection only in BaseModel class if isinstance(summary_periodic, int): self._callbacks.append( TrainSummary(key=model.default_collection, periodic=summary_periodic)) if session_creator is None: self.session_creator = \ NewSessionCreator(config=get_default_session_config()) else: raise ValueError('custormer session creator is not allowed at this point!') @property def callbacks(self): return self._callbacks
[docs]class GANTrainConfig(TrainConfig): def __init__(self, dataflow=None, model=None, discriminator_callbacks=[], generator_callbacks=[], session_creator=None, monitors=None, batch_size=1, max_epoch=100, summary_d_periodic=None, summary_g_periodic=None, default_dirs=None): assert_type(model, GANBaseModel) if not isinstance(discriminator_callbacks, list): discriminator_callbacks = [discriminator_callbacks] self._dis_callbacks = discriminator_callbacks if not isinstance(generator_callbacks, list): generator_callbacks = [generator_callbacks] self._gen_callbacks = generator_callbacks if isinstance(summary_d_periodic, int): self._dis_callbacks.append( TrainSummary(key=model.d_collection, periodic=summary_d_periodic)) if isinstance(summary_g_periodic, int): self._dis_callbacks.append( TrainSummary(key=model.g_collection, periodic=summary_g_periodic)) callbacks = self._dis_callbacks + self._gen_callbacks super(GANTrainConfig, self).__init__( dataflow=dataflow, model=model, callbacks=callbacks, session_creator=session_creator, monitors=monitors, batch_size=batch_size, max_epoch=ßmax_epoch, default_dirs=default_dirs) @property def dis_callbacks(self): return self._dis_callbacks @property def gen_callbacks(self): return self._gen_callbacks