import torch from torch.utils.data import Dataset import torchvision.transforms as deftfx import glob import os import SimpleITK as sitk import random import numpy as np from . import image_transforms as myit from .specifics import * class TestDataset(Dataset): def __init__(self, args): # reading the paths self.image_dirs = glob.glob(os.path.join(args.data_root, 'images/image*')) self.image_dirs = sorted(self.image_dirs, key=lambda x: int(x.split('_')[-1].split('.nii.gz')[0])) self.image_dirs = self.image_dirs[:len(self.image_dirs)//60] # self.FOLD = get_folds(args.dataset) # self.image_dirs = [elem for idx, elem in enumerate(self.image_dirs) if idx in self.FOLD[args.fold]] # split into support/query\ self.support_index = random.randint(0,len(self.image_dirs)) self.support_dir = self.image_dirs[self.support_index] # - 1 self.image_dirs = [self.image_dirs[image] for image in range(len(self.image_dirs)) if image != self.support_index] # :-1 # remove support self.label = None def __len__(self): return len(self.image_dirs) def __getitem__(self, idx): img_path = self.image_dirs[idx] img = sitk.GetArrayFromImage(sitk.ReadImage(img_path)) img = (img - img.mean()) / img.std() img = np.stack(3 * [img], axis=1) lbl = sitk.GetArrayFromImage( sitk.ReadImage(img_path.split('image_')[0] + 'label_' + img_path.split('image_')[-1])) # lbl[lbl == 200] = 1 # lbl[lbl == 500] = 2 # lbl[lbl == 600] = 3 lbl = 1 * (lbl == self.label) sample = {'id': img_path} sample['image'] = torch.from_numpy(img) sample['label'] = torch.from_numpy(lbl) return sample def get_support_index(self, n_shot, C): if n_shot == 1: pcts = [0.5] else: half_part = 1 / (n_shot * 2) part_interval = (1.0 - 1.0 / n_shot) / (n_shot - 1) pcts = [half_part + part_interval * ii for ii in range(n_shot)] return (np.array(pcts) * C).astype('int') def getSupport(self, label=None, all_slices=True, N=None): if label is None: raise ValueError('Need to specify label class!') img_path = self.support_dir img = sitk.GetArrayFromImage(sitk.ReadImage(img_path)) img = (img - img.mean()) / img.std() img = np.stack(3 * [img], axis=1) lbl = sitk.GetArrayFromImage( sitk.ReadImage(img_path.split('image_')[0] + 'label_' + img_path.split('image_')[-1])) # lbl[lbl == 200] = 1 # lbl[lbl == 500] = 2 # lbl[lbl == 600] = 3 lbl = 1 * (lbl == label) sample = {} if all_slices: sample['image'] = torch.from_numpy(img) sample['label'] = torch.from_numpy(lbl) else: # select N labeled slices if N is None: raise ValueError('Need to specify number of labeled slices!') idx = lbl.sum(axis=(1, 2)) > 0 idx_ = self.get_support_index(N, idx.sum()) sample['image'] = torch.from_numpy(img[idx][idx_]) sample['label'] = torch.from_numpy(lbl[idx][idx_]) return sample class TrainDataset(Dataset): def __init__(self, args): self.n_shot = args.n_shot self.n_way = args.n_way self.n_query = args.n_query self.n_sv = args.n_sv self.max_iter = args.max_iterations self.read = True # read images before get_item self.train_sampling = 'neighbors' self.min_size = 200 self.factor = 20 self.image_dirs = glob.glob(os.path.join(args.data_root, 'images/image*')) self.image_dirs = sorted(self.image_dirs, key=lambda x: int(x.split('_')[-1].split('.nii.gz')[0])) self.image_dirs = self.image_dirs[:len(self.image_dirs)//self.factor] self.label_dirs = glob.glob(os.path.join(args.data_root, 'labels/label*')) self.label_dirs = sorted(self.label_dirs, key=lambda x: int(x.split('_')[-1].split('.nii.gz')[0])) self.label_dirs = self.label_dirs[:len(self.label_dirs)//self.factor] # remove test fold! # self.FOLD = get_folds(args.dataset) # self.image_dirs = [elem for idx, elem in enumerate(self.image_dirs) if idx not in self.FOLD[args.fold]] # self.label_dirs = [elem for idx, elem in enumerate(self.label_dirs) if idx not in self.FOLD[args.fold]] # read images if self.read: self.images = {} self.labels = {} for image_dir, label_dir in zip(self.image_dirs, self.label_dirs): self.images[image_dir] = sitk.GetArrayFromImage(sitk.ReadImage(image_dir)) self.labels[label_dir] = sitk.GetArrayFromImage(sitk.ReadImage(label_dir)) def __len__(self): return self.max_iter """ Image transforms according to Ouyang et al. """ def gamma_tansform(self, img): gamma_range = (0.5, 1.5) gamma = np.random.rand() * (gamma_range[1] - gamma_range[0]) + gamma_range[0] cmin = img.min() irange = (img.max() - cmin + 1e-5) img = img - cmin + 1e-5 img = irange * np.power(img * 1.0 / irange, gamma) img = img + cmin return img def geom_transform(self, img, mask): affine = {'rotate': 5, 'shift': (5, 5), 'shear': 5, 'scale': (0.9, 1.2)} alpha = 10 sigma = 5 order = 3 tfx = [] tfx.append(myit.RandomAffine(affine.get('rotate'), affine.get('shift'), affine.get('shear'), affine.get('scale'), affine.get('scale_iso', True), order=order)) tfx.append(myit.ElasticTransform(alpha, sigma)) transform = deftfx.Compose(tfx) if len(img.shape) > 4: n_shot = img.shape[1] for shot in range(n_shot): cat = np.concatenate((img[0, shot], mask[:, shot])).transpose(1, 2, 0) cat = transform(cat).transpose(2, 0, 1) img[0, shot] = cat[:3, :, :] mask[:, shot] = np.rint(cat[3:, :, :]) else: for q in range(img.shape[0]): cat = np.concatenate((img[q], mask[q][None])).transpose(1, 2, 0) cat = transform(cat).transpose(2, 0, 1) img[q] = cat[:3, :, :] mask[q] = np.rint(cat[3:, :, :].squeeze()) return img, mask def __getitem__(self, idx): # sample patient idx pat_idx = random.choice(range(len(self.image_dirs))) if self.read: # get image/label volume from dictionary img = self.images[self.image_dirs[pat_idx]] label = self.labels[self.label_dirs[pat_idx]] else: # read image/supervoxel volume into memory img = sitk.GetArrayFromImage(sitk.ReadImage(self.image_dirs[pat_idx])) label = sitk.GetArrayFromImage(sitk.ReadImage(self.label_dirs[pat_idx])) # normalize img = (img - img.mean()) / img.std() # sample class(es) (supervoxel) unique = list(np.unique(label)) unique.remove(0) size = 0 while size < self.min_size: n_slices = (self.n_shot * self.n_way) + self.n_query - 1 while n_slices < ((self.n_shot * self.n_way) + self.n_query): cls_idx = random.choice(unique) # extract slices containing the sampled class sli_idx = np.sum(label == cls_idx, axis=(1, 2)) > 0 n_slices = np.sum(sli_idx) img_slices = img[sli_idx] label_slices = 1 * (label[sli_idx] == cls_idx) # sample support and query slices i = random.choice( np.arange(n_slices - ((self.n_shot * self.n_way) + self.n_query) + 1)) # successive slices sample = np.arange(i, i + (self.n_shot * self.n_way) + self.n_query) size = np.sum(label_slices[sample[0]]) # invert order if np.random.random(1) > 0.5: sample = sample[::-1] # successive slices (inverted) sup_lbl = label_slices[sample[:self.n_shot * self.n_way]][None,] # n_way * (n_shot * C) * H * W qry_lbl = label_slices[sample[self.n_shot * self.n_way:]] # n_qry * C * H * W sup_img = img_slices[sample[:self.n_shot * self.n_way]][None,] # n_way * (n_shot * C) * H * W sup_img = np.stack((sup_img, sup_img, sup_img), axis=2) qry_img = img_slices[sample[self.n_shot * self.n_way:]] # n_qry * C * H * W qry_img = np.stack((qry_img, qry_img, qry_img), axis=1) # gamma transform # if np.random.random(1) > 0.5: # qry_img = self.gamma_tansform(qry_img) # else: # sup_img = self.gamma_tansform(sup_img) # # geom transform # if np.random.random(1) > 0.5: # qry_img, qry_lbl = self.geom_transform(qry_img, qry_lbl) # else: # sup_img, sup_lbl = self.geom_transform(sup_img, sup_lbl) sample = {'support_images': sup_img, 'support_fg_labels': sup_lbl, 'query_images': qry_img, 'query_labels': qry_lbl} return sample