# --------------------------------------------------------------- # Copyright (c) ___________________ 2023. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # --------------------------------------------------------------- import os import csv import glob import pickle import numpy as np import lmdb from PIL import Image import torch import torch.utils.data as _data from torchvision import datasets import torchvision.transforms as transforms from torch.distributed.elastic.utils.data import ElasticDistributedSampler from torch.utils.data import DataLoader def init_dataloader(data_name, data_path, batch_size, num_workers, use_gpu=True, distributed=True): train_dataset, val_dataset = select_dataset(data_name, data_path) train_sampler = ElasticDistributedSampler(train_dataset) if distributed else None train_loader = DataLoader( train_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=use_gpu, sampler=train_sampler, drop_last=True ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=use_gpu, drop_last=True ) return train_loader, val_loader def select_dataset(name, path): base_transform = transforms.ToTensor() if name == 'mnist': train_ds = datasets.MNIST(path, train=True, download=True, transform=base_transform) val_ds = datasets.MNIST(path, train=False, download=True, transform=base_transform) elif name == 'cifar10': train_ds = datasets.CIFAR10(path, train=True, download=True, transform=base_transform) val_ds = datasets.CIFAR10(path, train=False, download=True, transform=base_transform) elif name == 'fashion': train_ds = datasets.FashionMNIST(path, train=True, download=True, transform=base_transform) val_ds = datasets.FashionMNIST(path, train=False, download=True, transform=base_transform) elif name == 'kmnist': train_ds = datasets.KMNIST(path, train=True, download=True, transform=base_transform) val_ds = datasets.KMNIST(path, train=False, download=True, transform=base_transform) elif name in ['imagenet8','imagenet16','imagenet32','imagenet64']: train_lmdb = os.path.join(path, "{}_{}.lmdb".format(name, 'train')) val_lmdb = os.path.join(path, "{}_{}.lmdb".format(name, 'val')) assert os.path.exists(train_lmdb), train_lmdb assert os.path.exists(val_lmdb), val_lmdb train_ds = ImageNetLMDB(train_lmdb, transform=base_transform) val_ds = ImageNetLMDB(val_lmdb, transform=base_transform) elif name.startswith('celeba'): size = int(name.replace('celeba','')) train_transform, val_transform = _data_transforms_celeba64(size) train_ds = CelebADataset(path, split='train', transform=train_transform) val_ds = CelebADataset(path, split='val', transform=val_transform) else: raise ValueError("Dataset not implemented: {}".format(name)) return train_ds, val_ds def unpickle(file): with open(file, 'rb') as fo: dict = pickle.load(fo) return dict def pickle_save(item, fname): with open(fname, "wb") as f: pickle.dump(item, f) def load_databatch(root, split, split_name, img_size, bidx=None): data_file = os.path.join(root, split_name) if split == 'train': data_file += str(bidx) d = unpickle(data_file) x = d['data'] y = d['labels'] return x, y def dumps_data(obj): """ Serialize an object. Returns: Implementation-dependent bytes-like object """ return pickle.dumps(obj) def raw_reader(path): with open(path, 'rb') as f: bin_data = f.read() return bin_data def loads_data(buf): """ Args: buf: the output of `dumps`. """ return pickle.loads(buf) def create_lmdb_db(dataset, root, split, img_size, write_frequency=5000): root = os.path.expanduser(root) ds_path = os.path.join(root, dataset) lmdb_path = os.path.join(root, "{}_{}.lmdb".format(dataset, split)) isdir = os.path.isdir(lmdb_path) print("Generate LMDB to %s" % lmdb_path) db = lmdb.open(lmdb_path, subdir=isdir, map_size=1099511627776 * 2, readonly=False, meminit=False, map_async=True) split_name = 'train_data_batch_' if split =='train' else 'val_data' batch_count = 10 if split =='train' else 1 idx = 0 txn = db.begin(write=True) for bidx in range(batch_count): print("Creating lmdb batch={} split={} ".format(bidx, split)) X, Y = load_databatch(ds_path, split, split_name, img_size, bidx=bidx+1) for x, y in zip(X, Y): txn.put(u'{}'.format(idx).encode('ascii'), dumps_data((x, y))) if idx % write_frequency == 0: txn.commit() txn = db.begin(write=True) idx += 1 # finish iterating through dataset txn.commit() keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)] with db.begin(write=True) as txn: txn.put(b'__keys__', dumps_data(keys)) txn.put(b'__len__', dumps_data(len(keys))) print("Flushing database ...") db.sync() db.close() class ImageNetLMDB(_data.Dataset): def __init__(self, db_path, transform=None, target_transform=None): self.db_path = db_path self.env = lmdb.open(db_path, subdir=os.path.isdir(db_path), readonly=True, lock=False, readahead=False, meminit=False) with self.env.begin(write=False) as txn: self.length = loads_data(txn.get(b'__len__')) self.keys = loads_data(txn.get(b'__keys__')) self.transform = transform self.target_transform = target_transform def __len__(self): return self.length - 1 def __getitem__(self, index): with self.env.begin(write=False, buffers=True) as txn: byteflow = txn.get(self.keys[index]) unpacked = loads_data(byteflow) imgbuf = unpacked[0] img = np.asarray(imgbuf, dtype=np.uint8) # assume data is RGB size = int(np.sqrt(len(img) / 3)) img = np.reshape(img, (3, size, size)) img = np.transpose(img, axes=[1,2,0]) img = Image.fromarray(img, mode='RGB') target = unpacked[1] if self.transform is not None: img = self.transform(img) return img, target class ImageNetDS(object): def __init__(self, root, split, img_size, folder_split=10000, **kwargs): root = self.root = os.path.expanduser(root) self.root = root self.split = split self.img_size = img_size self.folder_split = folder_split self.total_len = self.count_data() #self.dir_list = glob.glob(root + '/*') def count_data(self): all_folders = glob.glob(self.root + '/*') total_count = 0 for folder in all_folders: total_count += len(glob.glob(folder + '/*')) return total_count def __len__(self): return self.total_len def __getitem__(self, idx): folder_idx = idx // self.folder_split folder_path = os.path.join(self.root, str(folder_idx)) img_path = os.path.join(folder_path, '{}.pkl'.format(str(idx))) img_dict = unpickle(img_path) x = torch.tensor(img_dict['x'], dtype=torch.float32) y = torch.tensor(img_dict['y'], dtype=torch.int32) x = torch.reshape(x, [3, self.img_size, self.img_size]) y -= 1 x /= 255 return x, y class CelebADataset(object): def __init__(self, root, split, transform, **kwargs): self.root = os.path.expanduser(root) self.root = os.path.join(self.root, 'celeba') self.split_idx = {'train': 0, 'val': 1, 'test': 2}[split] self.split = split self.transform = transform img_dir = os.path.join(self.root, 'img_align_celeba') self.img_dir = os.path.join(img_dir, 'img_align_celeba') img_count = len(glob.glob(self.img_dir + '/*')) assert img_count > 0, "No images found @ {}".format(self.img_dir) self.partition_map = self.read_partition() self.attr_map = self.read_attributes() self.attr_keys = [ '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young' ] def read_partition(self): partition_path = os.path.join(self.root, 'list_eval_partition.csv') partition = csv.DictReader(open(partition_path)) split_map = {} for row in partition: p = int(row['partition']) if p not in split_map: split_map[p] = [] else: split_map[p].append(row['image_id']) return split_map def read_attributes(self): attr_path = os.path.join(self.root, 'list_attr_celeba.csv') attrs = csv.DictReader(open(attr_path)) attrs = {k['image_id']: k for k in attrs} return attrs def __len__(self): return len(self.partition_map[self.split_idx]) def attr_binary(self, attr): attr_binary = [int(attr[k]) if int(attr[k]) == 1 else 0 for k in self.attr_keys] return torch.tensor(attr_binary, dtype=torch.float32) def __getitem__(self, idx): img_name = self.partition_map[self.split_idx][idx] img_path = os.path.join(self.img_dir, img_name) img = Image.open(img_path).convert("RGB") attr = self.attr_map[img_name] img = self.transform(img) attr_binary = self.attr_binary(attr) return img, attr_binary class CropCelebA64(object): """ This class applies cropping for CelebA64. This is a simplified implementation of: https://github.com/andersbll/autoencoding_beyond_pixels/blob/master/dataset/celeba.py """ def __call__(self, pic): new_pic = pic.crop((15, 40, 178 - 15, 218 - 30)) return new_pic def __repr__(self): return self.__class__.__name__ + '()' def _data_transforms_celeba64(size): train_transform = transforms.Compose([ CropCelebA64(), transforms.Resize(size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) valid_transform = transforms.Compose([ CropCelebA64(), transforms.Resize(size), transforms.ToTensor(), ]) return train_transform, valid_transform