Source code for tensorcv.callbacks.saver

import scipy.misc
import os
import os

import numpy as np
import tensorflow as tf

from .base import Callback
from ..utils.common import check_dir

__all__ = ['ModelSaver']

[docs]class ModelSaver(Callback): def __init__(self, max_to_keep=5, keep_checkpoint_every_n_hours=0.5, periodic=1, checkpoint_dir=None, var_collections=tf.GraphKeys.GLOBAL_VARIABLES): self._periodic = periodic self._max_to_keep = max_to_keep self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours if not isinstance(var_collections, list): var_collections = [var_collections] self.var_collections = var_collections def _setup_graph(self): try: checkpoint_dir = os.path.join(self.trainer.default_dirs.checkpoint_dir) check_dir(checkpoint_dir) except AttributeError: raise AttributeError('checkpoint_dir is not set in config_path!') self._save_path = os.path.join(checkpoint_dir, 'model') self._saver = tf.train.Saver() def _trigger_step(self): if self.global_step % self._periodic == 0: self._saver.save(tf.get_default_session(), self._save_path, global_step = self.global_step)