Source code for tensorcv.dataflow.base

from abc import abstractmethod, ABCMeta
import numpy as np 

from ..utils.utils import get_rng

__all__ = ['DataFlow', 'RNGDataFlow']

# @six.add_metaclass(ABCMeta)
[docs]class DataFlow(object): """ base class for dataflow """ # self._epochs_completed = 0
[docs] def before_read_setup(self, **kwargs): pass
[docs] def setup(self, epoch_val, batch_size, **kwargs): self.reset_epochs_completed(epoch_val) self.set_batch_size(batch_size) self.reset_state() self._setup()
def _setup(self, **kwargs): pass # @property # def channels(self): # try: # return self._num_channels # except AttributeError: # self._num_channels = self._get_channels() # return self._num_channels # def _get_channels(self): # return 0 # @property # def im_size(self): # try: # return self._im_size # except AttributeError: # self._im_size = self._get_im_size() # return self._im_size def _get_im_size(self): return 0 @property def epochs_completed(self): return self._epochs_completed
[docs] def reset_epochs_completed(self, val): self._epochs_completed = val
[docs] @abstractmethod def next_batch(self): return
[docs] def next_batch_dict(self): print('Need to be implemented!')
[docs] def set_batch_size(self, batch_size): self._batch_size = batch_size
[docs] def size(self): raise NotImplementedError()
[docs] def reset_state(self): self._reset_state()
def _reset_state(self): pass
[docs] def after_reading(self): pass
[docs]class RNGDataFlow(DataFlow): def _reset_state(self): self.rng = get_rng(self) def _suffle_file_list(self): idxs = np.arange(self.size()) self.rng.shuffle(idxs) self.file_list = self.file_list[idxs]
[docs] def suffle_data(self): self._suffle_file_list()