# --------------------------------------------------------------- # Copyright (c) ___________________ 2023. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # --------------------------------------------------------------- import time import numpy as np from runner import vae import torch import torch.distributed as dist from torch.autograd import Variable from torch.nn import functional as F from utils import create_eval_dir from .base import BaseRunner class SEQRunner(BaseRunner): def __init__(self, cfg): super(SEQRunner, self).__init__(cfg) self.infer_time = {'seq': [], 'rec': []} def train(self, train_loader, models, opts, hypothese_count=32): if self.cfg.debug: self.logger.info("Starting training loop") # Check if sequence model is two stages two_stage_flag = 'seq_cond' in models self.tracker.reset_all() for m in models.values(): m.train() opts['seq_base'].zero_grad() if two_stage_flag: opts['seq_cond'].zero_grad() vae_module = models['vae'].module if hasattr(models['vae'], 'module') else models['vae'] iterator = train_loader.__iter__() steps_in_epoch = len(iterator) for step, (x, tgt) in enumerate(train_loader.__iter__()): if self.cfg.gpu.use: x = x.cuda(self.device_id) tgt = tgt.cuda(self.device_id) condition = tgt if self.cfg.model.gan.disc_conditional else None with torch.no_grad(): cb_ind = vae_module.sample_codebook(x) if two_stage_flag: cb_ind, cb_ind1 = cb_ind _, _, loss = models['seq_base']( cb_ind, condition=condition) loss.backward() if (step + 1) % self.gradient_accum_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accum_steps steps_in_epoch <= self.gradient_accum_steps and (step + 1) == steps_in_epoch): opts['seq_base'].step() opts['seq_base'].zero_grad() for t in ['epoch', 'iter']: self.tracker.add(t, 'ce1_loss', loss.item()) if two_stage_flag: if self.cfg.debug: self.logger.info("seq_cond.forward") _, _, loss = models['seq_cond'](cb_ind1, condition=cb_ind) if self.cfg.debug: self.logger.info("seq_cond.get_loss") #loss = seqc_module.get_loss(logits, cb_ind1) if self.cfg.debug: self.logger.info("seq_cond.backward") loss.backward() if (step + 1) % self.gradient_accum_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accum_steps steps_in_epoch <= self.gradient_accum_steps and (step + 1) == steps_in_epoch): if self.cfg.debug: self.logger.info("seq_cond.step") opts['seq_cond'].step() opts['seq_cond'].zero_grad() for t in ['epoch', 'iter']: self.tracker.add(t, 'ce2_loss', loss.item()) if isinstance(self.display, int) and step % self.display == 0 and step > 0: self.logger.info("E{:02d} S: {:03d} {}".format(self.epoch, step, self.tracker.loss_str('iter'))) self.tracker.reset_interval('iter') if self.cfg.debug: self.logger.info("Finished training loop") self.update_writer('train') self.logger.info("E{:02d} (seq train) | {} ".format( self.epoch, self.tracker.loss_str('epoch'))) if self.use_torch_dist: dist.barrier() def eval(self, eval_loader, models, hypothese_count=1, temp=1.0, gen_samples=True, **kwargs): #Force only one hypothese_count = 1 # Check if sequence model is two stages two_stage_flag = 'seq_cond' in models self.tracker.reset_all() for m in models.values(): m.eval() vae_module = models['vae'].module if hasattr(models['vae'], 'module') else models['vae'] p1 = self.seq_shape1 p2 = self.seq_shape2 if self.cfg.model.backbone == 'vq2': p1, p2 = p2, p1 pred_dir = create_eval_dir(self.cfg.sample_dirs['seq_base'], 'seq_eval', self.epoch) if self.cfg.debug: self.logger.info("Start eval loop") fid_score = float('inf') if self.cfg.eval.fid.use: #Only one hypotheses for testing fid_score = self.run_fid(models['vae'], eval_loader, 1, pred_dir, stage='seq', seq_model=models['seq_base'], seq_cond_model=models.get('seq_cond')) self.writer.add_scalar('eval/{}_seq'.format('fid'), fid_score , self.epoch) ind_1 = self.sample_seq_model(models['seq_base'], p1, temp=temp) ind_2 = None if two_stage_flag: ind_2 = self.sample_seq_model(models['seq_cond'], p2, temp=temp, condition=Variable(ind_1)) if self.cfg.eval.return_outputs and self.transfer_cfg.use and gen_samples: outputs = {} with torch.no_grad(): out = vae_module.reconstruct_from_indices(ind_1=ind_1, ind_2=ind_2, hypothese_count=hypothese_count) xd_hat = out['xd_hat'].unsqueeze(1) outputs['xc_hat'] = out['xc_hat'] outputs['xd_hat'] = xd_hat cond_ind = torch.zeros_like(ind_2) num_embed2 = self.cfg.model.vq.n_emb2 names = ind_2.view(-1).detach().cpu().tolist() for embed_idx in range(num_embed2): cond_ind[:,:,-1] = embed_idx with torch.no_grad(): out = vae_module.reconstruct_from_indices( ind_1=ind_1, ind_2=cond_ind, hypothese_count=hypothese_count) xd_hat = out['xd_hat'].unsqueeze(1) outputs['xd_hat'] = torch.concat([outputs['xd_hat'], xd_hat], dim=1) self.create_transfer_tiles(outputs, pred_dir, names) elif self.cfg.eval.return_outputs and gen_samples: with torch.no_grad(): outputs = vae_module.reconstruct_from_indices( ind_1=ind_1, ind_2=ind_2, hypothese_count=hypothese_count) self.create_center_subplot(outputs, pred_dir) self.update_writer('eval') if self.cfg.debug: self.logger.info("Finished eval loop") if self.use_torch_dist: dist.barrier() #self.logger.info("Epoch {:02d} (seq eval) | {} ".format(self.epoch, self.tracker.loss_str('epoch'))) return fid_score def create_generator_seq(self, vae_model, seq_model, batch_size, num_total_samples, hypothese_count, seq_cond_model=None, temp=1.0): num_iters = int(np.ceil(num_total_samples / batch_size)) total_samples = 0 ind_2 = None vae_module = vae_model.module if hasattr(vae_model,'module') else vae_model p1 = self.seq_shape1 p2 = self.seq_shape2 if self.cfg.model.backbone == 'vq2': p1, p2 = p2, p1 #Measure inference time self.infer_time = {'seq': [], 'rec': []} for _ in range(num_iters): if total_samples > num_total_samples: break ind_1 = self.sample_seq_model(seq_model, p1, temp=temp) if seq_cond_model is not None: ind_2 = self.sample_seq_model(seq_cond_model, p2, temp=temp, condition=Variable(ind_1)) with torch.no_grad(): outputs = vae_module.reconstruct_from_indices( ind_1=ind_1, ind_2=ind_2, hypothese_count=hypothese_count) if 'xd_hat' not in outputs: out = outputs['xc_hat'] else: out = torch.flatten(outputs['xd_hat'], start_dim=0, end_dim=1) num_total_samples += out.size(0) yield out.float() def sample_seq_model(self, model, size, temp=1.0, condition=None): cache = {} ind = torch.zeros(size, dtype=torch.long) if self.cfg.gpu.use: ind = ind.cuda(self.device_id) with torch.no_grad(): for i in range(size[-1]): for j in range(size[-2]): out, cache, _ = model(Variable(ind), condition=condition, cache=cache) probs = F.softmax(out[:, :, j, i] / temp, dim=1).data val = torch.multinomial(probs, 1).view(-1) ind[:, j, i] = val return ind