mvq / dataset / build.py
build.py
Raw
# ---------------------------------------------------------------
# Copyright (c) ___________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------


import os
import csv
import glob
import pickle
import numpy as np
import lmdb
from PIL import Image

import torch
import torch.utils.data as _data

from torchvision import datasets
import torchvision.transforms as transforms
from torch.distributed.elastic.utils.data import ElasticDistributedSampler
from torch.utils.data import DataLoader

def init_dataloader(data_name, data_path, batch_size, num_workers, use_gpu=True, distributed=True):

    train_dataset, val_dataset = select_dataset(data_name, data_path)
    
    train_sampler = ElasticDistributedSampler(train_dataset) if distributed else None

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=use_gpu,
        sampler=train_sampler,
        drop_last=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=use_gpu,
        drop_last=True
    )

    return train_loader, val_loader


def select_dataset(name, path):

  base_transform = transforms.ToTensor()

  if name == 'mnist':
    train_ds = datasets.MNIST(path, train=True, download=True, transform=base_transform)
    val_ds = datasets.MNIST(path, train=False, download=True, transform=base_transform)

  elif name == 'cifar10':
    train_ds = datasets.CIFAR10(path, train=True, download=True, transform=base_transform)
    val_ds = datasets.CIFAR10(path, train=False, download=True, transform=base_transform)

  elif name == 'fashion':
    train_ds = datasets.FashionMNIST(path, train=True, download=True, transform=base_transform)
    val_ds = datasets.FashionMNIST(path, train=False, download=True, transform=base_transform)

  elif name == 'kmnist':
    train_ds = datasets.KMNIST(path, train=True, download=True, transform=base_transform)
    val_ds = datasets.KMNIST(path, train=False, download=True, transform=base_transform)

  elif name in ['imagenet8','imagenet16','imagenet32','imagenet64']:

    train_lmdb = os.path.join(path, "{}_{}.lmdb".format(name, 'train'))
    val_lmdb = os.path.join(path, "{}_{}.lmdb".format(name, 'val'))

    assert os.path.exists(train_lmdb), train_lmdb
    assert os.path.exists(val_lmdb), val_lmdb

    train_ds = ImageNetLMDB(train_lmdb, transform=base_transform)
    val_ds = ImageNetLMDB(val_lmdb, transform=base_transform)

  elif name.startswith('celeba'):
      size = int(name.replace('celeba',''))

      train_transform, val_transform = _data_transforms_celeba64(size)
      train_ds = CelebADataset(path, split='train', transform=train_transform)
      val_ds = CelebADataset(path, split='val', transform=val_transform)
  else:
      raise ValueError("Dataset not implemented: {}".format(name))

  return train_ds, val_ds

def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo)
    return dict

def pickle_save(item, fname):
  with open(fname, "wb") as f:
    pickle.dump(item, f)

def load_databatch(root, split, split_name, img_size, bidx=None):
    data_file = os.path.join(root, split_name)

    if split == 'train':
        data_file += str(bidx)

    d = unpickle(data_file)
    x = d['data']
    y = d['labels']
    return x, y

def dumps_data(obj):
    """
    Serialize an object.
    Returns:
        Implementation-dependent bytes-like object
    """
    return pickle.dumps(obj)


def raw_reader(path):
    with open(path, 'rb') as f:
        bin_data = f.read()
    return bin_data

def loads_data(buf):
    """
    Args:
        buf: the output of `dumps`.
    """
    return pickle.loads(buf)


def create_lmdb_db(dataset, root, split, img_size, write_frequency=5000):

    root = os.path.expanduser(root)
    ds_path = os.path.join(root, dataset)

    lmdb_path = os.path.join(root, "{}_{}.lmdb".format(dataset, split))
    isdir = os.path.isdir(lmdb_path)

    print("Generate LMDB to %s" % lmdb_path)

    db = lmdb.open(lmdb_path, subdir=isdir,
                   map_size=1099511627776 * 2, readonly=False,
                   meminit=False, map_async=True)


    split_name = 'train_data_batch_' if split =='train' else 'val_data'
    batch_count = 10 if split =='train' else 1
    idx = 0

    txn = db.begin(write=True)
    for bidx in range(batch_count):
        print("Creating lmdb batch={} split={} ".format(bidx, split))

        X, Y = load_databatch(ds_path, split, split_name, img_size,  bidx=bidx+1)
        for x, y in zip(X, Y):

            txn.put(u'{}'.format(idx).encode('ascii'), dumps_data((x, y)))
            if idx % write_frequency == 0:
                txn.commit()
                txn = db.begin(write=True)
            
            idx += 1

    # finish iterating through dataset
    txn.commit()
    keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
    with db.begin(write=True) as txn:
        txn.put(b'__keys__', dumps_data(keys))
        txn.put(b'__len__', dumps_data(len(keys)))

    print("Flushing database ...")
    db.sync()
    db.close()

class ImageNetLMDB(_data.Dataset):
    def __init__(self, db_path, transform=None, target_transform=None):
        self.db_path = db_path
        self.env = lmdb.open(db_path, subdir=os.path.isdir(db_path),
                             readonly=True, lock=False,
                             readahead=False, meminit=False)
        with self.env.begin(write=False) as txn:
            self.length = loads_data(txn.get(b'__len__'))
            self.keys = loads_data(txn.get(b'__keys__'))

        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return self.length - 1

    def __getitem__(self, index):

        with self.env.begin(write=False, buffers=True) as txn:
            byteflow = txn.get(self.keys[index])

            unpacked = loads_data(byteflow)
            imgbuf = unpacked[0]

            img = np.asarray(imgbuf, dtype=np.uint8)
            
            # assume data is RGB
            size = int(np.sqrt(len(img) / 3))
            img = np.reshape(img, (3, size, size))
            img = np.transpose(img, axes=[1,2,0])
            img = Image.fromarray(img, mode='RGB')
            
            target = unpacked[1]

        if self.transform is not None:
            img = self.transform(img)

        return img, target


class ImageNetDS(object):
    def __init__(self, root, split, img_size, folder_split=10000, **kwargs):
        root = self.root = os.path.expanduser(root)
        self.root = root
        self.split = split
        self.img_size = img_size
        self.folder_split = folder_split
        self.total_len = self.count_data()
        #self.dir_list = glob.glob(root + '/*')

    def count_data(self):
        all_folders = glob.glob(self.root + '/*')
        total_count = 0
        for folder in all_folders:
            total_count += len(glob.glob(folder + '/*'))
        return total_count

    def __len__(self):
        return self.total_len
    
    def __getitem__(self, idx):
        folder_idx = idx // self.folder_split
        folder_path = os.path.join(self.root, str(folder_idx))
        img_path = os.path.join(folder_path, '{}.pkl'.format(str(idx)))
        img_dict = unpickle(img_path)
        x = torch.tensor(img_dict['x'], dtype=torch.float32)
        y = torch.tensor(img_dict['y'], dtype=torch.int32)
        x = torch.reshape(x, [3, self.img_size, self.img_size])
        y -= 1
        x /= 255
        return x, y
    

class CelebADataset(object):
    def __init__(self, root, split, transform, **kwargs):
        self.root = os.path.expanduser(root)
        self.root = os.path.join(self.root, 'celeba')
        self.split_idx = {'train': 0, 'val': 1, 'test': 2}[split]
        self.split = split
        self.transform = transform

        img_dir = os.path.join(self.root, 'img_align_celeba')
        self.img_dir = os.path.join(img_dir, 'img_align_celeba')
        img_count = len(glob.glob(self.img_dir + '/*'))
        assert img_count > 0, "No images found @ {}".format(self.img_dir)
        
        self.partition_map = self.read_partition()

        self.attr_map = self.read_attributes()
        self.attr_keys = [
            '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', 
            'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair', 'Blurry', 
            'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', 'Eyeglasses', 
            'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', 'Male', 
            'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', 'Oval_Face', 
            'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns', 
            'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings', 'Wearing_Hat', 
            'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young'
            ]

    def read_partition(self):
        partition_path = os.path.join(self.root, 'list_eval_partition.csv')
        partition = csv.DictReader(open(partition_path))

        split_map = {}
        for row in partition:
            p = int(row['partition'])
            if p not in split_map:
                split_map[p] = []
            else:
                split_map[p].append(row['image_id'])
        return split_map

    def read_attributes(self):
        attr_path = os.path.join(self.root, 'list_attr_celeba.csv')
        attrs = csv.DictReader(open(attr_path))
        attrs = {k['image_id']: k for k in attrs}
        return attrs

    def __len__(self):
        return len(self.partition_map[self.split_idx])
    
    def attr_binary(self, attr):
        attr_binary = [int(attr[k]) if int(attr[k]) == 1 else 0 for k in self.attr_keys]
        return torch.tensor(attr_binary, dtype=torch.float32)

    def __getitem__(self, idx):
        img_name = self.partition_map[self.split_idx][idx]
        img_path = os.path.join(self.img_dir, img_name)
        img = Image.open(img_path).convert("RGB")
        attr = self.attr_map[img_name]
        img = self.transform(img)
        attr_binary = self.attr_binary(attr)
        return img, attr_binary

class CropCelebA64(object):
    """ This class applies cropping for CelebA64. This is a simplified implementation of:
    https://github.com/andersbll/autoencoding_beyond_pixels/blob/master/dataset/celeba.py
    """
    def __call__(self, pic):
        new_pic = pic.crop((15, 40, 178 - 15, 218 - 30))
        return new_pic

    def __repr__(self):
        return self.__class__.__name__ + '()'


def _data_transforms_celeba64(size):
    train_transform = transforms.Compose([
        CropCelebA64(),
        transforms.Resize(size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    valid_transform = transforms.Compose([
        CropCelebA64(),
        transforms.Resize(size),
        transforms.ToTensor(),
    ])

    return train_transform, valid_transform