mvq / runner / seq.py
seq.py
Raw
# ---------------------------------------------------------------
# Copyright (c) ___________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------

import time
import numpy as np
from runner import vae
import torch
import torch.distributed as dist
from torch.autograd import Variable
from torch.nn import functional as F

from utils import create_eval_dir
from .base import BaseRunner

class SEQRunner(BaseRunner):
  def __init__(self, cfg):
    super(SEQRunner, self).__init__(cfg)
    self.infer_time = {'seq': [], 'rec': []}

  def train(self, train_loader, models, opts, hypothese_count=32):

    if self.cfg.debug:
      self.logger.info("Starting training loop")
    
    # Check if sequence model is two stages
    two_stage_flag = 'seq_cond' in models

    self.tracker.reset_all()
    for m in models.values():
      m.train()

    opts['seq_base'].zero_grad()
    if two_stage_flag:
      opts['seq_cond'].zero_grad()
    
    vae_module = models['vae'].module if hasattr(models['vae'], 'module') else models['vae']

    iterator = train_loader.__iter__()
    steps_in_epoch = len(iterator)

    for step, (x, tgt) in enumerate(train_loader.__iter__()):

      if self.cfg.gpu.use:
        x = x.cuda(self.device_id)
        tgt = tgt.cuda(self.device_id)

      condition = tgt if self.cfg.model.gan.disc_conditional else None

      with torch.no_grad():
        cb_ind = vae_module.sample_codebook(x)

        if two_stage_flag:
          cb_ind, cb_ind1 = cb_ind

      _, _, loss = models['seq_base'](
        cb_ind, condition=condition)

      loss.backward()
      
      if (step + 1) % self.gradient_accum_steps == 0 or (
            # last step in epoch but step is always smaller than gradient_accum_steps
            steps_in_epoch <= self.gradient_accum_steps
            and (step + 1) == steps_in_epoch):

      
        opts['seq_base'].step()
        opts['seq_base'].zero_grad()


      for t in ['epoch', 'iter']:
          self.tracker.add(t, 'ce1_loss', loss.item())

      if two_stage_flag:
        
        if self.cfg.debug:
          self.logger.info("seq_cond.forward")

        _, _, loss = models['seq_cond'](cb_ind1, condition=cb_ind)
        if self.cfg.debug:
          self.logger.info("seq_cond.get_loss")
        #loss = seqc_module.get_loss(logits, cb_ind1)
        if self.cfg.debug:
          self.logger.info("seq_cond.backward")

        loss.backward()

        if (step + 1) % self.gradient_accum_steps == 0 or (
              # last step in epoch but step is always smaller than gradient_accum_steps
              steps_in_epoch <= self.gradient_accum_steps
              and (step + 1) == steps_in_epoch):
          if self.cfg.debug:
            self.logger.info("seq_cond.step")
          opts['seq_cond'].step()
          opts['seq_cond'].zero_grad()


        for t in ['epoch', 'iter']:
          self.tracker.add(t, 'ce2_loss', loss.item())
        
      if isinstance(self.display, int) and step % self.display == 0 and step > 0:
        self.logger.info("E{:02d} S: {:03d} {}".format(self.epoch, step, self.tracker.loss_str('iter')))
        self.tracker.reset_interval('iter')

    if self.cfg.debug:
      self.logger.info("Finished training loop")

    self.update_writer('train')

    self.logger.info("E{:02d} (seq train) | {} ".format(
      self.epoch, self.tracker.loss_str('epoch')))
      
    if self.use_torch_dist:
      dist.barrier()

  def eval(self, eval_loader, models, hypothese_count=1, temp=1.0, gen_samples=True, **kwargs):
    #Force only one
    hypothese_count = 1

    # Check if sequence model is two stages
    two_stage_flag = 'seq_cond' in models

    self.tracker.reset_all()
    for m in models.values():
      m.eval()

    vae_module = models['vae'].module if hasattr(models['vae'], 'module') else models['vae']

    p1 = self.seq_shape1
    p2 = self.seq_shape2
    if self.cfg.model.backbone == 'vq2':
      p1, p2 = p2, p1

    pred_dir = create_eval_dir(self.cfg.sample_dirs['seq_base'], 'seq_eval', self.epoch)
    if self.cfg.debug:
      self.logger.info("Start eval loop")

    fid_score = float('inf')
    if self.cfg.eval.fid.use:
      #Only one hypotheses for testing
      fid_score = self.run_fid(models['vae'], eval_loader, 1, 
        pred_dir, stage='seq', seq_model=models['seq_base'], 
        seq_cond_model=models.get('seq_cond'))
      self.writer.add_scalar('eval/{}_seq'.format('fid'), fid_score , self.epoch)
    
    ind_1 = self.sample_seq_model(models['seq_base'], p1, temp=temp)
    ind_2 = None
    
    if two_stage_flag:
      ind_2 = self.sample_seq_model(models['seq_cond'], 
        p2, temp=temp, condition=Variable(ind_1))

      
    if self.cfg.eval.return_outputs and self.transfer_cfg.use and gen_samples:
      outputs = {}
      with torch.no_grad():
        out = vae_module.reconstruct_from_indices(ind_1=ind_1, ind_2=ind_2, hypothese_count=hypothese_count)
      xd_hat = out['xd_hat'].unsqueeze(1)
      
      outputs['xc_hat'] = out['xc_hat']
      outputs['xd_hat'] = xd_hat

      cond_ind = torch.zeros_like(ind_2)
      num_embed2 = self.cfg.model.vq.n_emb2
      
      names = ind_2.view(-1).detach().cpu().tolist()
      for embed_idx in range(num_embed2):
        cond_ind[:,:,-1] = embed_idx
        
        with torch.no_grad():
          out = vae_module.reconstruct_from_indices(
            ind_1=ind_1, ind_2=cond_ind, hypothese_count=hypothese_count)
          xd_hat = out['xd_hat'].unsqueeze(1)
          outputs['xd_hat'] = torch.concat([outputs['xd_hat'], xd_hat], dim=1)
      
      self.create_transfer_tiles(outputs, pred_dir, names)

    elif self.cfg.eval.return_outputs and gen_samples:
      with torch.no_grad():
        outputs = vae_module.reconstruct_from_indices(
          ind_1=ind_1, ind_2=ind_2, hypothese_count=hypothese_count)
      self.create_center_subplot(outputs, pred_dir)

    self.update_writer('eval')
    if self.cfg.debug:
      self.logger.info("Finished eval loop")

    if self.use_torch_dist:
      dist.barrier()

    #self.logger.info("Epoch {:02d} (seq eval) | {} ".format(self.epoch, self.tracker.loss_str('epoch')))
    return fid_score

  def create_generator_seq(self, vae_model, seq_model, batch_size, num_total_samples, hypothese_count, seq_cond_model=None, temp=1.0):

    num_iters = int(np.ceil(num_total_samples / batch_size))
    total_samples = 0
    ind_2 = None

    vae_module = vae_model.module if hasattr(vae_model,'module') else vae_model
    
    p1 = self.seq_shape1
    p2 = self.seq_shape2
    if self.cfg.model.backbone == 'vq2':
      p1, p2 = p2, p1

    #Measure inference time
    self.infer_time = {'seq': [], 'rec': []}

    for _ in range(num_iters):
      if total_samples > num_total_samples:
          break
      
      ind_1 = self.sample_seq_model(seq_model, p1, temp=temp)

      if seq_cond_model is not None:
        ind_2 = self.sample_seq_model(seq_cond_model, 
          p2, temp=temp, condition=Variable(ind_1))

      with torch.no_grad():
        outputs = vae_module.reconstruct_from_indices(
          ind_1=ind_1, ind_2=ind_2, hypothese_count=hypothese_count)

      if 'xd_hat' not in outputs:
        out = outputs['xc_hat']
      else:
        out = torch.flatten(outputs['xd_hat'], start_dim=0, end_dim=1) 

      num_total_samples += out.size(0)
      yield out.float()
    

  def sample_seq_model(self, model, size, temp=1.0, condition=None):
    cache = {}
    ind = torch.zeros(size, dtype=torch.long)

    if self.cfg.gpu.use:
      ind = ind.cuda(self.device_id)

    with torch.no_grad():
      for i in range(size[-1]):
        for j in range(size[-2]):
          out, cache, _ = model(Variable(ind), condition=condition, cache=cache)
          probs = F.softmax(out[:, :, j, i] / temp, dim=1).data
          val =  torch.multinomial(probs, 1).view(-1)
          ind[:, j, i] = val

    return ind