Source code for tensorcv.dataflow.randoms
# File: randoms.py
# Author: Qian Ge <geqian1001@gmail.com>
import numpy as np
from .base import DataFlow
from ..utils.utils import get_rng
__all__ = ['RandomVec']
[docs]class RandomVec(DataFlow):
""" random vector input """
def __init__(self,
len_vec=100):
self.setup(epoch_val=0, batch_size=1)
self._len_vec = len_vec
[docs] def next_batch(self):
self._epochs_completed += 1
return [np.random.normal(size=(self._batch_size, self._len_vec))]
[docs] def size(self):
return self._batch_size
[docs] def reset_state(self):
self._reset_state()
def _reset_state(self):
self.rng = get_rng(self)
if __name__ == '__main__':
vec = RandomVec()
print(vec.next_batch())
print(vec.next_batch())