custom 3d image generator for segmentation
import glob import os import keras import numpy as np import skimage from imgaug import augmenters as iaa class DataGenerator(keras.utils.Sequence): """Generates data for Keras""" """This structure guarantees that the network will only train once on each sample per epoch""" def __init__(self, list_IDs, im_path, label_path, batch_size=4, dim=(128, 128, 128), n_classes=4, shuffle=True, augment=False): 'Initialization' self.dim = dim self.batch_size = batch_size self.list_IDs = list_IDs self.im_path = im_path self.label_path = label_path self.n_classes = n_classes self.shuffle = shuffle self.augment = augment self.on_epoch_end() def __len__(self): 'Denotes the number of batches per epoch' return int(np.floor(len(self.list_IDs) / self.batch_size)) def __getitem__(self, index): 'Generate one batch of data' # Generate indexes of the batch indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size] # Find list of IDs list_IDs_temp = [self.list_IDs[k] for k in indexes] # Generate data X, y = self.__data_generation(list_IDs_temp) return X, y def on_epoch_end(self): 'Updates indexes after each epoch' self.indexes = np.arange(len(self.list_IDs)) if self.shuffle == True: np.random.shuffle(self.indexes) def __data_generation(self, list_IDs_temp): if self.augment: pass if not self.augment: X = np.empty([self.batch_size, *self.dim]) Y = np.empty([self.batch_size, *self.dim, self.n_classes]) # Generate data for i, ID in enumerate(list_IDs_temp): img_X = skimage.io.imread(os.path.join(im_path, ID)) X[i,] = img_X img_Y = skimage.io.imread(os.path.join(label_path, ID)) Y[i,] = keras.utils.to_categorical(img_Y, num_classes=self.n_classes) X = X.reshape(self.batch_size, *self.dim, 1) return X, Y params = {'dim': (128, 128, 128), 'batch_size': 4, 'im_path': "some/path/for/the/images/", 'label_path': "some/path/for/the/label_images", 'n_classes': 4, 'shuffle': True, 'augment': True} partition = {} im_path = "some/path/for/the/images/" label_path = "some/path/for/the/label_images/" images = glob.glob(os.path.join(im_path, "*.tif")) images_IDs = [name.split("/")[-1] for name in images] partition['train'] = images_IDs training_generator = DataGenerator(partition['train'], **params)