FOT-OOD / torch_datasets / imagenet.py
imagenet.py
Raw
from torchvision.datasets import ImageFolder


def get_imagenet_dataset(data_dir, subpopulation, transform, corr='clean', corr_sev=0):
    if subpopulation == 'natural':
        if corr_sev == 0:
            return ImageFolder(root=f"{data_dir}/imagenetv2/imagenetv2-matched-frequency-format-val/", transform=transform)
        elif corr_sev == 1:
            return ImageFolder(f"{data_dir}/imagenetv2/imagenetv2-threshold0.7-format-val/", transform=transform)
        elif corr_sev == 2:
            return ImageFolder(f"{data_dir}/imagenetv2/imagenetv2-top-images-format-val", transform=transform)
        elif corr_sev == 3:
            return ImageFolder(f"{data_dir}/imagenet-sketch/", transform=transform)
        
    elif subpopulation == 'same':
        assert corr_sev > 0, 'corr sev should > 0 for synthetic shifts'
        return ImageFolder(root=f"{data_dir}/imagenet-c/{corr}/{corr_sev}", transform=transform)