honeyplotnet / fid_usage_example.py
fid_usage_example.py
Raw
# ---------------------------------------------------------------
# Copyright (c) __________________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------

'''

This file demonstrates how to run the FID metrics presented in the paper.

The `init_fid_model` automatically downloads the pre-trained weights from an AWS bucket a
and creates a new model.

The "Example Usage" code block below shows how to calculate FID with your own model. 
The inputs to the FID model need to be in a compatible format.
'''


import click
import os
from utils import load_cfg
from fid import init_fid_model, calculate_frechet_distance

@click.command()
@click.option('--config_file','-c', default='default.yaml')
def main(config_file):

    ###########################################
    # Load Configurations 
    cur_dir = os.path.dirname(os.path.realpath(__file__))
    cfg_dir = os.path.join(cur_dir, 'config')
    cfg = load_cfg(config_file, cfg_dir)

    ###########################################
    # Initialize fid model
    fid_model, fid_stats = init_fid_model(cfg, load_path=cur_dir, device_id=0)

    mu1, sigma1, act1 = fid_stats['test']

    ###########################################
    # Example usage. 
    ###########################################

    # x_hat = chart_model.sample()
    # activations, _, _ = fid_model(x_hat)
    # mu2 = np.mean(acts, axis=0)
    # sigma2 = np.cov(acts, rowvar=False)

    # score = calculate_frechet_distance(mu1, sigma1, mu2, sigma2)
    # print(score)

if __name__ == "__main__":
  main()