Source code for tensorcv.callbacks.hooks

import tensorflow as tf

from .base import Callback
from .inferencer import InferencerBase
from ..predicts.predictions import PredictionBase

__all__ = ['Callback2Hook', 'Infer2Hook', 'Prediction2Hook']

def assert_type(v, tp):
    assert isinstance(v, tp), \
    "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"

[docs]class Callback2Hook(tf.train.SessionRunHook): """ """ def __init__(self, cb): self.cb = cb
[docs] def before_run(self, rct): return self.cb.before_run(rct)
[docs] def after_run(self, rct, val): self.cb.after_run(rct, val)
[docs]class Infer2Hook(tf.train.SessionRunHook): def __init__(self, inferencer): # to be modified assert_type(inferencer, InferencerBase) self.inferencer = inferencer
[docs] def before_run(self, rct): return tf.train.SessionRunArgs(fetches=self.inferencer.put_fetch())
[docs] def after_run(self, rct, val): self.inferencer.get_fetch(val)
[docs]class Prediction2Hook(tf.train.SessionRunHook): def __init__(self, prediction): assert_type(prediction, PredictionBase) self.prediction = prediction
[docs] def before_run(self, rct): return tf.train.SessionRunArgs(fetches=self.prediction.get_predictions())
[docs] def after_run(self, rct, val): self.prediction.after_prediction(val.results)