FOT-OOD / torch_datasets / configs.py
configs.py
Raw
import torchvision.transforms as transforms
import torch
import torchvision.transforms.functional as TF
import torch.optim as optim
from model import (
    ResNet18, ResNet50, DenseNet121, VGG11, ViT_B_16, EfficientNetB4,
    initialize_bert_based_model, initialize_bert_transform
)
from wilds.datasets.fmow_dataset import FMoWDataset
from wilds.datasets.rxrx1_dataset import RxRx1Dataset
from wilds.datasets.amazon_dataset import AmazonDataset
from wilds.datasets.civilcomments_dataset import CivilCommentsDataset
from wilds.datasets.camelyon17_dataset import Camelyon17Dataset
from collections import Counter
import random


def get_expected_label_distribution(dataset):
    if dataset == 'FMoW':
        full_set = FMoWDataset(download=True, root_dir='./data', use_ood_val=True)
        val_set = full_set.get_subset('id_val', transform=None)
        label_counts = Counter(val_set.y_array.tolist())
        total_count = len(val_set.y_array)
        label_dist = [label_counts[i] / total_count for i in range(len(label_counts))]
        return label_dist
    
    elif dataset == 'Camelyon17':
        full_set =  Camelyon17Dataset(download=True, root_dir='./data')
        val_set = full_set.get_subset('id_val', transform=None)
        label_counts = Counter(val_set.y_array.tolist())
        total_count = len(val_set.y_array)
        label_dist = [label_counts[i] / total_count for i in range(len(label_counts))]
        return label_dist

    elif dataset == 'RxRx1':
        full_set = RxRx1Dataset(download=True, root_dir='./data')
        val_set = full_set.get_subset('id_test', transform=None)
        label_counts = Counter(val_set.y_array.tolist())
        total_count = len(val_set.y_array)
        label_dist = [label_counts[i] / total_count for i in range(len(label_counts))]
        return label_dist
    
    elif dataset == 'Amazon':
        full_set = AmazonDataset(download=False, root_dir='./data')
        val_set = full_set.get_subset('id_val', transform=None)
        label_counts = Counter(val_set.y_array.tolist())
        total_count = len(val_set.y_array)
        label_dist = [label_counts[i] / total_count for i in range(len(label_counts))]
        return label_dist
    
    elif dataset == 'CivilComments':
        full_set = CivilCommentsDataset(download=True, root_dir=f"./data")
        val_set = full_set.get_subset('val', transform=None)
        label_counts = Counter(val_set.y_array.tolist())
        total_count = len(val_set.y_array)
        label_dist = [label_counts[i] / total_count for i in range(len(label_counts))]
        return label_dist
        
    config = {
        'CIFAR-10': [1 / 10] * 10,
        'CIFAR-100': [1 / 100] * 100,
        'ImageNet': [1 / 1000] * 1000,
        'Living-17': [1 / 17] * 17,
        'Nonliving-26': [1 / 26] * 26,
        'Entity-13': [1 / 13] * 13,
        'Entity-30': [1 / 30] * 30,
    }

    return config[dataset]


def sample_val_label_dist(dist, n_class, sample_size):
    labels = sum([[i] * int(dist[i] * sample_size) for i in range(n_class)], [])
    
    remainder = sample_size - len(labels)
    
    r_labels = random.choices(
        list(range(n_class)), 
        weights=dist, 
        k=remainder
    )
    
    labels = labels + r_labels
    
    return torch.as_tensor(labels)



def sample_label_dist(ds, n_class, sample_size):
    dist = get_expected_label_distribution(ds)
    labels = sum([[i] * int(dist[i] * sample_size) for i in range(n_class)], [])
    
    remainder = sample_size - len(labels)
    
    r_labels = random.choices(
        list(range(n_class)), 
        weights=get_expected_label_distribution(ds), 
        k=remainder
    )
    
    labels = labels + r_labels
    
    return torch.as_tensor(labels)


def get_n_classes(dataset):
    n_class = {
        'CIFAR-10': 10,
        'CIFAR-100': 100,
        'ImageNet': 1000,
        'Tiny-ImageNet': 200,
        'Living-17': 17,
        'Nonliving-26': 26,
        'Entity-13': 13,
        'Entity-30': 30,
        'Camelyon17': 2,
        'FMoW': 62,
        'RxRx1': 1139,
        'Amazon': 5,
        'CivilComments':2
    }

    return n_class[dataset]


def get_transforms(dataset, split, pretrained):
    if dataset in ['CIFAR-10', 'CIFAR-100', 'Tiny-ImageNet']:
        if pretrained:
            transform = transforms.Compose([
                transforms.Resize(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        else:
            transform = transforms.Compose([
                transforms.Resize(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])
    
    elif dataset == 'ImageNet':
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
		])
    
    elif dataset in ['Living-17', 'Nonliving-26', 'Entity-13', 'Entity-30']:
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.4717, 0.4499, 0.3837], [0.2600, 0.2516, 0.2575])
		])
    
    elif dataset == 'Camelyon17':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
		])
    
    elif dataset == 'FMoW':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
		])
    
    elif dataset == 'RxRx1':
        def standardize(x: torch.Tensor) -> torch.Tensor:
            mean = x.mean(dim=(1, 2))
            std = x.std(dim=(1, 2))
            std[std == 0.] = 1.
            return TF.normalize(x, mean, std)
        
        t_standardize = transforms.Lambda(lambda x: standardize(x))

        angles = [0, 90, 180, 270]
        def random_rotation(x: torch.Tensor) -> torch.Tensor:
            angle = angles[torch.randint(low=0, high=len(angles), size=(1,))]
            if angle > 0:
                x = TF.rotate(x, angle)
            return x
        
        t_random_rotation = transforms.Lambda(lambda x: random_rotation(x))

        if split == 'train':
            transforms_ls = [
                t_random_rotation,
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                t_standardize,
            ]
        else:
            transforms_ls = [
                transforms.ToTensor(),
                t_standardize,
            ]
        transform = transforms.Compose(transforms_ls)
    
    elif dataset == 'Amazon':
        transform = initialize_bert_transform('distilbert-base-uncased', max_token_length=512)
    
    elif dataset == 'CivilComments':
        transform = initialize_bert_transform('distilbert-base-uncased', max_token_length=300)
    
    return transform


def get_optimizer(dsname, net, lr, pretrained):
    if dsname in ['CIFAR-10', 'CIFAR-100', 'Tiny-ImageNet']:
        if pretrained:
            return optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0)
        else:
            return optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    
    elif dsname == 'ImageNet':
        return optim.Adam(net.parameters(), lr=lr)
    
    elif dsname == 'Camelyon17':
        return optim.SGD(net.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4)
    
    elif dsname in ['Living-17', 'Nonliving-26', 'Entity-13', 'Entity-30']:
        return optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=1e-2)

    elif dsname == 'FMoW':
        return optim.Adam(net.parameters(), lr=lr)
    
    elif dsname == 'RxRx1':
        return optim.Adam(net.parameters(), lr=lr, weight_decay=1e-5)
    
    elif dsname in ['Amazon', 'CivilComments']:
        no_decay = ['bias', 'LayerNorm.weight']
        params = [
            {'params': [p for n, p in net.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
            {'params': [p for n, p in net.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        return optim.AdamW(params, lr=lr)


def get_lr_scheduler(dsname, opt, pretrained, T_max=-1):
    if dsname in ['CIFAR-10', 'CIFAR-100', 'Tiny-ImageNet']:
        if pretrained:
            return optim.lr_scheduler.CosineAnnealingLR(opt, T_max=T_max)
        else:
            return optim.lr_scheduler.MultiStepLR(opt, milestones=[100, 200], gamma=0.1)
    
    elif dsname == 'ImageNet':
       return optim.lr_scheduler.MultiStepLR(opt, milestones=[100], gamma=1)
    
    elif dsname in ['Living-17', 'Nonliving-26']:
        return optim.lr_scheduler.MultiStepLR(opt, milestones=[150, 300], gamma=0.1)
    
    elif dsname in ['Entity-13', 'Entity-30']:
        return optim.lr_scheduler.MultiStepLR(opt, milestones=[100, 200], gamma=0.1)
    
    elif dsname == 'Camelyon17':
        return optim.lr_scheduler.MultiStepLR(opt, milestones=[100], gamma=1)
    
    elif dsname == 'FMoW':
        return optim.lr_scheduler.StepLR(opt, step_size=1, gamma=0.96)
    
    elif dsname == 'RxRx1':
        return optim.lr_scheduler.OneCycleLR(
            opt, max_lr=1e-4, div_factor=1e12, pct_start=0.11, final_div_factor=1e12,
            cycle_momentum=False, base_momentum=0, max_momentum=0, total_steps=T_max
        )
    
    elif dsname == 'Amazon':
        return optim.lr_scheduler.PolynomialLR(opt, total_iters=3, power=1, verbose=True)

    elif dsname == 'CivilComments':
        return optim.lr_scheduler.PolynomialLR(opt, total_iters=5, power=1, verbose=True)


def get_models(arch, n_class, model_seed, pretrained):
    if arch == 'resnet18':
        model = ResNet18(num_classes=n_class, seed=model_seed, pretrained=pretrained)
    elif arch == 'resnet50':
        model = ResNet50(num_classes=n_class, seed=model_seed, pretrained=pretrained)
    elif arch == 'efficientnet_b4':
        model = EfficientNetB4(num_classes=n_class, seed=model_seed, pretrained=pretrained)
    elif arch == 'densenet121':
        model = DenseNet121(num_classes=n_class, seed=model_seed, pretrained=pretrained)
    elif arch == 'vit_b_16':
        model = ViT_B_16(num_classes=n_class, seed=model_seed, pretrained=pretrained)
    elif arch == 'vgg11':
        model = VGG11(num_classes=n_class, seed=model_seed, pretrained=pretrained)
    elif arch == 'distilbert-base-uncased':
        model = initialize_bert_based_model(n_class)
    else:
        raise ValueError('incorrect model name')

    return model