Source code for tensorcv.dataflow.dataset.MNIST
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: MNIST.py
# Author: Qian Ge <geqian1001@gmail.com>
import os
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from ..base import RNGDataFlow
__all__ = ['MNIST', 'MNISTLabel']
def get_mnist_im_label(name, mnist_data):
if name == 'train':
return mnist_data.train.images, mnist_data.train.labels
elif name == 'val':
return mnist_data.validation.images, mnist_data.validation.labels
else:
return mnist_data.test.images, mnist_data.test.labels
# TODO read data without tensorflow
[docs]class MNIST(RNGDataFlow):
"""
"""
def __init__(self, name, data_dir='', shuffle=True, normalize=None):
self.num_channels = 1
self.im_size = [28, 28]
assert os.path.isdir(data_dir)
self.data_dir = data_dir
self.shuffle = shuffle
self._normalize = normalize
assert name in ['train', 'test', 'val']
self.setup(epoch_val=0, batch_size=1)
self._load_files(name)
self._num_image = self.size()
self._image_id = 0
def _load_files(self, name):
mnist_data = input_data.read_data_sets(self.data_dir, one_hot=False)
self.im_list = []
self.label_list = []
mnist_images, mnist_labels = get_mnist_im_label(name, mnist_data)
for image, label in zip(mnist_images, mnist_labels):
# TODO to be modified
if self._normalize == 'tanh':
image = image*2.-1.
image = np.reshape(image, [28, 28, 1])
self.im_list.append(image)
self.label_list.append(label)
self.im_list = np.array(self.im_list)
self.label_list = np.array(self.label_list)
if self.shuffle:
self._suffle_files()
def _suffle_files(self):
idxs = np.arange(self.size())
self.rng.shuffle(idxs)
self.im_list = self.im_list[idxs]
self.label_list = self.label_list[idxs]
[docs] def size(self):
return self.im_list.shape[0]
[docs] def next_batch(self):
assert self._batch_size <= self.size(), \
"batch_size {} cannot be larger than data size {}".\
format(self._batch_size, self.size())
start = self._image_id
self._image_id += self._batch_size
end = self._image_id
batch_files = self.im_list[start:end]
if self._image_id + self._batch_size > self._num_image:
self._epochs_completed += 1
self._image_id = 0
if self.shuffle:
self._suffle_files()
return [batch_files]
[docs]class MNISTLabel(MNIST):
[docs] def next_batch(self):
assert self._batch_size <= self.size(), \
"batch_size {} cannot be larger than data size {}".\
format(self._batch_size, self.size())
start = self._image_id
self._image_id += self._batch_size
end = self._image_id
batch_im = self.im_list[start:end]
batch_label = self.label_list[start:end]
if self._image_id + self._batch_size > self._num_image:
self._epochs_completed += 1
self._image_id = 0
if self.shuffle:
self._suffle_files()
return [batch_im, batch_label]
if __name__ == '__main__':
a = MNISTLabel('val','D:\\Qian\\GitHub\\workspace\\tensorflow-DCGAN\\MNIST_data\\')
t = a.next_batch()
print(t)