# --------------------------------------------------------------- # Copyright (c) ___________________ 2023. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # --------------------------------------------------------------- import torch import torch.distributed as dist from utils import create_eval_dir from .base import BaseRunner class VAERunner(BaseRunner): def __init__(self, cfg): super(VAERunner, self).__init__(cfg) def train(self, train_loader, models, opts, hypothese_count=32): self.tracker.reset_all() for m in models.values(): m.train() for o in opts.values(): o.zero_grad() vae_module = models['vae'].module if hasattr(models['vae'],'module') else models['vae'] iterator = train_loader.__iter__() steps_in_epoch = len(iterator) for step, (x, tgt) in enumerate(iterator): x = x.cuda(self.device_id) tgt = tgt.cuda(self.device_id) x = x / 256. * 255. + torch.rand_like(x) / 256. #Output loss and log opt_count = 2 if self.use_gan else 1 total_loss = 0 for opt_idx in range(opt_count): # Optimize generator or discriminator opt_model = 'vae' if opt_idx == 0 else 'disc' x_hat, disc_inputs, loss, outputs, log = models['vae']( x=x, hypothese_count=hypothese_count, optimizer_idx=opt_idx, split='train') if 'disc' in models: gan_loss, disc_log = models['disc']( nll_loss=disc_inputs['nll_loss'], codebook_loss=disc_inputs['cb_loss'], inputs=x, reconstructions=x_hat, optimizer_idx=opt_idx, global_step=self.global_step, last_layer=vae_module.get_last_layer(), split='train' ) loss += gan_loss log = {**log, **disc_log} # Store logs total_loss += loss.clone().detach().item() self.add_logs(split='train',log=log) # Compute gradients if self.gradient_accum_steps > 1: loss = loss / self.gradient_accum_steps if self.do_grad_scaling: self.scaler.scale(loss).backward() else: loss.backward() #Gradient Accumuation if (step + 1) % self.gradient_accum_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accum_steps steps_in_epoch <= self.gradient_accum_steps and (step + 1) == steps_in_epoch): if self.do_grad_scaling: #scale_before = self.scaler.get_scale() self.scaler.step(opts[opt_model]) self.scaler.update() #scale_after = self.scaler.get_scale() #optimizer_was_run = scale_before <= scale_after opts[opt_model].zero_grad() else: opts[opt_model].step() opts[opt_model].zero_grad() else: pass self.add_logs(split='train',total_loss=total_loss) self.global_step += 1 if isinstance(self.display, int) and step % self.display == 0 and step > 0: self.logger.info("E{:02d} S: {:03d} {}".format( self.epoch, step, self.tracker.loss_str('iter'))) self.tracker.reset_interval('iter') self.logger.info("E{:02d} (train) {}".format( self.epoch, self.tracker.loss_str('epoch'))) self.update_writer('train') def eval(self, eval_loader, models, hypothese_count=32, epoch=0, gen_samples=True): self.tracker.reset_all() if self.use_torch_dist: dist.barrier() for m in models.values(): m.eval() hypothese_count = max([hypothese_count, self.subplots_cfg.columns]) vae_module = models['vae'].module if hasattr(models['vae'],'module') else models['vae'] pred_dir = create_eval_dir(self.cfg.sample_dirs['vae'], 'vae_eval', self.epoch) fid_score = float('inf') if self.cfg.eval.fid.use and epoch > self.cfg.train.epochs.warmup: fid_score = self.run_fid(vae_module, eval_loader, hypothese_count, pred_dir, stage='vae') self.writer.add_scalar("eval/{}_vae".format('fid'), fid_score , self.epoch) is_best = fid_score < self.best_score['vae'] for step, (x, tgt) in enumerate(eval_loader.__iter__()): x = x.cuda(self.device_id) tgt = tgt.cuda(self.device_id) with torch.no_grad(): x_hat, disc_inputs, loss, outputs, log = models['vae']( x=x, hypothese_count=hypothese_count, optimizer_idx=0, split='eval', return_outputs=self.cfg.eval.return_outputs) if 'disc' in models: gan_loss, disc_log = models['disc']( nll_loss=disc_inputs['nll_loss'], codebook_loss=disc_inputs['cb_loss'], inputs=x, reconstructions=x_hat, optimizer_idx=0, global_step=self.global_step, last_layer=vae_module.get_last_layer(), split='eval' ) loss += gan_loss log = {**log, **disc_log} total_loss = loss.clone().detach().item() self.add_logs(split='eval',log=log, total_loss=total_loss) # Create samples if self.cfg.eval.return_outputs and (gen_samples or is_best): self.save_reconst(x, x_hat, pred_dir, step) self.update_writer('eval') if self.use_torch_dist: dist.barrier() self.logger.info("Epoch {:02d} (eval) | {}".format(self.epoch, self.tracker.loss_str('epoch'))) return fid_score def create_generator_vae(self, vae_model, dataset, batch_size, num_total_samples, hypothese_count, topk=10): total_samples = 0 for _, (x, tgt) in enumerate(dataset.__iter__()): if total_samples > num_total_samples: break x = x.cuda(self.device_id) with torch.no_grad(): x_hat, _, _, _, _ = vae_model( x=x, hypothese_count=hypothese_count, split='eval') num_total_samples += x_hat.size(0) yield x_hat.float()