mvq / precompute_fid.py
precompute_fid.py
Raw
# ---------------------------------------------------------------
# Copyright (c) ___________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Forked from:
# https://github.com/NVlabs/NVAE/blob/master/scripts/precompute_fid_statistics.py
# 
# The license for the original version of this file can be
# found in this directory (/thirdparty/fid/LICENSE_pytorch_fid). The modifications
# to this file are subject to the same Apache License.
# ---------------------------------------------------------------


import os
import warnings
import torch
import click

from thirdparty.fid.fid_score import compute_statistics_of_generator, save_statistics
from thirdparty.fid.inception import InceptionV3
from dataset import select_dataset

warnings.filterwarnings("ignore", category=DeprecationWarning) 

@click.command()
@click.option('--dataset','-d', default='fashion')
@click.option('--batch_size','-bs', default=64)
@click.option('--max_samples','-s', default=50000)
@click.option('--device','-de', default='cuda:0')
@click.option('--dims','-di', default=2048)
@click.option('--data_dir','-dp')
def main(dataset, data_dir, batch_size, max_samples, device, dims):
    assert dataset in ['cifar10', 'celeba64', 'mnist', 'kmnist', 'fashion','imagenet32', 'imagenet64']
    assert os.path.exists(data_dir)
    precompute_fid_scores(data_dir, dataset, device=device, batch_size=batch_size, dims=dims, max_samples=max_samples)

def precompute_fid_scores(data_dir, dataset, device='cuda:0', batch_size=32, dims=2048, max_samples=50000):

    fid_dir = os.path.join(data_dir, 'fid-stats')
    if not os.path.exists(fid_dir):
        os.mkdir(fid_dir)
        print("Making new directory at :{}".format(fid_dir))
    
    train_data, val_data = select_dataset(dataset, data_dir)

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=batch_size,
        shuffle=True, pin_memory=True, num_workers=8, drop_last=True)

    valid_queue = torch.utils.data.DataLoader(
        val_data, batch_size=batch_size,
        shuffle=True, pin_memory=True, num_workers=8, drop_last=False)

    
    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
    model = InceptionV3([block_idx], model_dir=fid_dir).to(device)
    m, s, a = compute_statistics_of_generator(train_queue, model, batch_size, dims, device, max_samples)
    file_path = os.path.join(fid_dir, dataset + '-train.npz')
    print('saving fid stats (train) at %s' % file_path)
    save_statistics(file_path, m, s, a)

    m, s, a = compute_statistics_of_generator(valid_queue, model, batch_size, dims, device, max_samples)
    file_path = os.path.join(fid_dir, dataset + '-eval.npz')
    print('saving fid stats at (eval) %s' % file_path)
    save_statistics(file_path, m, s, a)

if __name__ == '__main__':
  main()