Source code for tensorcv.callbacks.debug

# File: inference.py
# Author: Qian Ge <geqian1001@gmail.com>

import tensorflow as tf

from .base import Callback
from ..utils.common import get_tensors_by_names

__all__ = ['CheckScalar']

def assert_type(v, tp):
    assert isinstance(v, tp), \
    "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"

[docs]class CheckScalar(Callback): """ print scalar tensor values during training Attributes: _tensors _names """
[docs] def __init__(self, tensors, periodic=1): """ init CheckScalar object Args: tensors : list[string] A tensor name or list of tensor names """ if not isinstance(tensors, list): tensors = [tensors] self._tensors = tensors self._names = tensors self._periodic = periodic
def _setup_graph(self): self._tensors = get_tensors_by_names(self._tensors) def _before_run(self, _): if self.global_step % self._periodic == 0: return tf.train.SessionRunArgs(fetches = self._tensors) else: return None def _after_run(self, _, val): if val.results is not None: print([name + ': ' + str(v) for name, v in zip(self._names, val.results)])