# --------------------------------------------------------------- # Copyright (c) __________________________ 2023. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # --------------------------------------------------------------- import os import numpy as np import torch import torch.distributed as dist from fid import calculate_frechet_distance from utils import ( pickle_save, create_recon_plots ) from .base import BaseRunner from models.constant import CHART_TO_HEAD_IDX class ContinuousRunner(BaseRunner): def __init__(self, stage, cfg): super(ContinuousRunner, self).__init__(cfg) self.stage = stage self.loss_weights = cfg.train.loss_weights def training_step(self, models, inputs, opt_idx): opt_model = 'continuous' if opt_idx == 0 else 'disc' with self.autocast_smart_context_manager(): x_hat, loss_dict, metric_log = models['continuous'](inputs) if 'disc' in models: disc_loss, disc_log = models['disc']( loss_dict=loss_dict, inputs=inputs['chart_data'], reconstructions=x_hat, optimizer_idx=opt_idx, global_step=self.global_step ) loss_dict = {**loss_dict, **disc_loss} metric_log = {**metric_log, **disc_log} loss_log = {} total_loss = 0.0 for name, loss in loss_dict.items(): l_name = name.split('/')[-1].replace('_loss','') weight = 1.0 if hasattr(self.loss_weights, l_name): weight = getattr(self.loss_weights, l_name) total_loss += loss * weight loss_log[name] = loss.detach().cpu() if self.use_torch_dist: total_loss = total_loss.mean() if self.gradient_accum_steps > 1: total_loss = total_loss / self.gradient_accum_steps if self.do_grad_scaling: self.scaler.scale(total_loss).backward() else: total_loss.backward() return loss_log, metric_log def train(self, train_loader, models, tokenizers, opts, schs): self.tracker.reset_all() tr_loss = torch.tensor(0.0).to(self.device_id) for m in models.values(): m.train() m.zero_grad() for o in opts.values(): if o is not None: o.zero_grad() iterator = train_loader.__iter__() steps_in_epoch = len(iterator) for step, (_, inputs) in enumerate(iterator): opt_count = 2 if 'disc' in models else 1 for opt_idx in range(opt_count): opt_model = 'continuous' if opt_idx == 0 else 'disc' loss_log, metric_log = self.training_step(models, inputs, opt_idx=opt_idx) tr_loss_step = sum(list(loss_log.values())) tr_loss += tr_loss_step self.tracker.add_logs(split='train', log=loss_log, total_loss=tr_loss_step) self.tracker.add_metrics(split='train', metrics=metric_log, metric_name='continuous') if (step + 1) % self.gradient_accum_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accum_steps steps_in_epoch <= self.gradient_accum_steps and (step + 1) == steps_in_epoch ): if self.do_grad_scaling: self.scaler.unscale_(opts[opt_model]) if self.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(models[opt_model].parameters(), self.max_grad_norm) if self.do_grad_scaling: self.scaler.step(opts[opt_model]) self.scaler.update() else: opts[opt_model].step() models[opt_model].zero_grad() opts[opt_model].zero_grad() self.global_step += 1 tr_loss = 0 if isinstance(self.display, int) and step % self.display == 0 and step > 0: self.logger.info("E{:02d} GS: {:03d} {} {}".format( self.epoch, self.global_step, self.tracker.loss_str('iter'), self.tracker.metric_str('iter'))) self.tracker.reset_interval('iter') self.logger.info("E{:02d} (train) {} {}".format(self.epoch, self.tracker.loss_str('epoch'), self.tracker.metric_str('epoch'))) self.update_writer('train') if self.use_torch_dist: dist.barrier() def eval(self, val_loader, models, metric_key_prefix='eval', epoch=0, **kwargs): self.tracker.reset_all() iterator = val_loader.__iter__() steps_in_epoch = len(iterator) if self.use_torch_dist: dist.barrier() for m in models.values(): if m is not None: m.eval() fid_log = {} if self.use_fid: fid_log['train_fid'], fid_log['test_fid'] = self.compute_fid(val_loader, models) self.tracker.add_metrics(split=metric_key_prefix, metrics=fid_log, metric_name='fid') if (epoch + 1) % self.cfg.eval.sample_epoch == 0: epoch_dir = os.path.join(self.cfg.sample_dirs[self.stage], "{}".format(epoch)) if not os.path.exists(epoch_dir): os.makedirs(epoch_dir, exist_ok=True) tr_loss = 0.0 for step, (indices, inputs) in enumerate(iterator): with torch.no_grad(): with self.autocast_smart_context_manager(): x_hat, loss_dict, metric_log = models['continuous'](inputs, is_train=False, split=metric_key_prefix) if 'disc' in models: disc_loss, disc_log = models['disc']( loss_dict=loss_dict, inputs=inputs, reconstructions=x_hat, optimizer_idx=0, global_step=self.global_step ) loss_dict = {**loss_dict, **disc_loss} metric_log = {**metric_log, **disc_log} loss_log = {} tr_loss_step = 0.0 for name, loss in loss_dict.items(): tr_loss_step += loss.detach().cpu() loss_log[name] = loss.detach().cpu() tr_loss += tr_loss_step self.tracker.add_logs(split=metric_key_prefix, log=loss_log, total_loss=tr_loss_step) self.tracker.add_metrics(split=metric_key_prefix, metrics=metric_log, metric_name='continuous') if (step + 1) % self.gradient_accum_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accum_steps steps_in_epoch <= self.gradient_accum_steps and (step + 1) == steps_in_epoch ): tr_loss = 0 if (step + 1) % self.cfg.eval.sample_interval == 0 and (epoch + 1) % self.cfg.eval.sample_epoch == 0: #self.to_vega_json(x_hat, prefix=metric_key_prefix, step=step, epoch=epoch) text_data = [val_loader.dataset.get_text_with_idx(ind) for ind in indices] create_recon_plots(inputs, x_hat, text_data, step, epoch_dir) if isinstance(self.display, int) and step % self.display == 0 and step > 0: self.logger.info("E{:02d} GS: {:04d} {} {}".format( self.epoch, self.global_step, self.tracker.loss_str('iter'), self.tracker.metric_str('iter'))) self.tracker.reset_interval('iter') self.logger.info("E{:02d} (eval) {} {}".format( self.epoch, self.tracker.loss_str('epoch'), self.tracker.metric_str('epoch'))) self.update_writer('eval') if self.use_torch_dist: dist.barrier() #Calculate scores for saving if self.use_fid: score = fid_log['test_fid'] else: score_names = ['metric/eval/row/total','metric/eval/col/total','metric/eval/scale/total','metric/eval/cont/total'] score = sum(self.tracker.get_loss('epoch', s) for s in score_names) return {'score': score} def generate(self, val_loader, models, tokenizers, metric_key_prefix='generate', epoch=0): self.tracker.reset_all() iterator = val_loader.__iter__() steps_in_epoch = len(iterator) if self.use_torch_dist: dist.barrier() for m in models.values(): if m is not None: m.eval() batch_size = self.bsz for step, (idx, inputs) in enumerate(iterator): if isinstance(self.display, int) and step % self.display == 0 and step > 0: self.logger.info("Eval | E{:02d} Step {:04d}/{:04d} ".format(self.epoch, step, self.cfg.eval.max_steps)) if isinstance(self.cfg.eval.gen_steps, int) and step > self.cfg.eval.gen_steps: break #Sample the indices with torch.no_grad(): with self.autocast_smart_context_manager(): cb_ind1 = models['continuous'].sample_codebook(inputs) cb_ind2 = None if len(cb_ind1) == 2: cb_ind1, cb_ind2 = cb_ind1 else: cb_ind1 = cb_ind1[0] ct_idx = [CHART_TO_HEAD_IDX[ct] for ct in inputs['chart_data']['chart_type']] ct_idx = torch.tensor(ct_idx, dtype=torch.long, device=self.device).view(-1,1) with torch.no_grad(): with self.autocast_smart_context_manager(): samples = models['continuous'].reconstruct_from_indices( ct_idx=ct_idx, cb_ind1=cb_ind1, cb_ind2=cb_ind2, hypo_count=self.cfg.eval.hypo_count, hypo_bsz=self.cfg.eval.hypo_bsz ) if self.use_torch_dist: dist.barrier() self.logger.info("Epoch {:02d} ({}) | {} ".format(epoch, metric_key_prefix, self.tracker.loss_str('epoch'))) def save(self, loader, models, prefix): data_path = os.path.join(self.cfg.data_path, 'processed') save_dir = os.path.join(data_path, self.cfg.exp_name) os.makedirs(save_dir, exist_ok=True) file_path = os.path.join(save_dir, f'{prefix}.pkl') self.logger.info(f"Saving data @ {file_path}") iterator = loader.__iter__() if self.use_torch_dist: dist.barrier() for m in models.values(): if m is not None: m.eval() container = [] for _, (idx, inputs) in enumerate(iterator): #Sample the indices with torch.no_grad(): with self.autocast_smart_context_manager(): cb_ind1 = models['continuous'].sample_codebook(inputs) cb_ind2 = None if len(cb_ind1) == 2: cb_ind1, cb_ind2 = cb_ind1 cb_ind2 = cb_ind2.detach().cpu().tolist() else: cb_ind1 = cb_ind1[0] cb_ind1 = cb_ind1.detach().cpu().tolist() #Get data from dataset data = loader.dataset.get_data_with_idx(idx) for d, ind1, ind2 in zip(data, cb_ind1, cb_ind2): d['codebook'] = {} d['codebook'][0] = ind1 d['codebook'][1] = ind2 container.append(d) pickle_save(container, file_path) def compute_fid(self, loader, models): assert 'fid' in models and 'continuous' in models iterator = loader.__iter__() if self.use_torch_dist: dist.barrier() for m in models.values(): if m is not None: m.eval() act_container = [] for (_, inputs) in iterator: with torch.no_grad(): with self.autocast_smart_context_manager(): x_hat, _, _ = models['continuous'](inputs, is_train=False, temp=1.0) activations, _, _ = models['fid'](x_hat) act_container.append(activations) act = np.concatenate(act_container, axis=0) mu = np.mean(act, axis=0) sigma = np.cov(act, rowvar=False) #Load existing fid scores m1, s1, _ = self.fid_stats['train'] m2, s2, _ = self.fid_stats['test'] train_fid = calculate_frechet_distance(mu, sigma, m1, s1) test_fid = calculate_frechet_distance(mu, sigma, m2, s2) return train_fid, test_fid