honeyplotnet / fid / fid_utils.py
fid_utils.py
Raw
# ---------------------------------------------------------------
# Copyright (c) __________________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------

import os
import wget
import torch

S3_BUCKET = 'https://s3.ap-southeast-2.amazonaws.com/decoychart/chartfid/'

def download_pretrained(path=None):
    
    assert os.path.exists(path), path
    
    #Download FID model snapshot and statistics on train and test datasets.
    snapshot_names = ['chartFid_snapshot.pth', 'pmc-train.npz', 'pmc-test.npz']
    for snapshot_name in snapshot_names:
        snapshot_local = os.path.join(path, snapshot_name)
        if not os.path.exists(snapshot_local):
            wget.download(os.path.join(S3_BUCKET, snapshot_name), snapshot_local)

def load_fid_snapshot(model, snapshot_path):
    snapshot = torch.load(snapshot_path)['fid_model']
    model.load_state_dict(snapshot)