# --------------------------------------------------------------- # Copyright (c) ___________________ 2023. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # # Forked from: # https://github.com/NVlabs/NVAE/blob/master/scripts/precompute_fid_statistics.py # # The license for the original version of this file can be # found in this directory (/thirdparty/fid/LICENSE_pytorch_fid). The modifications # to this file are subject to the same Apache License. # --------------------------------------------------------------- import os import warnings import torch import click from thirdparty.fid.fid_score import compute_statistics_of_generator, save_statistics from thirdparty.fid.inception import InceptionV3 from dataset import select_dataset warnings.filterwarnings("ignore", category=DeprecationWarning) @click.command() @click.option('--dataset','-d', default='fashion') @click.option('--batch_size','-bs', default=64) @click.option('--max_samples','-s', default=50000) @click.option('--device','-de', default='cuda:0') @click.option('--dims','-di', default=2048) @click.option('--data_dir','-dp') def main(dataset, data_dir, batch_size, max_samples, device, dims): assert dataset in ['cifar10', 'celeba64', 'mnist', 'kmnist', 'fashion','imagenet32', 'imagenet64'] assert os.path.exists(data_dir) precompute_fid_scores(data_dir, dataset, device=device, batch_size=batch_size, dims=dims, max_samples=max_samples) def precompute_fid_scores(data_dir, dataset, device='cuda:0', batch_size=32, dims=2048, max_samples=50000): fid_dir = os.path.join(data_dir, 'fid-stats') if not os.path.exists(fid_dir): os.mkdir(fid_dir) print("Making new directory at :{}".format(fid_dir)) train_data, val_data = select_dataset(dataset, data_dir) train_queue = torch.utils.data.DataLoader( train_data, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=8, drop_last=True) valid_queue = torch.utils.data.DataLoader( val_data, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=8, drop_last=False) block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] model = InceptionV3([block_idx], model_dir=fid_dir).to(device) m, s, a = compute_statistics_of_generator(train_queue, model, batch_size, dims, device, max_samples) file_path = os.path.join(fid_dir, dataset + '-train.npz') print('saving fid stats (train) at %s' % file_path) save_statistics(file_path, m, s, a) m, s, a = compute_statistics_of_generator(valid_queue, model, batch_size, dims, device, max_samples) file_path = os.path.join(fid_dir, dataset + '-eval.npz') print('saving fid stats at (eval) %s' % file_path) save_statistics(file_path, m, s, a) if __name__ == '__main__': main()