foreground-prototypes-based-few-shot-learning / main_test.py
main_test.py
Raw
#!/usr/bin/env python

import argparse
import random
import numpy as np

import SimpleITK as sitk
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from torch.utils.data import DataLoader

from models.fewshot import FewShotSeg
from dataloaders.datasets import TestDataset
from dataloaders.specifics import *
from utils import *


def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, required=True)
    parser.add_argument('--save_dir', type=str, required=True)
    parser.add_argument('--pretrained_root', type=str, required=True)
    parser.add_argument('--fold', type=int, required=True)
    parser.add_argument('--dataset', type=str, required=True)
    parser.add_argument('--n_shot', default=1, type=int)
    parser.add_argument('--all_slices', default=True, type=bool)
    parser.add_argument('--EP1', default=False, type=bool)
    parser.add_argument('--seed', default=None, type=int)
    parser.add_argument('--workers', default=0, type=int)

    return parser.parse_args()


def main():
    args = parse_arguments()

    # Deterministic setting for reproducability.
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True

    # Set up logging.
    logger = set_logger(args.save_dir, 'train.log')
    logger.info(args)

    # Setup the path to save.
    args.save = os.path.join(args.save_dir)

    # Init model and load state_dict.
    model = FewShotSeg(use_coco_init=False)
    model = nn.DataParallel(model.cuda())
    model.load_state_dict(torch.load(args.pretrained_root, map_location="cpu"))

    # Data loader.
    test_dataset = TestDataset(args)
    query_loader = DataLoader(test_dataset,
                              batch_size=1,
                              shuffle=False,
                              num_workers=args.workers,
                              pin_memory=True,
                              drop_last=True)

    # Inference.
    logger.info('  Start inference ... Note: EP1 is ' + str(args.EP1))
    logger.info('  Support: ' + str(test_dataset.support_dir[len(args.data_dir):]))
    logger.info('  Query: ' +
                str([elem[len(args.data_dir):] for elem in test_dataset.image_dirs]))

    # Get unique labels (classes).
    labels = get_label_names(args.dataset)

    # Loop over classes.
    class_dice = {}
    class_iou = {}
    max_dice = {}
    max_iou = {}
    class_accuracy = {}
    class_precision = {}
    max_precision = {}

    for label_val, label_name in labels.items():

        # Skip BG class.
        if label_name is 'BG':
            continue

        logger.info('  *------------------Class: {}--------------------*'.format(label_name))
        logger.info('  *--------------------------------------------------*')

        # Get support sample + mask for current class.
        support_sample = test_dataset.getSupport(label=label_val, all_slices=args.all_slices, N=args.n_shot)
        test_dataset.label = label_val

        # Infer.
        with torch.no_grad():
            scores = infer(model, query_loader, support_sample, args, logger, label_name)

        # Log class-wise results
        class_dice[label_name] = torch.tensor(scores.patient_dice).mean().item()
        class_iou[label_name] = torch.tensor(scores.patient_iou).mean().item()
        class_accuracy[label_name] = torch.tensor(scores.accuracy).mean().item()
        class_precision[label_name] = torch.tensor(scores.precision).mean().item()

        max_dice[label_name] = torch.tensor(scores.patient_dice).max().item()
        max_iou[label_name] = torch.tensor(scores.patient_iou).max().item()
        max_precision[label_name] = torch.tensor(scores.precision).max().item()

        
        # logger.info('      Mean class IoU: {}'.format(class_iou[label_name]))
        # logger.info('      Mean class Dice: {}'.format(class_dice[label_name]))
        # logger.info('  *--------------------------------------------------*')

    # Log final results.
    logger.info('  *-----------------Final results--------------------*')
    logger.info('  *--------------------------------------------------*')
    logger.info('  Mean IoU: {}'.format(class_iou))
    logger.info('  Mean Dice: {}'.format(class_dice))    
    logger.info('  Mean Accuracy : {}'.format(class_accuracy))
    logger.info('  Mean Precision : {}'.format(class_precision))
    logger.info('  Max Dice : {}'.format(max_dice))
    logger.info('  Max IoU : {}'.format(max_iou))
    logger.info('  Max Precision : {}'.format(max_precision))
    logger.info('  *--------------------------------------------------*')


def infer(model, query_loader, support_sample, args, logger, label_name):


    # Test mode.
    model.eval()

    # Unpack support data.
    support_image = [support_sample['image'][[i]].float().cuda() for i in range(support_sample['image'].shape[0])]  # n_shot x 3 x H x W
    support_fg_mask = [support_sample['label'][[i]].float().cuda() for i in range(support_sample['image'].shape[0])]  # n_shot x H x W

    # Loop through query volumes.
    scores = Scores()
    for i, sample in enumerate(query_loader):

        # Unpack query data.
        query_image = [sample['image'][i].float().cuda() for i in range(sample['image'].shape[0])]  # [C x 3 x H x W]
        query_label = sample['label'].long()  # C x H x W
        query_id = sample['id'][0].split('image_')[1][:-len('.nii.gz')]

        # Compute output.
        if args.EP1 is True:
            # Match support slice and query sub-chunck.
            query_pred = torch.zeros(query_label.shape[-3:])
            C_q = sample['image'].shape[1]
            idx_ = np.linspace(0, C_q, args.n_shot+1).astype('int')
            for sub_chunck in range(args.n_shot):
                support_image_s = [support_image[sub_chunck]]  # 1 x 3 x H x W
                support_fg_mask_s = [support_fg_mask[sub_chunck]]  # 1 x H x W
                query_image_s = query_image[0][idx_[sub_chunck]:idx_[sub_chunck+1]]  # C' x 3 x H x W
                query_pred_s, _, _ = model([support_image_s], [support_fg_mask_s], [query_image_s], train=False)  # C x 2 x H x W
                query_pred_s = query_pred_s.argmax(dim=1).cpu()  # C x H x W
                query_pred[idx_[sub_chunck]:idx_[sub_chunck+1]] = query_pred_s

        else:  # EP 2
            query_pred, _, _ = model([support_image], [support_fg_mask], query_image, train=False)  # C x 2 x H x W
            query_pred = query_pred.argmax(dim=1).cpu()  # C x H x W
            
            # np_arr = query_pred.cpu().detach().numpy()

            # np_arr = 1 - np_arr

            # testimg = sitk.GetImageFromArray(np_arr)
            # sitk.WriteImage(testimg, 'image_' + query_id + '_' + label_name + '.nii.gz')

        # Record scores.
        scores.record(query_pred, query_label)

        # Log.
        logger.info('    Tested query volume: ' + sample['id'][0][len(args.data_dir):]
                    + '. Dice score:  ' + str(scores.patient_dice[-1].item()) + '. Accuracy:  ' + str(scores.accuracy[-1].item())
                    + '. Precision: ' + str(scores.precision[-1].item()))

        # Save predictions.
        file_name = 'image_' + query_id + '_' + str(i) + '.pt'
        torch.save(query_pred, os.path.join(args.save, file_name))

    return scores


if __name__ == '__main__':
    main()