Source code for tensorcv.dataflow.matlab

# File: matlab.py
# Author: Qian Ge <geqian1001@gmail.com>

import os
from scipy.io import loadmat

import numpy as np 

from .base import RNGDataFlow
from .common import *

__all__ = ['MatlabData']

[docs]class MatlabData(RNGDataFlow): """ dataflow from .mat file with mask """ def __init__(self, data_dir='', mat_name_list=None, mat_type_list=None, shuffle=True, normalize=None): self.setup(epoch_val=0, batch_size=1) self.shuffle = shuffle self._normalize = normalize assert os.path.isdir(data_dir) self.data_dir = data_dir assert mat_name_list is not None, 'mat_name_list cannot be None' if not isinstance(mat_name_list, list): mat_name_list = [mat_name_list] self._mat_name_list = mat_name_list if mat_type_list is None: mat_type_list = ['float']*len(self._mat_name_list) assert len(self._mat_name_list) == len(mat_type_list),\ 'Length of mat_name_list and mat_type_list has to be the same!' self._mat_type_list = mat_type_list self._load_file_list() self._get_im_size() self._num_image = self.size() self._image_id = 0 def _get_im_size(self): # Run after _load_file_list # Assume all the image have the same size mat = loadmat(self.file_list[0]) cur_mat = load_image_from_mat(mat, self._mat_name_list[0], self._mat_type_list[0]) if len(cur_mat.shape) < 3: self.num_channels = 1 else: self.num_channels = cur_mat.shape[2] self.im_size = [cur_mat.shape[0], cur_mat.shape[1]] def _load_file_list(self): # data_dir = os.path.join(self.data_dir) self.file_list = np.array([os.path.join(self.data_dir, file) for file in os.listdir(self.data_dir) if file.endswith(".mat")]) if self.shuffle: self._suffle_file_list()
[docs] def next_batch(self): assert self._batch_size <= self.size(), \ "batch_size cannot be larger than data size" start = self._image_id self._image_id += self._batch_size end = self._image_id batch_file_path = self.file_list[start:end] if self._image_id + self._batch_size > self._num_image: self._epochs_completed += 1 self._image_id = 0 if self.shuffle: self._suffle_file_list() return self._load_data(batch_file_path)
def _load_data(self, batch_file_path): # TODO deal with num_channels input_data = [[] for i in range(len(self._mat_name_list))] for file_path in batch_file_path: mat = loadmat(file_path) cur_data = load_image_from_mat(mat, self._mat_name_list[0], self._mat_type_list[0]) cur_data = np.reshape(cur_data, [1, cur_data.shape[0], cur_data.shape[1], self.num_channels]) input_data[0].extend(cur_data) for k in range(1, len(self._mat_name_list)): cur_data = load_image_from_mat(mat, self._mat_name_list[k], self._mat_type_list[k]) cur_data = np.reshape(cur_data, [1, cur_data.shape[0], cur_data.shape[1]]) input_data[k].extend(cur_data) input_data = [np.array(data) for data in input_data] if self._normalize == 'tanh': try: input_data[0] = tanh_normalization(input_data[0], self._half_in_val) except AttributeError: self._input_val_range(input_data[0][0]) input_data[0] = tanh_normalization(input_data[0], self._half_in_val) return input_data def _input_val_range(self, in_mat): # TODO to be modified self._max_in_val, self._half_in_val = input_val_range(in_mat)
[docs] def size(self): return len(self.file_list)
def load_image_from_mat(matfile, name, datatype): mat = matfile[name].astype(datatype) return mat if __name__ == '__main__': a = MatlabData(data_dir='D:\\GoogleDrive_Qian\\Foram\\Training\\CNN_GAN_ORIGINAL_64\\', mat_name_list=['level1Edge'], normalize='tanh') print(a.next_batch()[0].shape) print(a.next_batch()[0][:,30:40,30:40,:]) print(np.amax(a.next_batch()[0]))