Source code for tensorcv.callbacks.base

import scipy.misc
import os
from abc import ABCMeta

import numpy as np
import tensorflow as tf

__all__ = ['Callback', 'ProxyCallback']

def assert_type(v, tp):
    assert isinstance(v, tp),\
     "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"

[docs]class Callback(object): """ base class for callbacks """
[docs] def setup_graph(self, trainer): self.trainer = trainer self._setup_graph()
@property def global_step(self): return self.trainer.get_global_step @property def epochs_completed(self): return self.trainer.epochs_completed def _setup_graph(self): pass
[docs] def before_run(self, rct): fetch = self._before_run(rct) if fetch is None: return None assert_type(fetch, tf.train.SessionRunArgs) return fetch
def _before_run(self, rct): return None
[docs] def after_run(self, rct, val): self._after_run(rct, val)
def _after_run(self, rct, val): pass
[docs] def before_train(self): self._before_train()
def _before_train(self): pass
[docs] def before_inference(self): self._before_inference()
def _before_inference(self): pass
[docs] def after_train(self): self._after_train()
def _after_train(self): pass
[docs] def before_epoch(self): self._before_epoch()
def _before_epoch(self): pass
[docs] def after_epoch(self): self._after_epoch()
def _after_epoch(self): pass
[docs] def trigger_epoch(self): self._trigger_epoch()
def _trigger_epoch(self): self.trigger()
[docs] def trigger_step(self): self._trigger_step()
def _trigger_step(self): pass
[docs] def trigger(self): self._trigger()
def _trigger(self): pass
# def before_run(self):
[docs]class ProxyCallback(Callback): def __init__(self, cb): assert_type(cb, Callback) self.cb = cb def __str__(self): return "Proxy-" + str(self.cb) def _before_train(self): self.cb.before_train() def _before_inference(self): self.cb.before_inference() def _setup_graph(self): with tf.name_scope(None): self.cb.setup_graph(self.trainer) def _trigger_epoch(self): self.cb.trigger_epoch() def _trigger(self): self.cb.trigger() def _trigger_step(self): self.cb.trigger_step() def _after_train(self): self.cb.after_train() def _before_epoch(self): self.cb.before_epoch() def _after_epoch(self): self.cb.after_epoch() def _before_run(self, crt): self.cb.before_run(crt) def _after_run(self, crt, val): self.cb.after_run(crt, val)