foreground-prototypes-based-few-shot-learning / utils.py
utils.py
Raw
import torch
import logging
import os
import wandb
import re


def atoi(text):
    return int(text) if text.isdigit() else text


def natural_keys(text):
    return [atoi(c) for c in re.split(r'(\d+)', text)]


def init_wandb(args):
    wandb.init(project=args.wandb)

    # Save run name.
    wandb.run.save()
    run_name = wandb.run.name

    # Log args.
    config = wandb.config
    config.update(args)

    return run_name


def set_logger(log_path, file_name):
    os.makedirs(log_path, exist_ok=True)
    path = os.path.join(log_path, file_name)

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # Log to .txt
    file_handler = logging.FileHandler(path)
    file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
    logger.addHandler(file_handler)

    # Log to console
    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(logging.Formatter('%(message)s'))
    logger.addHandler(stream_handler)

    return logger


class AverageMeter(object):

    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class Scores():
    # Computes and stores measures of inference

    def __init__(self):
        self.TP = 0
        self.TN = 0
        self.FP = 0
        self.FN = 0

        self.patient_dice = []
        self.patient_iou = []
        self.accuracy = []
        self.precision = []

    def record(self, preds, label):
        assert len(torch.unique(preds)) < 3

        tp = torch.sum((label == 1) * (preds == 1))
        tn = torch.sum((label == 0) * (preds == 0))
        fp = torch.sum((label == 0) * (preds == 1))
        fn = torch.sum((label == 1) * (preds == 0))

        self.patient_dice.append(2 * tp / (2 * tp + fp + fn))
        self.patient_iou.append(tp / (tp + fp + fn))
        self.accuracy.append((tp + tn) / (tp + tn + fp + fn))
        prec = tp / (tp + fp)
        if (torch.isnan(prec)):
            self.precision.append(torch.tensor(0))
        else:
            self.precision.append(prec)

        self.TP += tp
        self.TN += tn
        self.FP += fp
        self.FN += fn

    def compute_dice(self):
        return 2 * self.TP / (2 * self.TP + self.FP + self.FN)

    def compute_iou(self):
        return self.TP / (self.TP + self.FP + self.FN)
    
    def compute_accuracy(self):
        return ((self.TP + self.TN) / (self.TP + self.TN + self.FP + self.FN))
    
    def compute_precision(self):
        return self.TP / (self.TP + self.FP)