Source code for tensorcv.dataflow.dataset.CIFAR

# File: CIFAR.py
# Author: Qian Ge <geqian1001@gmail.com>

import os
import pickle

import numpy as np 

from ..base import RNGDataFlow

__all__ = ['CIFAR']

## TODO Add batch size
[docs]class CIFAR(RNGDataFlow): def __init__(self, data_dir='', shuffle=True, normalize=None): self.num_channels = 3 self.im_size = [32, 32] assert os.path.isdir(data_dir) self.data_dir = data_dir self.shuffle = shuffle self._normalize = normalize self.setup(epoch_val=0, batch_size=1) self._file_list = [os.path.join(data_dir, 'data_batch_' + str(batch_id)) for batch_id in range(1,6)] # self._load_files() self._num_image = self.size() self._image_id = 0 self._batch_file_id = -1 self._image = [] self._next_batch_file() def _next_batch_file(self): if self._batch_file_id >= len(self._file_list) - 1: self._batch_file_id = 0 self._epochs_completed += 1 else: self._batch_file_id += 1 self._image = np.array(unpickle(self._file_list[self._batch_file_id])) # TODO to be modified if self._normalize == 'tanh': self._image = (self._image*1. - 128)/128.0 if self.shuffle: self._suffle_files() def _suffle_files(self): idxs = np.arange(len(self._image)) self.rng.shuffle(idxs) self._image = self._image[idxs]
[docs] def size(self): try: return self.data_size except AttributeError: data_size = 0 for k in range(len(self._file_list)): tmp_image = unpickle(self._file_list[k]) data_size += len(tmp_image) self.data_size = data_size return self.data_size
[docs] def next_batch(self): # TODO assume batch_size smaller than images in one file assert self._batch_size <= self.size(), \ "batch_size {} cannot be larger than data size {}".\ format(self._batch_size, self.size()) start = self._image_id self._image_id += self._batch_size end = self._image_id batch_files = np.array(self._image[start:end]) if self._image_id + self._batch_size > len(self._image): self._next_batch_file() self._image_id = 0 if self.shuffle: self._suffle_files() return [batch_files]
def unpickle(file): with open(file, 'rb') as fo: dict = pickle.load(fo, encoding='bytes') image = dict[b'data'] r = image[:,:32*32].reshape(-1,32,32) g = image[:,32*32: 2*32*32].reshape(-1,32,32) b = image[:,2*32*32:].reshape(-1,32,32) image = np.stack((r,g,b),axis=-1) return image if __name__ == '__main__': a = CIFAR('D:\\Qian\\GitHub\\workspace\\tensorflow-DCGAN\\cifar-10-python.tar\\') t = a.next_batch()[0] print(t) print(t.shape) print(a.size()) # print(a.next_batch()[0]) # print(a.next_batch()[0])