foreground-prototypes-based-few-shot-learning / dataloaders / datasets.py
datasets.py
Raw
import torch
from torch.utils.data import Dataset
import torchvision.transforms as deftfx
import glob
import os
import SimpleITK as sitk
import random
import numpy as np
from . import image_transforms as myit
from .specifics import *


class TestDataset(Dataset):

    def __init__(self, args):

        # reading the paths
        self.image_dirs = glob.glob(os.path.join(args.data_root, 'images/image*'))
        self.image_dirs = sorted(self.image_dirs, key=lambda x: int(x.split('_')[-1].split('.nii.gz')[0]))
        self.image_dirs = self.image_dirs[:len(self.image_dirs)//60]

        # self.FOLD = get_folds(args.dataset)
        # self.image_dirs = [elem for idx, elem in enumerate(self.image_dirs) if idx in self.FOLD[args.fold]]

        # split into support/query\
        self.support_index = random.randint(0,len(self.image_dirs))
        self.support_dir = self.image_dirs[self.support_index] # - 1
        self.image_dirs = [self.image_dirs[image] for image in range(len(self.image_dirs)) if image != self.support_index] # :-1  # remove support 
        self.label = None 


    def __len__(self):
        return len(self.image_dirs)

    def __getitem__(self, idx):

        img_path = self.image_dirs[idx]
        img = sitk.GetArrayFromImage(sitk.ReadImage(img_path))
        img = (img - img.mean()) / img.std()
        img = np.stack(3 * [img], axis=1)

        lbl = sitk.GetArrayFromImage(
            sitk.ReadImage(img_path.split('image_')[0] + 'label_' + img_path.split('image_')[-1]))
        # lbl[lbl == 200] = 1
        # lbl[lbl == 500] = 2
        # lbl[lbl == 600] = 3
        lbl = 1 * (lbl == self.label)

        sample = {'id': img_path}

        sample['image'] = torch.from_numpy(img)
        sample['label'] = torch.from_numpy(lbl)

        return sample

    def get_support_index(self, n_shot, C):
        if n_shot == 1:
            pcts = [0.5]
        else:
            half_part = 1 / (n_shot * 2)
            part_interval = (1.0 - 1.0 / n_shot) / (n_shot - 1)
            pcts = [half_part + part_interval * ii for ii in range(n_shot)]

        return (np.array(pcts) * C).astype('int')

    def getSupport(self, label=None, all_slices=True, N=None):
        if label is None:
            raise ValueError('Need to specify label class!')

        img_path = self.support_dir
        img = sitk.GetArrayFromImage(sitk.ReadImage(img_path))
        img = (img - img.mean()) / img.std()
        img = np.stack(3 * [img], axis=1)

        lbl = sitk.GetArrayFromImage(
            sitk.ReadImage(img_path.split('image_')[0] + 'label_' + img_path.split('image_')[-1]))
        # lbl[lbl == 200] = 1
        # lbl[lbl == 500] = 2
        # lbl[lbl == 600] = 3
        lbl = 1 * (lbl == label)

        sample = {}
        if all_slices:
            sample['image'] = torch.from_numpy(img)
            sample['label'] = torch.from_numpy(lbl)
        else:
            # select N labeled slices
            if N is None:
                raise ValueError('Need to specify number of labeled slices!')
            idx = lbl.sum(axis=(1, 2)) > 0
            idx_ = self.get_support_index(N, idx.sum())

            sample['image'] = torch.from_numpy(img[idx][idx_])
            sample['label'] = torch.from_numpy(lbl[idx][idx_])

        return sample


class TrainDataset(Dataset):

    def __init__(self, args):
        self.n_shot = args.n_shot
        self.n_way = args.n_way
        self.n_query = args.n_query
        self.n_sv = args.n_sv
        self.max_iter = args.max_iterations
        self.read = True  # read images before get_item
        self.train_sampling = 'neighbors'
        self.min_size = 200
        self.factor = 20

        self.image_dirs = glob.glob(os.path.join(args.data_root, 'images/image*'))
        self.image_dirs = sorted(self.image_dirs, key=lambda x: int(x.split('_')[-1].split('.nii.gz')[0]))
        self.image_dirs = self.image_dirs[:len(self.image_dirs)//self.factor]

        self.label_dirs = glob.glob(os.path.join(args.data_root, 'labels/label*'))
        self.label_dirs = sorted(self.label_dirs, key=lambda x: int(x.split('_')[-1].split('.nii.gz')[0]))
        self.label_dirs = self.label_dirs[:len(self.label_dirs)//self.factor]

        # remove test fold!
        # self.FOLD = get_folds(args.dataset)
        # self.image_dirs = [elem for idx, elem in enumerate(self.image_dirs) if idx not in self.FOLD[args.fold]]
        # self.label_dirs = [elem for idx, elem in enumerate(self.label_dirs) if idx not in self.FOLD[args.fold]]

        # read images
        if self.read:
            self.images = {}
            self.labels = {}
            for image_dir, label_dir in zip(self.image_dirs, self.label_dirs):
                self.images[image_dir] = sitk.GetArrayFromImage(sitk.ReadImage(image_dir))
                self.labels[label_dir] = sitk.GetArrayFromImage(sitk.ReadImage(label_dir))

    def __len__(self):
        return self.max_iter

    """
    Image transforms according to Ouyang et al.
    """
    def gamma_tansform(self, img):
        gamma_range = (0.5, 1.5)
        gamma = np.random.rand() * (gamma_range[1] - gamma_range[0]) + gamma_range[0]
        cmin = img.min()
        irange = (img.max() - cmin + 1e-5)

        img = img - cmin + 1e-5
        img = irange * np.power(img * 1.0 / irange, gamma)
        img = img + cmin

        return img

    def geom_transform(self, img, mask):

        affine = {'rotate': 5, 'shift': (5, 5), 'shear': 5, 'scale': (0.9, 1.2)}
        alpha = 10
        sigma = 5
        order = 3

        tfx = []
        tfx.append(myit.RandomAffine(affine.get('rotate'),
                                     affine.get('shift'),
                                     affine.get('shear'),
                                     affine.get('scale'),
                                     affine.get('scale_iso', True),
                                     order=order))
        tfx.append(myit.ElasticTransform(alpha, sigma))
        transform = deftfx.Compose(tfx)

        if len(img.shape) > 4:
            n_shot = img.shape[1]
            for shot in range(n_shot):
                cat = np.concatenate((img[0, shot], mask[:, shot])).transpose(1, 2, 0)
                cat = transform(cat).transpose(2, 0, 1)
                img[0, shot] = cat[:3, :, :]
                mask[:, shot] = np.rint(cat[3:, :, :])

        else:
            for q in range(img.shape[0]):
                cat = np.concatenate((img[q], mask[q][None])).transpose(1, 2, 0)
                cat = transform(cat).transpose(2, 0, 1)
                img[q] = cat[:3, :, :]
                mask[q] = np.rint(cat[3:, :, :].squeeze())

        return img, mask

    def __getitem__(self, idx):

        # sample patient idx
        pat_idx = random.choice(range(len(self.image_dirs)))

        if self.read:
            # get image/label volume from dictionary
            img = self.images[self.image_dirs[pat_idx]]
            label = self.labels[self.label_dirs[pat_idx]]
        else:
            # read image/supervoxel volume into memory
            img = sitk.GetArrayFromImage(sitk.ReadImage(self.image_dirs[pat_idx]))
            label = sitk.GetArrayFromImage(sitk.ReadImage(self.label_dirs[pat_idx]))

        # normalize
        img = (img - img.mean()) / img.std()

        # sample class(es) (supervoxel)
        unique = list(np.unique(label))
        unique.remove(0)

        size = 0
        while size < self.min_size:
            n_slices = (self.n_shot * self.n_way) + self.n_query - 1
            while n_slices < ((self.n_shot * self.n_way) + self.n_query):
                cls_idx = random.choice(unique)

                # extract slices containing the sampled class
                sli_idx = np.sum(label == cls_idx, axis=(1, 2)) > 0
                n_slices = np.sum(sli_idx)

            img_slices = img[sli_idx]
            label_slices = 1 * (label[sli_idx] == cls_idx)

            # sample support and query slices
            i = random.choice(
                np.arange(n_slices - ((self.n_shot * self.n_way) + self.n_query) + 1))  # successive slices
            sample = np.arange(i, i + (self.n_shot * self.n_way) + self.n_query)

            size = np.sum(label_slices[sample[0]])

        # invert order
        if np.random.random(1) > 0.5:
            sample = sample[::-1]  # successive slices (inverted)

        sup_lbl = label_slices[sample[:self.n_shot * self.n_way]][None,]  # n_way * (n_shot * C) * H * W
        qry_lbl = label_slices[sample[self.n_shot * self.n_way:]]  # n_qry * C * H * W

        sup_img = img_slices[sample[:self.n_shot * self.n_way]][None,]  # n_way * (n_shot * C) * H * W
        sup_img = np.stack((sup_img, sup_img, sup_img), axis=2)
        qry_img = img_slices[sample[self.n_shot * self.n_way:]]  # n_qry * C * H * W
        qry_img = np.stack((qry_img, qry_img, qry_img), axis=1)

        # gamma transform
        # if np.random.random(1) > 0.5:
        #     qry_img = self.gamma_tansform(qry_img)
        # else:
        #     sup_img = self.gamma_tansform(sup_img)

        # # geom transform
        # if np.random.random(1) > 0.5:
        #     qry_img, qry_lbl = self.geom_transform(qry_img, qry_lbl)
        # else:
        #     sup_img, sup_lbl = self.geom_transform(sup_img, sup_lbl)

        sample = {'support_images': sup_img,
                  'support_fg_labels': sup_lbl,
                  'query_images': qry_img,
                  'query_labels': qry_lbl}

        return sample