Source code for tensorcv.callbacks.hooks
import tensorflow as tf
from .base import Callback
from .inferencer import InferencerBase
from ..predicts.predictions import PredictionBase
__all__ = ['Callback2Hook', 'Infer2Hook', 'Prediction2Hook']
def assert_type(v, tp):
assert isinstance(v, tp), \
"Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"
[docs]class Callback2Hook(tf.train.SessionRunHook):
""" """
def __init__(self, cb):
self.cb = cb
[docs] def before_run(self, rct):
return self.cb.before_run(rct)
[docs] def after_run(self, rct, val):
self.cb.after_run(rct, val)
[docs]class Infer2Hook(tf.train.SessionRunHook):
def __init__(self, inferencer):
# to be modified
assert_type(inferencer, InferencerBase)
self.inferencer = inferencer
[docs] def before_run(self, rct):
return tf.train.SessionRunArgs(fetches=self.inferencer.put_fetch())
[docs] def after_run(self, rct, val):
self.inferencer.get_fetch(val)
[docs]class Prediction2Hook(tf.train.SessionRunHook):
def __init__(self, prediction):
assert_type(prediction, PredictionBase)
self.prediction = prediction
[docs] def before_run(self, rct):
return tf.train.SessionRunArgs(fetches=self.prediction.get_predictions())
[docs] def after_run(self, rct, val):
self.prediction.after_prediction(val.results)