mvq / runner / base.py
base.py
Raw
# ---------------------------------------------------------------
# Copyright (c) ___________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------

import os
import copy
import torch
import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
from torchvision.transforms.functional import to_pil_image

from thirdparty.fid.fid_score import compute_statistics_of_generator, load_statistics, calculate_frechet_distance, save_statistics
from thirdparty.fid.inception import InceptionV3
from thirdparty.prds.prd_score import compute_prd_from_embedding, prd_to_max_f_beta_pair
from thirdparty.prds.prd_score import plot as plot_prd

from utils import (
  Logger, Writer, ResultTracker,
  save_mega_subplot, 
  tile_image,
  save_fid_line, 
  average_tensor,
)

class BaseRunner(object):
  def __init__(self, cfg):
    self.epoch = 0
    self.global_step = 1

    self.cfg = cfg
    self.bsz = getattr(cfg.batch_size_per_gpu, cfg.cur_stage)

    self.device_id = self.local_rank()

    self.use_torch_dist = cfg.torch_dist.use
    self.display = cfg.train.intervals.display

    self.logger = Logger(self.rank(), cfg.save_dir) 
    self.writer = Writer(self.rank(), cfg.save_dir) 
    
    self.logger.info("Runner Initialized - Rank=[{}/{}]".format(self.local_rank(), self.rank()))

    self.data_name = cfg.data.name
    self.use_gan = cfg.model.gan.use

    self.gradient_accum_steps = self.cfg.train.gradient_accum_steps

    if self.cfg.model.num_latent_space == 1:
      self.seq_shape1 = [self.bsz, self.cfg.model.discrete_seq_len[-1], self.cfg.model.discrete_seq_len[-1]]
    else:
      p1 = int(np.ceil(self.cfg.model.discrete_seq_len[0] ** (1/2))) 
      p2 = int(np.ceil(self.cfg.model.discrete_seq_len[1] ** (1/2))) 
      self.seq_shape1 = [self.bsz, p1, p1]
      self.seq_shape2 = [self.bsz, p2, p2]

    #Metrics
    self.best_score = {} 

    self.transfer_cfg = self.cfg.eval.transfer_test
    self.subplots_cfg = self.cfg.eval.save_imgs.subplots
    self.fid_cfg = self.cfg.eval.fid
    self.fid_device = 'cuda:{}'.format(self.device_id) 
    self.fid_dir = os.path.join(self.cfg.data_path, 'fid-stats')
    self.fid_file = os.path.join(self.fid_dir, self.cfg.data.name + '-train.npz')

    if not os.path.exists(self.fid_file):
      self.logger.info("FID file does not exist: {}".format(self.fid_file))
      raise SystemError(f"Please precompute fid statistics with script >>> '''python precompute_fid.py ...'''")

    self.total_fid_samples = self.cfg.eval.fid.samples
    if self.data_name.startswith('imagenet'):
      self.total_fid_samples = 50000
    elif self.data_name.startswith('celeba'):
      self.total_fid_samples = 19961


    self.metrics = []
    self.metric_names = [
      'mse', 'dist', 'gan', 'disc',
      'cb', 'cb1', 'cb2', 'nll',
      'rec', 'p', 'ce1', 'ce2','g', 
      'var', 'mhd',
      ]
    
    self.print_names = [
      'mse', 'dist', 'gan', 'disc',
      'cb', 'cb1', 'cb2', 
      'rec', 'p', 'ce1', 'ce2',
      'zc_dlen', 'var', 'mhd'
      ]
    
    self.tracker = ResultTracker(['epoch', 'iter'], print_names=self.print_names)

    self.lr_scheduler = None
    self.scaler = None
    self.gradient_accum_steps = cfg.train.gradient_accum_steps
    self.max_grad_norm = cfg.train.max_grad_norm

    self.use_amp = False
    self.do_grad_scaling = False
    self.use_fp16 = self.cfg.fp16.use

    if self.use_fp16: 
      self.use_amp = True
      self.do_grad_scaling = True
      self.amp_dtype = torch.float16 if self.cfg.fp16.use else torch.bfloat16
      self.scaler = torch.cuda.amp.GradScaler()


  def local_rank(self):
    r = os.environ.get("LOCAL_RANK")
    r = 0 if r is None else int(r)
    return r

  def rank(self):
    r = os.environ.get("RANK")
    r = 0 if r is None else int(r)
    return r

  def update_writer(self, split):
    for l in self.metric_names:
        l_name = '{}/{}_loss'.format(split, l)
        if self.tracker.get_loss('epoch', l_name):
            self.writer.add_scalar(l_name, self.tracker.get_loss('epoch', l_name) , self.epoch)
  
  def add_logs(self, split, log=None, total_loss=None):
    for t in ['epoch', 'iter']:
      if log is not None:
        for l in self.metric_names:
          l_name = '{}/{}_loss'.format(split, l)
          if l_name in log:
            self.tracker.add(t, l_name, log[l_name])

      if total_loss is not None:
        self.tracker.add(t, '{}/total_loss'.format(split), total_loss)
    
  def save_reconst(self, x, x_hat, pred_dir, step, num_display=16):
    if self.rank() == 0:
      x = x.cpu() 
      x_hat = x_hat.cpu() 
      bsz = min(x.size(0), num_display)
      n_sq = int(np.ceil(bsz ** (1/2)))
      h = n_sq
      w = n_sq

      #x = x[:bsz]
      #tiled_x = tile_image(x, h=h, w=w)
      # save_x = os.path.join(pred_dir, 'recon-E{}-{}-{}-original.png'.format(self.epoch, step, self.rank()))
      # pil_x = to_pil_image(tiled_x)
      # pil_x.save(save_x)

      x_hat = x_hat[:bsz]
      tiled_x_hat = tile_image(x_hat, h=h, w=w)
      save_x_hat = os.path.join(pred_dir, 'recon-E{}-{}-{}-recon.png'.format(self.epoch, step, self.rank()))
      pil_x_hat = to_pil_image(tiled_x_hat)
      pil_x_hat.save(save_x_hat)

    if self.use_torch_dist:
        dist.barrier()

  def create_subplot(self, images, outputs, pred_dir, step, num_display=25):

    if 'xk_hat' in outputs and self.subplots_cfg.active and self.rank() == 0:
      images = images.permute(0, 2, 3, 1).cpu()
      centers = outputs['xc_hat'].permute(0, 2, 3, 1).cpu()
      win_sample = outputs['xw_hat'].permute(0, 2, 3, 1).cpu()
      topk_sample = outputs['xk_hat'].permute(0, 1, 3, 4, 2).cpu()

      titles = ['Original', 'C', 'W'] + ['{}'.format(i + 1) for i in range(self.subplots_cfg.columns - 3)]

      plots = []
      for i, (img, c, w, samp) in enumerate(zip(images, centers, win_sample, topk_sample)):
        plot = [img, c, w]
        for s in samp[:self.subplots_cfg.columns - 3]:
          plot.append(s)
        plots.append(plot)

        if (i + 1) % self.subplots_cfg.rows == 0:
          save_path = os.path.join(pred_dir, 'subplots-E{}-{}-{}-{}.png'.format(self.epoch, step, i, self.rank()))
          save_mega_subplot(save_path, plots, titles)
          plots = []

      all_sample = outputs['xd_hat'].cpu()
      centers = centers.permute(0, 3, 1, 2)
      for i, (c, samples) in enumerate(zip(centers, all_sample)):
        bsz = min(samples.size(0), num_display)
        n = max([didx for didx in range(3, int(num_display ** (1/2) + 1)) if didx ** 2 <= bsz])
        n_sq = int(n ** 2)
        #n = num_display ** (1/2) #int(samples.size(0) ** (1/2))
        samples[0] = c
        samples = samples[:n_sq]
        tiled_image = tile_image(samples, n)

        save_path = os.path.join(pred_dir, 'tiles-E{}-{}-{}-{}.png'.format(self.epoch, step, i, self.rank()))
        pil_image = to_pil_image(tiled_image)
        pil_image.save(save_path)
        
    if self.use_torch_dist:
        dist.barrier()

  def create_center_subplot(self, outputs, pred_dir):
    if self.subplots_cfg.active and 'xd_hat' in outputs and self.rank() == 0:
      xc_hat = outputs['xc_hat'].cpu()
      xd_hat = outputs['xd_hat'].cpu()

      for idx, (c, ds) in enumerate(zip(xc_hat, xd_hat)):
        n = int(ds.size(0) ** (1/2))
        ds[0] = c
        tiled_image = tile_image(ds, n)
        save_path = os.path.join(pred_dir, 'eval-tiles-E{}-{}-{}.png'.format(self.epoch, idx, self.rank()))
        pil_image = to_pil_image(tiled_image)
        pil_image.save(save_path)
    
    if self.use_torch_dist:
      dist.barrier()

  def create_transfer_tiles(self, outputs, pred_dir, names):

    if self.rank() == 0:
      centers = outputs['xc_hat'].cpu()
      recons = outputs['xd_hat'].cpu()

      for idx, (c, rec, n) in enumerate(zip(centers, recons, names)):
        h = self.cfg.model.vq.n_emb2 + 1
        w = h
        rec = rec[:,:w,:,:,:]
        #Replace first sample with center
        rec[:,0,:,:,:] = c

        channels, height, width = rec.size(2), rec.size(3), rec.size(4)
        rec = rec.permute(2, 0, 3, 1, 4)                              # n, height, n, width, c
        rec = rec.contiguous().view(channels, h * height, w * width)

        save_path = os.path.join(pred_dir, 'seq-transfer-tiles-E{}-{}-{}-{}.png'.format(self.epoch, idx, self.rank(), n))
        pil_image = to_pil_image(rec)
        pil_image.save(save_path)

    if self.use_torch_dist:
      dist.barrier()

  def run_fid(self, vae_model, dataset, hypothese_count, pred_dir, stage, seq_model=None, seq_cond_model=None):

      if not os.path.exists(self.fid_file):
        return float('inf')

      if self.use_torch_dist:
        dist.barrier()

      scores = self.test_vae_fid(vae_model, dataset, hypothese_count, seq_model=seq_model, seq_cond_model=seq_cond_model)
      precision, recall, f_beta, f_beta_inv, train_fid, eval_fid, m, s, a = scores
      self.writer.add_scalar("{}/fid_train".format(stage), train_fid, self.epoch)
      self.writer.add_scalar("{}/fid_eval".format(stage), eval_fid, self.epoch)
      self.writer.add_scalar("{}/f_beta".format(stage), f_beta, self.epoch)
      self.writer.add_scalar("{}/f_beta_inv".format(stage), f_beta_inv, self.epoch)
      self.logger.info("[eval-{}] fid: train={:.4f} eval={:.4f} fbeta: {:.4f} fbeta_inv: {:.4f}".format(
          stage, train_fid, eval_fid, f_beta, f_beta_inv))

      metrics = {
              'rank': self.rank(), 'epoch': self.epoch,
              'precision': precision, 'recall': recall, 
              'f_beta': f_beta, 'f_beta_inv': f_beta_inv,
              'train_fid': train_fid, 'eval_fid': eval_fid 
              }

      #Save metrics to pickle
      save_path = os.path.join(pred_dir, 'metrics-{}-{}.pkl'.format(stage, self.rank()))
      self.metrics.append(metrics)
      #pickle_save(self.metrics, save_path)

      #Save activations for later - activations are too big.
      save_path = os.path.join(pred_dir, 'activations-{}-{}.npz'.format(stage, self.rank()))
      #save_statistics(save_path, m=m,s=s,a=a)

      #Make a chart for fid train and eval
      save_path = os.path.join(pred_dir, 'fidscore_{}.png'.format(self.rank()))
      save_fid_line(self.metrics, save_path)

      save_path = os.path.join(pred_dir, 'prdplot-single-{}.png'.format(self.rank()))
      plot_prd([[precision, recall]], out_path=save_path)

      plot_input = []
      labels = []
      for m in self.metrics[-5:]:
          labels.append(m['epoch'])
          plot_input.append([m['precision'], m['recall']])

      save_path = os.path.join(pred_dir, 'prdplot-all-{}.png'.format(self.rank()))
      plot_prd(plot_input, labels=labels, out_path=save_path)

      return eval_fid

  def test_vae_fid(self, vae_model, dataset, hypothese_count, seq_model=None, seq_cond_model=None):
    dims = 2048
    device = self.fid_device
    if self.use_torch_dist:
      size = float(dist.get_world_size())
      num_sample_per_gpu = int(np.ceil(self.total_fid_samples / size))
    else:
      num_sample_per_gpu = self.total_fid_samples

    g = self.create_generator(vae_model, dataset, self.bsz, num_sample_per_gpu, 
      hypothese_count, seq_model=seq_model, seq_cond_model=seq_cond_model)

    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
    model = InceptionV3([block_idx], model_dir=self.fid_dir).to()
    m, s, a = compute_statistics_of_generator(g, model, self.bsz, dims, device, max_samples=num_sample_per_gpu)

    # share m and s
    if self.use_torch_dist:
      dist.barrier()

      m = torch.from_numpy(m).cuda(self.device_id)
      s = torch.from_numpy(s).cuda(self.device_id)
      a = torch.from_numpy(a).cuda(self.device_id)

      average_tensor(m)
      average_tensor(s)
      average_tensor(a)
      m = m.cpu().numpy()
      s = s.cpu().numpy()
      a = a.cpu().numpy()

    # load precomputed m, s
    train_fid_path = os.path.join(self.fid_dir, self.cfg.data.name + '-train.npz')
    eval_fid_path = os.path.join(self.fid_dir, self.cfg.data.name + '-eval.npz')

    t_m0, t_s0, _ = load_statistics(train_fid_path)
    e_m0, e_s0, e_a0 = load_statistics(eval_fid_path)

    eval_size = e_a0.shape[0]
    if a.shape[0] != eval_size:
        #self.logger.info("PRD Embedding sizes not the same ({}={}). Automatically compensating".format(a.shape[0], eval_size))
        if a.shape[0] > eval_size:
            a = a[:eval_size, :]
        elif a.shape[0] < eval_size:
            e_a0 = e_a0[:a.shape[0], :]

    precision, recall = compute_prd_from_embedding(e_a0, a)
    f_beta, f_beta_inv = prd_to_max_f_beta_pair(precision, recall, beta=8)

    #precision, recall = np.mean(precision), np.mean(recall)
    train_fid = calculate_frechet_distance(t_m0, t_s0, m, s)
    eval_fid = calculate_frechet_distance(e_m0, e_s0, m, s)

    return precision, recall, f_beta, f_beta_inv, train_fid, eval_fid, m, s, a

  def create_generator(self, vae_model, dataset, batch_size, num_total_samples, hypothese_count, seq_model=None, seq_cond_model=None):
      if seq_model is None:
          return self.create_generator_vae(vae_model, dataset, batch_size, num_total_samples, hypothese_count)
      else:
          return self.create_generator_seq(vae_model, seq_model, batch_size, num_total_samples, hypothese_count, seq_cond_model=seq_cond_model)