Source code for tensorcv.callbacks.summary

import scipy.misc
import os

import numpy as np
import tensorflow as tf

from .base import Callback

__all__ = ['TrainSummary']

[docs]class TrainSummary(Callback): def __init__(self, key=None, periodic=1): self.periodic = periodic if not key is None and not isinstance(key, list): key = [key] self._key = key def _setup_graph(self): self.summary_list = tf.summary.merge( [tf.summary.merge_all(key) for key in self._key]) # self.all_summary = tf.summary.merge_all(self._key) def _before_run(self, _): if self.global_step % self.periodic == 0: return tf.train.SessionRunArgs(fetches = self.summary_list) else: None def _after_run(self, _, val): if val.results is not None: self.trainer.monitors.process_summary(val.results)