Source code for tensorcv.callbacks.group
import scipy.misc
import os
import numpy as np
import tensorflow as tf
from .base import Callback
from .hooks import Callback2Hook
__all__ = ['Callbacks']
def assert_type(v, tp):
assert isinstance(v, tp),\
"Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"
[docs]class Callbacks(Callback):
""" group all the callback """
def __init__(self, cbs):
for cb in cbs:
assert_type(cb, Callback)
self.cbs = cbs
def _setup_graph(self):
with tf.name_scope(None):
for cb in self.cbs:
cb.setup_graph(self.trainer)
[docs] def get_hooks(self):
return [Callback2Hook(cb) for cb in self.cbs]
def _before_train(self):
for cb in self.cbs:
cb.before_train()
def _before_inference(self):
for cb in self.cbs:
cb.before_inference()
def _after_train(self):
for cb in self.cbs:
cb.after_train()
def _before_epoch(self):
for cb in self.cbs:
cb.before_epoch()
def _after_epoch(self):
for cb in self.cbs:
cb.after_epoch()
def _trigger_epoch(self):
for cb in self.cbs:
cb.trigger_epoch()
def _trigger_step(self):
for cb in self.cbs:
cb.trigger_step()
# def trigger(self):
# self._trigger()
# def _trigger(self):
# pass
# def before_run(self):