honeyplotnet / runner / seq.py
seq.py
Raw
# ---------------------------------------------------------------
# Copyright (c) ____________________________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------

from filelock import FileLock
from typing import Dict, NamedTuple, Optional, Tuple, Union

import os
import nltk  
import time
import numpy as np
import torch
from torch.distributions.categorical import Categorical

from transformers.trainer_pt_utils import (
  find_batch_size, nested_concat, 
  nested_numpify, nested_truncate,
)

from transformers.trainer_utils import ( 
  EvalPrediction, denumpify_detensorize
)

from models.constant import CHART_TO_HEAD_IDX, HEAD_IDX_TO_CHART, UNIQ_CHART_HEADS
from dataset.base import shift_tokens_right_pad
from runner.text import ChartTextRunner
from fid import calculate_frechet_distance
from utils.constant import TASK2IDX

try:
    nltk.data.find("tokenizers/punkt")
except (LookupError, OSError):
    with FileLock(".lock") as lock:
        nltk.download("punkt", quiet=True)

class EvalLoopOutputwInputs(NamedTuple):
    predictions: Union[np.ndarray, Tuple[np.ndarray]]
    label_ids_text: Optional[np.ndarray]
    label_ids_code: Optional[np.ndarray]
    metrics: Optional[Dict[str, float]]
    num_samples: Optional[int]
    inputs: Optional[np.ndarray]

class SeqRunner(ChartTextRunner):
  def __init__(self, stage, cfg):
    super(SeqRunner, self).__init__(stage, cfg)
    self.loss_fn_ = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction='none')

  def set_model_output(self, models, mode):
    '''
    Swaps between text and data generation mode.
    '''
    if hasattr(models['seq'], 'module'):
      models['seq'].module.set_output(mode)
    else:
      models['seq'].set_output(mode) 

  def compute_loss(self, models, inputs, text_lbls=None, data_lbls=None, return_outputs=False, text_tasks=None):

    text_inputs = self._prepare_inputs(inputs['text'])
    
    self.set_model_output(models, 'both')

    text_logits, data_logits = models['seq'](**text_inputs)

    text_mask, data_mask = None, None
    if text_tasks is not None:
      text_mask = torch.tensor(text_tasks, dtype=torch.float32, device=self.device)
      data_mask = (text_mask == 0).to(torch.float32)

    loss = {}
    if text_lbls is not None:
      text_lbls = text_lbls.to(self.device)
      batch_sz, label_len = text_lbls.shape
      text_logits = text_logits[:,:label_len,:]
      flat_text_logits = torch.flatten(text_logits, start_dim=0, end_dim=1)
      text_loss = self.loss_fn_(flat_text_logits, text_lbls.view(-1))
      text_loss = torch.nan_to_num(text_loss, nan=1e-9)
      text_loss = torch.reshape(text_loss, (batch_sz, label_len))
      non_zeros_count = (text_loss > 1e-5).to(torch.float32)
      text_loss = text_loss.sum(-1) / non_zeros_count.sum(-1) if non_zeros_count.sum() > 0 else text_loss.sum()
      if text_mask is not None:
        text_loss = text_loss * text_mask
        text_loss = text_loss.sum() / text_mask.sum()  if text_mask.sum() > 0 else text_loss.sum()
      else:
        text_loss = text_loss.mean()
      
      loss['text'] = text_loss.mean() * self.cfg.train.loss_weights.text

    if data_lbls is not None:
      data_lbls = data_lbls.to(self.device)
      batch_sz, label_len = data_lbls.shape
      data_logits = data_logits[:,:label_len,:]
      flat_data_logits = torch.flatten(data_logits, start_dim=0, end_dim=1)
      data_loss = self.loss_fn_(flat_data_logits, data_lbls.view(-1))
      data_loss = torch.nan_to_num(data_loss, nan=1e-9)
      data_loss = torch.reshape(data_loss, (batch_sz, label_len))
      non_zeros_count = (data_loss > 1e-5).to(torch.float32)
      data_loss = data_loss.sum(-1) / non_zeros_count.sum(-1) if non_zeros_count.sum() > 0 else data_loss.sum()
      if data_mask is not None:
        data_loss = data_loss * data_mask
        data_loss = data_loss.sum() / data_mask.sum() if data_mask.sum() > 0 else data_loss.sum()
      else:
        data_loss = data_loss.mean()
      
      loss['data'] = data_loss.mean() * self.cfg.train.loss_weights.code
    
    outputs = {'text': text_logits, 'data': data_logits}
    return (loss, outputs, ) if return_outputs else loss
  
  def sample_data_labels(self, models, inputs, tokenizers):
    
    models['continuous'].eval()

    with self.autocast_smart_context_manager():
      with torch.no_grad():
        if hasattr(models['continuous'], 'module'):
          cb1 = models['continuous'].module.sample_codebook(inputs['data'])
        else:
          cb1 = models['continuous'].sample_codebook(inputs['data'])

      if len(cb1) == 2:
        cb1, cb2 = cb1
      else:
        cb1, cb2 = cb1[0], None

      ct_idx = [CHART_TO_HEAD_IDX[ct] for ct in inputs['data']['chart_type']]
      ct_idx = torch.tensor(ct_idx, dtype=torch.long, device=self.device).view(-1,1)

      #################################
      # OFFSETS for the data codebook
      # ct_idx: + 2 
      # cb1   : + 2 + 3 (unique charts)
      # cb2   : + 2 + 3 + (cfg.model.continuous_data.vq.n_emb1)
      ct_idx = ct_idx + 2
      cb1    = cb1    + 2 + len(UNIQ_CHART_HEADS)
      if cb2 is not None:
        cb2  = cb2    + 2 + len(UNIQ_CHART_HEADS) #+ self.cfg.model.continuous_data.vq.n_emb1

      code_labels = torch.cat([ct_idx, cb1], dim=-1)
      if cb2 is not None:
        code_labels = torch.cat([code_labels, cb2], dim=-1)

      eos_token_id = tokenizers['seq'].eos_token_id
      pad_token_id = tokenizers['seq'].pad_token_id

      #Add eos token
      padding = torch.ones([code_labels.shape[0], 1], dtype=torch.long, device=self.device) 
      code_labels = torch.cat([code_labels, padding * eos_token_id], dim=-1)
      #Shift all labels up by one because 0 is reserved for start token
      decoder_input_ids = shift_tokens_right_pad(code_labels, pad_token_id=pad_token_id)

    return code_labels, decoder_input_ids

  def training_step(self, models, inputs, tokenizers, text_tasks=None):
    
    models['seq'].train()

    data_lbls, decoder2_input_ids = self.sample_data_labels(models, inputs, tokenizers)
    
    text_lbls = inputs['text'].pop("labels")
    inputs['text']['decoder2_input_ids'] = decoder2_input_ids

    loss_dict = self.compute_loss(models, inputs, 
        text_lbls=text_lbls, data_lbls=data_lbls, text_tasks=text_tasks)
  
    loss_log = {}
    total_loss = 0.0
    for name, loss in loss_dict.items():
      total_loss += loss
      loss_log[name] = loss.detach().cpu()
  
    if self.gradient_accum_steps > 1 :
      total_loss = total_loss / self.gradient_accum_steps
    
    if self.do_grad_scaling:
      self.scaler.scale(total_loss).backward()
    else:
      total_loss.backward()

    return loss_log

  def train(self, train_loader, models, tokenizers, opts, schs):   

    self.tracker.reset_all()

    tr_loss = torch.tensor(0.0).to(self.device_id)

    for stage, m in models.items():
      if stage == self.stage:
        m.train()
        m.zero_grad()
      else:
        m.eval()
    
    for o in opts.values():
      if o is not None:
        o.zero_grad()    

    iterator = train_loader.__iter__()
    steps_in_epoch = len(iterator)
    text_tasks = None
    for step, (_, inputs) in enumerate(iterator):

      # if self.debug and step > 1:
      #    break
      
      if self.cfg.model.seperate_data_task:
        text_tasks = [1 if t != 'data' else 0 for t in  inputs['task']]

      loss_log = self.training_step(models, inputs, tokenizers, text_tasks=text_tasks)

      tr_loss_step = sum(list(loss_log.values()))
      tr_loss += tr_loss_step
      
      self.tracker.add_logs(split='train', log=loss_log, total_loss=tr_loss_step)


      if (step + 1) % self.gradient_accum_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accum_steps
                    steps_in_epoch <= self.gradient_accum_steps
                    and (step + 1) == steps_in_epoch
                ):
        
        if self.do_grad_scaling:
          self.scaler.unscale_(opts[self.stage])
        if self.max_grad_norm > 0:
          torch.nn.utils.clip_grad_norm_(models[self.stage].parameters(), self.max_grad_norm)

        if self.do_grad_scaling:
          self.scaler.step(opts[self.stage])
          self.scaler.update()
        else:
          opts[self.stage].step()

        models[self.stage].zero_grad()
        opts[self.stage].zero_grad()

        self.global_step += 1
        tr_loss = 0
        
      if isinstance(self.display, int) and step % self.display == 0 and step > 0:
        self.logger.info("E{:02d} GS: {:03d} {}".format(
          self.epoch, self.global_step, self.tracker.loss_str('iter')))
        
        self.tracker.reset_interval('iter')

    if schs[self.stage] is not None:
       schs[self.stage].step()
       
    self.logger.info("E{:02d} (train) {}".format(self.epoch, self.tracker.loss_str('epoch')))
    self.update_writer('train')

  def prediction_step(self, models, tokenizers, inputs, prediction_loss_only=False, text_tasks=None, generate_tokens=False):
    
    has_labels = 'labels' in inputs['text']
    
    for m in models.values():
       m.eval()

    text_inputs = self._prepare_inputs(inputs['text'])
    text_tokens = None
    data_tokens = None
    if generate_tokens:
      gen_kwargs = {
              "input_ids": text_inputs["input_ids"],
              "max_length": self._max_length,
              "num_beams": self._num_beams,
              "synced_gpus": False,
              "repetition_penalty": self._repetition_penalty,
              "temperature": self._gen_temperature,
          }

      if "attention_mask" in text_inputs:
        gen_kwargs["attention_mask"] = text_inputs.get("attention_mask", None)
      if "global_attention_mask" in inputs:
        gen_kwargs["global_attention_mask"] = text_inputs.get("global_attention_mask", None)

      tasks = self.cfg.data.dataset.tasks

      #Generate text
      self.set_model_output(models, 'text')
      
      if any(t in tasks for t in ['categorical','series_name','axis','caption']):
        with torch.no_grad():
          if models['seq'].__class__.__name__ == 'DistributedDataParallel':
            text_tokens = models['seq'].module.generate(**gen_kwargs)      
          else:
            text_tokens = models['seq'].generate(**gen_kwargs)

      #Generate data
      for l in ['max_length', 'num_beams', 'synced_gpus', 'repetition_penalty', 'temperature']:
        gen_kwargs.pop(l)

      self.set_model_output(models, 'data')
      
      if 'data' in tasks:
        with torch.no_grad():
          if models['seq'].__class__.__name__ == 'DistributedDataParallel':
            data_tokens = models['seq'].module.generate(**gen_kwargs)      
          else:
            data_tokens = models['seq'].generate(**gen_kwargs)

      # in case the batch is shorter than max length, the output should be padded
      if text_tokens is not None and text_tokens.shape[-1] < self._max_length:
          text_tokens = self._pad_tensors_to_max_len(
            text_tokens, self._max_length, models['seq'], tokenizers['seq'])

    tr_loss = 0.0
    text_lbls = None
    data_lbls = None
    if has_labels:
      with torch.no_grad():
        text_lbls = inputs['text'].pop("labels")

        data_lbls, decoder2_input_ids = self.sample_data_labels(models, inputs, tokenizers=tokenizers)
        inputs['text']['decoder2_input_ids'] = decoder2_input_ids

        losses, outputs = self.compute_loss(
           models, inputs, text_lbls=text_lbls, data_lbls=data_lbls, 
           return_outputs=True, text_tasks=text_tasks)
        
        if not generate_tokens:
          text_logits, data_logits = outputs['text'], outputs['data']
          text_tokens = Categorical(logits=text_logits).sample()
          data_tokens = Categorical(logits=data_logits).sample()

        loss_log = {}
        for k,v in losses.items():
          tr_loss += v.detach()
          loss_log[k] = v.cpu().detach().item()

        self.tracker.add_logs(split='eval', log=loss_log, total_loss=tr_loss)
        
    if prediction_loss_only:
        return (tr_loss, None, None)

    if text_lbls is not None and text_lbls.shape[-1] < self._max_length:
        text_lbls = self._pad_tensors_to_max_len(text_lbls, 
          self._max_length, models['seq'], tokenizers['seq'])

    return (tr_loss, text_tokens, data_tokens, inputs, text_lbls, data_lbls)
  
  def eval_loop(self, cur_stage, loader, models, tokenizers, metric_key_prefix='eval', prediction_loss_only=False, step_count=None):
    
    self.tracker.reset_all()
    iterator = loader.__iter__()
    steps_in_epoch = len(iterator)

    models[cur_stage].eval()

    # Initialize containers
    # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
    losses_host = None
    preds_host_text = None
    preds_host_data = None
    labels_host_text = None
    labels_host_data = None
    fidact_host_data = None
    inputs_host = None
    task_host = None

    # losses/preds/labels on CPU (final containers)
    all_losses = None
    all_text_preds = None
    all_data_preds = None
    all_labels_text = None
    all_labels_data = None
    all_fidact_data = None
    all_inputs = None
    all_tasks = None

    text_tasks = None
    fid_act = None

    observed_num_examples = 0
    batch_size = self.bsz

    start_time = time.time()

    max_step_count = self.cfg.eval.max_steps if step_count is None and isinstance(self.cfg.eval.max_steps, int) else step_count
    max_step_count = min(steps_in_epoch, max_step_count) if max_step_count is not None else steps_in_epoch

    for step, (_, inputs) in enumerate(iterator): 

      if (self.debug and step >= 6) or (max_step_count is not None and step > max_step_count):
         break 

      if isinstance(self.cfg.eval.display_interval, int) and step % self.cfg.eval.display_interval == 0 and step > 0:
        self.logger.info("Eval | {}/{} Time elapsed : {:.2f}s".format(step, max_step_count, time.time() - start_time))
      
      if self.cfg.model.seperate_data_task:
        text_tasks = [1 if t != 'data' else 0 for t in  inputs['task']]
        task_idx = torch.tensor([TASK2IDX[t] for t in inputs['task']], dtype=torch.long, device=self.device)

      # Update the observed num examples
      observed_batch_size = find_batch_size(inputs)
      if observed_batch_size is not None:
          observed_num_examples += observed_batch_size
          # For batch samplers, batch_size is not known by the dataloader in advance.
          if batch_size is None:
              batch_size = observed_batch_size

      start_time = time.time()

      loss, text_tokens, data_tokens, inputs, text_lbls, data_lbls = self.prediction_step(
         models, tokenizers=tokenizers,
        inputs=inputs, prediction_loss_only=prediction_loss_only,
        text_tasks=text_tasks, 
        generate_tokens=True)
      if self.cfg.eval.fid:
        fid_act = self.compute_fid_acts(models, data_tokens)

      inputs_decode = inputs['text']["input_ids"] 

      # Update containers on host
      if loss is not None:
        losses = self._nested_gather(loss.repeat(batch_size))
        losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)

      if text_lbls is not None:
        text_lbls = text_lbls.to(self.device)
        text_lbls = self._pad_across_processes(text_lbls)
        text_lbls = self._nested_gather(text_lbls)
        labels_host_text = text_lbls if labels_host_text is None else nested_concat(labels_host_text, text_lbls, padding_index=-100)

      if data_lbls is not None:
        data_lbls = data_lbls.to(self.device)
        data_lbls = self._pad_across_processes(data_lbls)
        data_lbls = self._nested_gather(data_lbls)
        labels_host_data = data_lbls if labels_host_data is None else nested_concat(labels_host_data, data_lbls, padding_index=-100)

      if fid_act is not None:
        fid_act = fid_act.to(self.device)
        fid_act = self._pad_across_processes(fid_act)
        fid_act = self._nested_gather(fid_act)
        fidact_host_data = fid_act if fidact_host_data is None else nested_concat(fidact_host_data, fid_act, padding_index=-100)

      if inputs_decode is not None:
        inputs_decode = inputs_decode.to(self.device)
        inputs_decode = self._pad_across_processes(inputs_decode)
        inputs_decode = self._nested_gather(inputs_decode)
        inputs_host = (
            inputs_decode
            if inputs_host is None
            else nested_concat(inputs_host, inputs_decode, padding_index=-100)
        )
      
      if text_tokens is not None:
        text_tokens = self._pad_across_processes(text_tokens)
        text_tokens = self._nested_gather(text_tokens)
        preds_host_text = text_tokens if preds_host_text is None else nested_concat(preds_host_text, text_tokens, padding_index=-100)

      if data_tokens is not None:
        data_tokens = data_tokens.contiguous()
        data_tokens = self._pad_across_processes(data_tokens)
        data_tokens = self._nested_gather(data_tokens)
        preds_host_data = data_tokens if preds_host_data is None else nested_concat(preds_host_data, data_tokens, padding_index=-100)

      if task_idx is not None:
        task_idx = self._pad_across_processes(task_idx)
        task_idx = self._nested_gather(task_idx)
        task_host = task_idx if task_host is None else nested_concat(task_host, task_idx, padding_index=-100)

      # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
      if self.eval_accumulation_steps is not None and (step + 1) % self.eval_accumulation_steps == 0:
        if losses_host is not None:
            losses = nested_numpify(losses_host)
            all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
        if preds_host_text is not None:
            text_tokens = nested_numpify(preds_host_text)
            all_text_preds = text_tokens if all_text_preds is None else nested_concat(all_text_preds, text_tokens, padding_index=-100)
        
        if preds_host_data is not None:
            data_tokens = nested_numpify(preds_host_data)
            all_data_preds = data_tokens if all_data_preds is None else nested_concat(all_data_preds, data_tokens, padding_index=-100)

        if inputs_host is not None:
            inputs_decode = nested_numpify(inputs_host)
            all_inputs = (
                inputs_decode
                if all_inputs is None
                else nested_concat(all_inputs, inputs_decode, padding_index=-100)
            )
        if labels_host_text is not None:
            text_lbls = nested_numpify(labels_host_text)
            all_labels_text = (
                text_lbls if all_labels_text is None else nested_concat(all_labels_text, text_lbls, padding_index=-100)
            )
        if labels_host_data is not None:
            data_lbls = nested_numpify(labels_host_data)
            all_labels_data = (
                data_lbls if all_labels_data is None else nested_concat(all_labels_data, data_lbls, padding_index=-100)
            )
        
        if fidact_host_data is not None:
            fid_act = nested_numpify(fidact_host_data)
            all_fidact_data = (
                fid_act if all_fidact_data is None else nested_concat(all_fidact_data, fid_act, padding_index=-100)
            )

        if task_host is not None:
            task_idx = nested_numpify(task_host)
            all_tasks = task_idx if all_tasks is None else nested_concat(all_tasks, task_idx, padding_index=-100)

        # Set back to None to begin a new accumulation
        losses_host, preds_host_text, preds_host_data, inputs_host, \
          labels_host_text, labels_host_data, task_host, fidact_host_data = \
          None, None, None, None, None, None, None, None


    # Gather all remaining tensors and put them back on the CPU
    if losses_host is not None:
        losses = nested_numpify(losses_host)
        all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
    if preds_host_text is not None:
        text_tokens = nested_numpify(preds_host_text)
        all_text_preds = text_tokens if all_text_preds is None else nested_concat(all_text_preds, text_tokens, padding_index=-100)
    if preds_host_data is not None:
        data_tokens = nested_numpify(preds_host_data)
        all_data_preds = data_tokens if all_data_preds is None else nested_concat(all_data_preds, data_tokens, padding_index=-100)
    if inputs_host is not None:
        inputs_decode = nested_numpify(inputs_host)
        all_inputs = inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
    if labels_host_text is not None:
        text_lbls = nested_numpify(labels_host_text)
        all_labels_text = text_lbls if all_labels_text is None else nested_concat(all_labels_text, text_lbls, padding_index=-100)
    if labels_host_data is not None:
        data_lbls = nested_numpify(labels_host_data)
        all_labels_data = data_lbls if all_labels_data is None else nested_concat(all_labels_data, data_lbls, padding_index=-100)
    if fidact_host_data is not None:
        fid_act = nested_numpify(fidact_host_data)
        all_fidact_data = fid_act if all_fidact_data is None else nested_concat(all_fidact_data, fid_act, padding_index=-100)
    if task_host is not None:
        task_idx = nested_numpify(task_host)
        all_tasks = task_idx if all_tasks is None else nested_concat(all_tasks, task_idx, padding_index=-100)

    num_samples = steps_in_epoch
    # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
    # samplers has been rounded to a multiple of batch_size, so we truncate.
    if all_losses is not None:
        all_losses = all_losses[:num_samples]
    if all_text_preds is not None:
        all_text_preds = nested_truncate(all_text_preds, num_samples)
    if all_data_preds is not None:
        all_data_preds = nested_truncate(all_data_preds, num_samples)
    if all_labels_text is not None:
        all_labels_text = nested_truncate(all_labels_text, num_samples)
    if all_labels_data is not None:
        all_labels_data = nested_truncate(all_labels_data, num_samples)
    if all_fidact_data is not None:
        all_fidact_data = nested_truncate(all_fidact_data, num_samples)
    if all_inputs is not None:
        all_inputs = nested_truncate(all_inputs, num_samples)
    if all_tasks is not None:
       all_tasks = all_tasks[:num_samples]
  
    # Metrics for text
    metrics = {}    
    fid_train, fid_test = None, None
    text_preds = all_text_preds
    text_labels = all_labels_text

    if self.compute_fid is not None and all_fidact_data is not None and self.cfg.eval.fid:
      data_indices = np.where(all_tasks == TASK2IDX['data'])[0]
      data_acts   = all_fidact_data[data_indices]
      fid_train, fid_test = self.compute_fid(data_acts)

    if self.compute_metrics is not None and text_preds is not None and text_labels is not None:
      # if step_count is not None:
      text_indices = np.where(all_tasks != TASK2IDX['data'])[0]
      text_preds   = all_text_preds[text_indices]
      text_labels  = all_labels_text[text_indices]

      if text_preds.shape[0] > 0:
        metrics = self.compute_metrics(EvalPrediction(predictions=text_preds, 
                          label_ids=text_labels),
                          tokenizers[cur_stage])
    
    metrics = denumpify_detensorize(metrics)

    if all_losses is not None:
        metrics["loss"] = all_losses.mean().item()
    
    if fid_train is not None:
       metrics['fid_train'] = fid_train
       metrics['fid_test'] = fid_test

    return EvalLoopOutputwInputs(predictions=all_text_preds, label_ids_text=all_labels_text,  label_ids_code=all_labels_data, 
      metrics=metrics, num_samples=num_samples, inputs=all_inputs), all_inputs


  def eval(self, val_loader, models, tokenizers, metric_key_prefix='eval', epoch=0, prediction_loss_only=False, step_count=None, **kwargs):
    
    for m in models.values():
      if m is not None:
        m.eval()

    predict_results, all_inputs = self.eval_loop(self.stage, val_loader, models, tokenizers, metric_key_prefix=metric_key_prefix, prediction_loss_only=prediction_loss_only, step_count=step_count)

    if self.rank() == 0 and step_count is None and predict_results.predictions is not None:
      predictions = tokenizers[self.stage].batch_decode(
          predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
      )

      if all_inputs is not None:
        contexts = tokenizers[self.stage].batch_decode(
            all_inputs, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )
      else:
        contexts = [''] * len(predictions)

      ksm_scores = self.compute_ksm_scores(
          predict_results.predictions, predict_results.label_ids_text, 
          contexts, models, tokenizers, seperator='<SEP>')
      
      self.tracker.add_metrics(ksm_scores, metric_key_prefix, 'ksm')

      if self.cfg.eval.write_samples:
        print(f"Writing samples to output: {self.cfg.sample_dirs[self.stage]}")
        for pidx, (pred, context) in enumerate(zip(predictions, contexts)):
          task = context.split(':')[0].lower().strip()
          if 'data' not in task:
            task_directory = os.path.join(self.cfg.sample_dirs[self.stage], task)
            os.makedirs(task_directory, exist_ok=True)

            output_prediction_file = os.path.join(
                task_directory, "e{}_{}.txt".format(epoch, pidx))
            
            text = "OUTPUT: {} \n\n INPUT: {}".format(pred, context)
            with open(output_prediction_file, "w") as writer:
              writer.write(text)

    metrics = predict_results.metrics
    self.tracker.add_metrics(metrics, metric_key_prefix, 'rouge')

    self.logger.info("E{:02d} (eval) {} {}".format(self.epoch, self.tracker.loss_str('epoch'), self.tracker.metric_str('epoch', stage=self.stage)))

    #return predict_results
    opt_mode = self.cfg.model.seq.opt_mode
    outputs= {}
    if opt_mode == 0:
      outputs = {'score': metrics.get('loss')}
    elif opt_mode == 1:
      outputs = {'score': metrics.get('rouge2')}
    elif opt_mode == 2:
      outputs = {'score': metrics.get('fid_test')}
    return outputs

  def compute_fid_acts(self, models, data_tokens):
    start_time = time.time()
    assert 'fid' in models and 'continuous' in models
    for m in models.values():
      if m is not None:
        m.eval()
    
    emb_len1 = self.cfg.model.continuous_data.vq.emb_len1
    emb_len2 = self.cfg.model.continuous_data.vq.emb_len2
    n_emb1 = self.cfg.model.continuous_data.vq.n_emb1
    n_emb2 = self.cfg.model.continuous_data.vq.n_emb2

    ct_idx = data_tokens[:,:1]
    cb_ind1 = data_tokens[:,1:1 + emb_len1]
    cb_ind2 = data_tokens[:,1 + emb_len1:1 + emb_len1 + emb_len2]
    
    ct_idx = ct_idx - 2
    cb_ind1    = cb_ind1 - 2 - len(UNIQ_CHART_HEADS)
    if cb_ind2 is not None:
      cb_ind2  = cb_ind2 - 2 - len(UNIQ_CHART_HEADS) 

    ct_idx = torch.clamp(ct_idx, min=0, max=len(UNIQ_CHART_HEADS)-1)
    cb_ind1 = torch.clamp(cb_ind1, min=0, max=n_emb1-1)
    cb_ind2 = torch.clamp(cb_ind2, min=0, max=n_emb2-1)

    kwargs = {
      'cb_ind1': cb_ind1,
      'cb_ind2': cb_ind2,
      'ct_idx': ct_idx,
      'temp': self.cfg.eval.gen_temperature,
      'hypo_count': self.cfg.eval.hypo_count,
      'hypo_bsz': self.cfg.eval.hypo_bsz
    }


    with torch.no_grad():
        with self.autocast_smart_context_manager():
          if models['continuous'].__class__.__name__ == 'DistributedDataParallel':
            x_hat = models['continuous'].module.reconstruct_from_indices(**kwargs)
          else:
            x_hat = models['continuous'].reconstruct_from_indices(**kwargs)


          x_hat['chart_type'] = [HEAD_IDX_TO_CHART[m] for m in ct_idx.view(-1).detach().cpu().numpy()]
          activations, _, _ = models['fid'](x_hat)
          
          #activations = torch.from_numpy(activations)

    return activations

  def compute_fid(self, act_container):
    
    #act = np.concatenate(act_container, axis=0)
    mu = np.mean(act_container, axis=0)
    sigma = np.cov(act_container, rowvar=False)
    
    #Load existing fid scores
    m1, s1, _ = self.fid_stats['train']
    m2, s2, _ = self.fid_stats['test']

    train_fid = calculate_frechet_distance(mu, sigma, m1, s1)
    test_fid = calculate_frechet_distance(mu, sigma, m2, s2)

    return train_fid, test_fid