mvq / runner / vae.py
vae.py
Raw
# ---------------------------------------------------------------
# Copyright (c) ___________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------

import torch
import torch.distributed as dist

from utils import create_eval_dir

from .base import BaseRunner

class VAERunner(BaseRunner):
  def __init__(self, cfg):
    super(VAERunner, self).__init__(cfg)

  def train(self, train_loader, models, opts, hypothese_count=32):

    self.tracker.reset_all()

    for m in models.values():
      m.train()

    for o in opts.values():
      o.zero_grad()
    vae_module = models['vae'].module if hasattr(models['vae'],'module') else models['vae']
    
    iterator = train_loader.__iter__()
    steps_in_epoch = len(iterator)
    for step, (x, tgt) in enumerate(iterator):
      x = x.cuda(self.device_id)
      tgt = tgt.cuda(self.device_id) 
      x = x / 256. * 255. + torch.rand_like(x) / 256.

      #Output loss and log
      opt_count = 2 if self.use_gan else 1
      total_loss = 0
      
      for opt_idx in range(opt_count):
        # Optimize generator or discriminator
        opt_model = 'vae' if opt_idx == 0 else 'disc'

        x_hat, disc_inputs, loss, outputs, log = models['vae'](
          x=x, 
          hypothese_count=hypothese_count, 
          optimizer_idx=opt_idx, 
          split='train')

        if 'disc' in models:
          gan_loss, disc_log = models['disc']( 
              nll_loss=disc_inputs['nll_loss'],
              codebook_loss=disc_inputs['cb_loss'],
              inputs=x, 
              reconstructions=x_hat, 
              optimizer_idx=opt_idx, 
              global_step=self.global_step,
              last_layer=vae_module.get_last_layer(),
              split='train'
              )
          
          loss += gan_loss
          log = {**log, **disc_log}

        # Store logs
        total_loss += loss.clone().detach().item()
        self.add_logs(split='train',log=log)

        # Compute gradients
        if self.gradient_accum_steps > 1:
          loss = loss / self.gradient_accum_steps

        if self.do_grad_scaling:
          self.scaler.scale(loss).backward()
        else:
          loss.backward()

        #Gradient Accumuation
        if (step + 1) % self.gradient_accum_steps == 0 or (
            # last step in epoch but step is always smaller than gradient_accum_steps
            steps_in_epoch <= self.gradient_accum_steps
            and (step + 1) == steps_in_epoch):

          if self.do_grad_scaling:
            #scale_before = self.scaler.get_scale()
            self.scaler.step(opts[opt_model])
            self.scaler.update()
            #scale_after = self.scaler.get_scale()
            #optimizer_was_run = scale_before <= scale_after
            opts[opt_model].zero_grad()
          else:
            opts[opt_model].step()
            opts[opt_model].zero_grad()
        else:
          pass

      self.add_logs(split='train',total_loss=total_loss)

      self.global_step += 1

      if isinstance(self.display, int) and step % self.display == 0 and step > 0:
        self.logger.info("E{:02d} S: {:03d} {}".format(
          self.epoch, step, self.tracker.loss_str('iter')))
        self.tracker.reset_interval('iter')

    self.logger.info("E{:02d} (train) {}".format(
      self.epoch, self.tracker.loss_str('epoch')))
    self.update_writer('train')

  def eval(self, eval_loader, models, hypothese_count=32, epoch=0, gen_samples=True):

    self.tracker.reset_all()

    if self.use_torch_dist:
      dist.barrier()

    for m in models.values():
      m.eval()

    hypothese_count = max([hypothese_count, self.subplots_cfg.columns]) 
    vae_module = models['vae'].module if hasattr(models['vae'],'module') else models['vae']

    pred_dir = create_eval_dir(self.cfg.sample_dirs['vae'], 'vae_eval', self.epoch)

    fid_score = float('inf')
    if self.cfg.eval.fid.use and epoch > self.cfg.train.epochs.warmup:

      fid_score = self.run_fid(vae_module, eval_loader, hypothese_count, pred_dir, stage='vae')
      self.writer.add_scalar("eval/{}_vae".format('fid'), fid_score , self.epoch)
      is_best = fid_score < self.best_score['vae']      

    for step, (x, tgt) in enumerate(eval_loader.__iter__()):
      x = x.cuda(self.device_id)
      tgt = tgt.cuda(self.device_id)

      with torch.no_grad():

        x_hat, disc_inputs, loss, outputs, log = models['vae'](
          x=x, 
          hypothese_count=hypothese_count, 
          optimizer_idx=0, 
          split='eval',
          return_outputs=self.cfg.eval.return_outputs)

        if 'disc' in models:
          gan_loss, disc_log = models['disc']( 
              nll_loss=disc_inputs['nll_loss'],
              codebook_loss=disc_inputs['cb_loss'],
              inputs=x, 
              reconstructions=x_hat, 
              optimizer_idx=0, 
              global_step=self.global_step,
              last_layer=vae_module.get_last_layer(),
              split='eval'
              )
          loss += gan_loss
          log = {**log, **disc_log}
      
      total_loss = loss.clone().detach().item()
      self.add_logs(split='eval',log=log, total_loss=total_loss)
      # Create samples
      if self.cfg.eval.return_outputs and (gen_samples or is_best):
        self.save_reconst(x, x_hat,  pred_dir, step)

    self.update_writer('eval')

    if self.use_torch_dist:
      dist.barrier()

    self.logger.info("Epoch {:02d} (eval) | {}".format(self.epoch, self.tracker.loss_str('epoch')))
    return fid_score

  def create_generator_vae(self, vae_model, dataset, batch_size, num_total_samples, hypothese_count, topk=10):
    total_samples = 0

    for _, (x, tgt) in enumerate(dataset.__iter__()):
      if total_samples > num_total_samples:
        break

      x = x.cuda(self.device_id)

      with torch.no_grad():
        x_hat, _, _, _, _ = vae_model(
          x=x, 
          hypothese_count=hypothese_count, split='eval')
      
      num_total_samples += x_hat.size(0)
      yield x_hat.float()