FOT-OOD / torch_datasets / tiny_imagenet.py
tiny_imagenet.py
Raw
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
import os, glob


class TinyImageNet(Dataset):
    def __init__(self, data_path, split='train', transform=None):
        self.id_dict = {}
        for i, line in enumerate(open(os.path.join(data_path, 'wnids.txt'), 'r')):
            self.id_dict[line.replace('\n', '')] = i
        
        self.split = split
        if split == 'train':
            self.filenames = sorted(glob.glob(os.path.join(data_path, "train/*/*/*.JPEG")))
            self.targets = [self.id_dict[img_path.replace("\\", "/").split("/")[-3]] for img_path in self.filenames]
        elif split == 'test':
            self.filenames = sorted(glob.glob(os.path.join(data_path, "val/images/*.JPEG")))
            self.cls_dic = {}
            for i, line in enumerate(open(os.path.join(data_path, 'val/val_annotations.txt'), 'r')):
                a = line.split('\t')
                img, cls_id = a[0], a[1]
                self.cls_dic[img] = self.id_dict[cls_id]
            
            self.targets = [self.cls_dic[img_path.replace("\\", "/").split('/')[-1]] for img_path in self.filenames]

        self.transform = transform
        self.data = [{'img_dir': filename, 'target': target} for (filename, target) in zip(self.filenames, self.targets)]

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img, target = self.data[idx]['img_dir'], self.data[idx]['target']
        img = Image.open(img).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)
        return img, target


class TinyImageNetCorrupted(Dataset):
    def __init__(self, corrupted_path, corruption_type, corruption_severity, transform=None):
        self.filenames = sorted(glob.glob(os.path.join(corrupted_path, corruption_type, str(corruption_severity), "*/*.JPEG")))
        self.transform = transform
        self.id_dict = {}
        for i, line in enumerate(open(os.path.join(corrupted_path, 'wnids.txt'), 'r')):
            self.id_dict[line.replace('\n', '')] = i

        self.targets = [self.id_dict[img_path.replace("\\", "/").split('/')[-2]] for img_path in self.filenames if img_path.replace("\\", "/").split('/')[-2] in self.id_dict]
        self.data = [{'img_dir': filename, 'target': target} for (filename, target) in zip(self.filenames, self.targets)]

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img, target = self.data[idx]['img_dir'], self.data[idx]['target']
        img = Image.open(img).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)
        return img, target