# --------------------------------------------------------------- # Copyright (c) ___________________ 2023. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # --------------------------------------------------------------- import os import copy import torch import numpy as np import pandas as pd import torch import torch.distributed as dist from torchvision.transforms.functional import to_pil_image from thirdparty.fid.fid_score import compute_statistics_of_generator, load_statistics, calculate_frechet_distance, save_statistics from thirdparty.fid.inception import InceptionV3 from thirdparty.prds.prd_score import compute_prd_from_embedding, prd_to_max_f_beta_pair from thirdparty.prds.prd_score import plot as plot_prd from utils import ( Logger, Writer, ResultTracker, save_mega_subplot, tile_image, save_fid_line, average_tensor, ) class BaseRunner(object): def __init__(self, cfg): self.epoch = 0 self.global_step = 1 self.cfg = cfg self.bsz = getattr(cfg.batch_size_per_gpu, cfg.cur_stage) self.device_id = self.local_rank() self.use_torch_dist = cfg.torch_dist.use self.display = cfg.train.intervals.display self.logger = Logger(self.rank(), cfg.save_dir) self.writer = Writer(self.rank(), cfg.save_dir) self.logger.info("Runner Initialized - Rank=[{}/{}]".format(self.local_rank(), self.rank())) self.data_name = cfg.data.name self.use_gan = cfg.model.gan.use self.gradient_accum_steps = self.cfg.train.gradient_accum_steps if self.cfg.model.num_latent_space == 1: self.seq_shape1 = [self.bsz, self.cfg.model.discrete_seq_len[-1], self.cfg.model.discrete_seq_len[-1]] else: p1 = int(np.ceil(self.cfg.model.discrete_seq_len[0] ** (1/2))) p2 = int(np.ceil(self.cfg.model.discrete_seq_len[1] ** (1/2))) self.seq_shape1 = [self.bsz, p1, p1] self.seq_shape2 = [self.bsz, p2, p2] #Metrics self.best_score = {} self.transfer_cfg = self.cfg.eval.transfer_test self.subplots_cfg = self.cfg.eval.save_imgs.subplots self.fid_cfg = self.cfg.eval.fid self.fid_device = 'cuda:{}'.format(self.device_id) self.fid_dir = os.path.join(self.cfg.data_path, 'fid-stats') self.fid_file = os.path.join(self.fid_dir, self.cfg.data.name + '-train.npz') if not os.path.exists(self.fid_file): self.logger.info("FID file does not exist: {}".format(self.fid_file)) raise SystemError(f"Please precompute fid statistics with script >>> '''python precompute_fid.py ...'''") self.total_fid_samples = self.cfg.eval.fid.samples if self.data_name.startswith('imagenet'): self.total_fid_samples = 50000 elif self.data_name.startswith('celeba'): self.total_fid_samples = 19961 self.metrics = [] self.metric_names = [ 'mse', 'dist', 'gan', 'disc', 'cb', 'cb1', 'cb2', 'nll', 'rec', 'p', 'ce1', 'ce2','g', 'var', 'mhd', ] self.print_names = [ 'mse', 'dist', 'gan', 'disc', 'cb', 'cb1', 'cb2', 'rec', 'p', 'ce1', 'ce2', 'zc_dlen', 'var', 'mhd' ] self.tracker = ResultTracker(['epoch', 'iter'], print_names=self.print_names) self.lr_scheduler = None self.scaler = None self.gradient_accum_steps = cfg.train.gradient_accum_steps self.max_grad_norm = cfg.train.max_grad_norm self.use_amp = False self.do_grad_scaling = False self.use_fp16 = self.cfg.fp16.use if self.use_fp16: self.use_amp = True self.do_grad_scaling = True self.amp_dtype = torch.float16 if self.cfg.fp16.use else torch.bfloat16 self.scaler = torch.cuda.amp.GradScaler() def local_rank(self): r = os.environ.get("LOCAL_RANK") r = 0 if r is None else int(r) return r def rank(self): r = os.environ.get("RANK") r = 0 if r is None else int(r) return r def update_writer(self, split): for l in self.metric_names: l_name = '{}/{}_loss'.format(split, l) if self.tracker.get_loss('epoch', l_name): self.writer.add_scalar(l_name, self.tracker.get_loss('epoch', l_name) , self.epoch) def add_logs(self, split, log=None, total_loss=None): for t in ['epoch', 'iter']: if log is not None: for l in self.metric_names: l_name = '{}/{}_loss'.format(split, l) if l_name in log: self.tracker.add(t, l_name, log[l_name]) if total_loss is not None: self.tracker.add(t, '{}/total_loss'.format(split), total_loss) def save_reconst(self, x, x_hat, pred_dir, step, num_display=16): if self.rank() == 0: x = x.cpu() x_hat = x_hat.cpu() bsz = min(x.size(0), num_display) n_sq = int(np.ceil(bsz ** (1/2))) h = n_sq w = n_sq #x = x[:bsz] #tiled_x = tile_image(x, h=h, w=w) # save_x = os.path.join(pred_dir, 'recon-E{}-{}-{}-original.png'.format(self.epoch, step, self.rank())) # pil_x = to_pil_image(tiled_x) # pil_x.save(save_x) x_hat = x_hat[:bsz] tiled_x_hat = tile_image(x_hat, h=h, w=w) save_x_hat = os.path.join(pred_dir, 'recon-E{}-{}-{}-recon.png'.format(self.epoch, step, self.rank())) pil_x_hat = to_pil_image(tiled_x_hat) pil_x_hat.save(save_x_hat) if self.use_torch_dist: dist.barrier() def create_subplot(self, images, outputs, pred_dir, step, num_display=25): if 'xk_hat' in outputs and self.subplots_cfg.active and self.rank() == 0: images = images.permute(0, 2, 3, 1).cpu() centers = outputs['xc_hat'].permute(0, 2, 3, 1).cpu() win_sample = outputs['xw_hat'].permute(0, 2, 3, 1).cpu() topk_sample = outputs['xk_hat'].permute(0, 1, 3, 4, 2).cpu() titles = ['Original', 'C', 'W'] + ['{}'.format(i + 1) for i in range(self.subplots_cfg.columns - 3)] plots = [] for i, (img, c, w, samp) in enumerate(zip(images, centers, win_sample, topk_sample)): plot = [img, c, w] for s in samp[:self.subplots_cfg.columns - 3]: plot.append(s) plots.append(plot) if (i + 1) % self.subplots_cfg.rows == 0: save_path = os.path.join(pred_dir, 'subplots-E{}-{}-{}-{}.png'.format(self.epoch, step, i, self.rank())) save_mega_subplot(save_path, plots, titles) plots = [] all_sample = outputs['xd_hat'].cpu() centers = centers.permute(0, 3, 1, 2) for i, (c, samples) in enumerate(zip(centers, all_sample)): bsz = min(samples.size(0), num_display) n = max([didx for didx in range(3, int(num_display ** (1/2) + 1)) if didx ** 2 <= bsz]) n_sq = int(n ** 2) #n = num_display ** (1/2) #int(samples.size(0) ** (1/2)) samples[0] = c samples = samples[:n_sq] tiled_image = tile_image(samples, n) save_path = os.path.join(pred_dir, 'tiles-E{}-{}-{}-{}.png'.format(self.epoch, step, i, self.rank())) pil_image = to_pil_image(tiled_image) pil_image.save(save_path) if self.use_torch_dist: dist.barrier() def create_center_subplot(self, outputs, pred_dir): if self.subplots_cfg.active and 'xd_hat' in outputs and self.rank() == 0: xc_hat = outputs['xc_hat'].cpu() xd_hat = outputs['xd_hat'].cpu() for idx, (c, ds) in enumerate(zip(xc_hat, xd_hat)): n = int(ds.size(0) ** (1/2)) ds[0] = c tiled_image = tile_image(ds, n) save_path = os.path.join(pred_dir, 'eval-tiles-E{}-{}-{}.png'.format(self.epoch, idx, self.rank())) pil_image = to_pil_image(tiled_image) pil_image.save(save_path) if self.use_torch_dist: dist.barrier() def create_transfer_tiles(self, outputs, pred_dir, names): if self.rank() == 0: centers = outputs['xc_hat'].cpu() recons = outputs['xd_hat'].cpu() for idx, (c, rec, n) in enumerate(zip(centers, recons, names)): h = self.cfg.model.vq.n_emb2 + 1 w = h rec = rec[:,:w,:,:,:] #Replace first sample with center rec[:,0,:,:,:] = c channels, height, width = rec.size(2), rec.size(3), rec.size(4) rec = rec.permute(2, 0, 3, 1, 4) # n, height, n, width, c rec = rec.contiguous().view(channels, h * height, w * width) save_path = os.path.join(pred_dir, 'seq-transfer-tiles-E{}-{}-{}-{}.png'.format(self.epoch, idx, self.rank(), n)) pil_image = to_pil_image(rec) pil_image.save(save_path) if self.use_torch_dist: dist.barrier() def run_fid(self, vae_model, dataset, hypothese_count, pred_dir, stage, seq_model=None, seq_cond_model=None): if not os.path.exists(self.fid_file): return float('inf') if self.use_torch_dist: dist.barrier() scores = self.test_vae_fid(vae_model, dataset, hypothese_count, seq_model=seq_model, seq_cond_model=seq_cond_model) precision, recall, f_beta, f_beta_inv, train_fid, eval_fid, m, s, a = scores self.writer.add_scalar("{}/fid_train".format(stage), train_fid, self.epoch) self.writer.add_scalar("{}/fid_eval".format(stage), eval_fid, self.epoch) self.writer.add_scalar("{}/f_beta".format(stage), f_beta, self.epoch) self.writer.add_scalar("{}/f_beta_inv".format(stage), f_beta_inv, self.epoch) self.logger.info("[eval-{}] fid: train={:.4f} eval={:.4f} fbeta: {:.4f} fbeta_inv: {:.4f}".format( stage, train_fid, eval_fid, f_beta, f_beta_inv)) metrics = { 'rank': self.rank(), 'epoch': self.epoch, 'precision': precision, 'recall': recall, 'f_beta': f_beta, 'f_beta_inv': f_beta_inv, 'train_fid': train_fid, 'eval_fid': eval_fid } #Save metrics to pickle save_path = os.path.join(pred_dir, 'metrics-{}-{}.pkl'.format(stage, self.rank())) self.metrics.append(metrics) #pickle_save(self.metrics, save_path) #Save activations for later - activations are too big. save_path = os.path.join(pred_dir, 'activations-{}-{}.npz'.format(stage, self.rank())) #save_statistics(save_path, m=m,s=s,a=a) #Make a chart for fid train and eval save_path = os.path.join(pred_dir, 'fidscore_{}.png'.format(self.rank())) save_fid_line(self.metrics, save_path) save_path = os.path.join(pred_dir, 'prdplot-single-{}.png'.format(self.rank())) plot_prd([[precision, recall]], out_path=save_path) plot_input = [] labels = [] for m in self.metrics[-5:]: labels.append(m['epoch']) plot_input.append([m['precision'], m['recall']]) save_path = os.path.join(pred_dir, 'prdplot-all-{}.png'.format(self.rank())) plot_prd(plot_input, labels=labels, out_path=save_path) return eval_fid def test_vae_fid(self, vae_model, dataset, hypothese_count, seq_model=None, seq_cond_model=None): dims = 2048 device = self.fid_device if self.use_torch_dist: size = float(dist.get_world_size()) num_sample_per_gpu = int(np.ceil(self.total_fid_samples / size)) else: num_sample_per_gpu = self.total_fid_samples g = self.create_generator(vae_model, dataset, self.bsz, num_sample_per_gpu, hypothese_count, seq_model=seq_model, seq_cond_model=seq_cond_model) block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] model = InceptionV3([block_idx], model_dir=self.fid_dir).to() m, s, a = compute_statistics_of_generator(g, model, self.bsz, dims, device, max_samples=num_sample_per_gpu) # share m and s if self.use_torch_dist: dist.barrier() m = torch.from_numpy(m).cuda(self.device_id) s = torch.from_numpy(s).cuda(self.device_id) a = torch.from_numpy(a).cuda(self.device_id) average_tensor(m) average_tensor(s) average_tensor(a) m = m.cpu().numpy() s = s.cpu().numpy() a = a.cpu().numpy() # load precomputed m, s train_fid_path = os.path.join(self.fid_dir, self.cfg.data.name + '-train.npz') eval_fid_path = os.path.join(self.fid_dir, self.cfg.data.name + '-eval.npz') t_m0, t_s0, _ = load_statistics(train_fid_path) e_m0, e_s0, e_a0 = load_statistics(eval_fid_path) eval_size = e_a0.shape[0] if a.shape[0] != eval_size: #self.logger.info("PRD Embedding sizes not the same ({}={}). Automatically compensating".format(a.shape[0], eval_size)) if a.shape[0] > eval_size: a = a[:eval_size, :] elif a.shape[0] < eval_size: e_a0 = e_a0[:a.shape[0], :] precision, recall = compute_prd_from_embedding(e_a0, a) f_beta, f_beta_inv = prd_to_max_f_beta_pair(precision, recall, beta=8) #precision, recall = np.mean(precision), np.mean(recall) train_fid = calculate_frechet_distance(t_m0, t_s0, m, s) eval_fid = calculate_frechet_distance(e_m0, e_s0, m, s) return precision, recall, f_beta, f_beta_inv, train_fid, eval_fid, m, s, a def create_generator(self, vae_model, dataset, batch_size, num_total_samples, hypothese_count, seq_model=None, seq_cond_model=None): if seq_model is None: return self.create_generator_vae(vae_model, dataset, batch_size, num_total_samples, hypothese_count) else: return self.create_generator_seq(vae_model, seq_model, batch_size, num_total_samples, hypothese_count, seq_cond_model=seq_cond_model)