honeyplotnet / runner / text.py
text.py
Raw
# ---------------------------------------------------------------
# Copyright (c) __________________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------
from filelock import FileLock
import numpy as np
import nltk  
import os
import torch
import torch.distributed as dist
from torch.nn import functional as F
from evaluate import load as load_metric

from .base import BaseRunner

from transformers.trainer_pt_utils import (
  find_batch_size, 
  nested_concat,
  nested_numpify,
  nested_truncate,
)

from transformers.trainer_utils import ( 
  EvalPrediction, denumpify_detensorize
)


from utils.ksm_scores import (
  yake_text, embed_and_encode, calc_similarity
)

try:
    nltk.data.find("tokenizers/punkt")
except (LookupError, OSError):
    with FileLock(".lock") as lock:
        nltk.download("punkt", quiet=True)


from typing import Dict, NamedTuple, Optional, Tuple, Union

from utils.constant import TASK2PREPEND

class EvalLoopOutputwInputs(NamedTuple):
    predictions: Union[np.ndarray, Tuple[np.ndarray]]
    label_ids: Optional[np.ndarray]
    metrics: Optional[Dict[str, float]]
    num_samples: Optional[int]
    inputs: Optional[np.ndarray]

class ChartTextRunner(BaseRunner):
  def __init__(self, stage, cfg):
    super(ChartTextRunner, self).__init__(cfg)
    self.stage = stage

    self.ignore_pad_token_for_loss = self.cfg.model.seq.hf_model.ignore_pad_token_for_loss
    self.include_inputs_for_metrics = self.cfg.eval.include_inputs_for_metrics
    self.eval_accumulation_steps = self.cfg.eval.eval_accumulation_steps

    self._max_length = cfg.eval.max_length
    self._num_beams = cfg.eval.num_beams
    self._repetition_penalty = cfg.eval.repetition_penalty
    self._gen_temperature = cfg.eval.gen_temperature

    self.rouge = load_metric("rouge")

    self.loss_fn =torch.nn.CrossEntropyLoss(ignore_index=-100)

  def postprocess_text(self, preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels

  def compute_metrics(self, eval_preds, tokenizer):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    if self.ignore_pad_token_for_loss:
        # Replace -100 in the labels as we can't decode them.
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = self.postprocess_text(decoded_preds, decoded_labels)

    result = self.rouge.compute(
      predictions=decoded_preds, 
      references=decoded_labels, use_stemmer=True)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

  def compute_loss(self, model, inputs, return_outputs=False):

    if "labels" in inputs and self.cfg.train.label_smoothing_factor != 0:
        labels = inputs.pop("labels")
    else:
        labels = None

    outputs = model(**inputs)

    if labels is not None:
        flat_logits = torch.flatten(outputs.logits, start_dim=0, end_dim=1)
        labels = labels.view(-1)
        loss = self.loss_fn(flat_logits, labels)
    else:
        # We don't use .loss here since the model may return tuples instead of ModelOutput.
        loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

    return (loss, outputs) if return_outputs else loss

  def training_step(self, model, inputs):
    model.train()
    inputs = self._prepare_inputs(inputs)

    with self.autocast_smart_context_manager():
      loss, outputs = self.compute_loss(model, inputs, return_outputs=True)

    if self.use_torch_dist:
      loss = loss.mean() 
    
    if self.gradient_accum_steps > 1:
      loss = loss / self.gradient_accum_steps
    
    if self.do_grad_scaling:
      self.scaler.scale(loss).backward()
    else:
      loss.backward()

    return loss.detach()

  def train(self, train_loader, models, tokenizers, opts, schs):

    self.tracker.reset_all()

    tr_loss = torch.tensor(0.0).to(self.device_id)
    
    models[self.stage].zero_grad()
    opts[self.stage].zero_grad()
    iterator = train_loader.__iter__()

    steps_in_epoch = len(iterator)

    for step, model_inputs in enumerate(iterator):
      
      tr_loss_step = self.training_step(models[self.stage], model_inputs)
      
      tr_loss += tr_loss_step


      if (step + 1) % self.gradient_accum_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accum_steps
                    steps_in_epoch <= self.gradient_accum_steps
                    and (step + 1) == steps_in_epoch
                ):
        
        if self.do_grad_scaling:
          self.scaler.unscale_(opts[self.stage])
        if self.max_grad_norm > 0:
          torch.nn.utils.clip_grad_norm_(models[self.stage].parameters(), self.max_grad_norm)

        if self.do_grad_scaling:
          self.scaler.step(opts[self.stage])
          self.scaler.update()
        else:
          opts[self.stage].step()

        models[self.stage].zero_grad()
        opts[self.stage].zero_grad()

        self.global_step += 1
          
        self.tracker.add_logs(split='train', total_loss=tr_loss)
        tr_loss = 0

      if isinstance(self.display, int) and step % self.display == 0 and step > 0:
        self.logger.info("E{:02d} GS: {:03d} {}".format(
          self.epoch, self.global_step, self.tracker.loss_str('iter')))
        
        self.tracker.reset_interval('iter')

    self.logger.info("E{:02d} (train) {} {}".format(self.epoch, 
      self.tracker.loss_str('epoch'), 
      self.tracker.metric_str('epoch', stage=self.stage)))

    self.update_writer('train')

  def eval(self, eval_loader, models, tokenizers, metric_key_prefix='eval', prediction_loss_only=False, epoch=0, **kwargs):
    
    predict_results, all_inputs = self.eval_loop(self.stage, eval_loader, models, tokenizers, metric_key_prefix=metric_key_prefix, prediction_loss_only=prediction_loss_only)
    

    if self.rank() == 0:
      
      ###### Create readable files below
      predictions = tokenizers[self.stage].batch_decode(
          predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
      )

      if all_inputs is not None:
        contexts = tokenizers[self.stage].batch_decode(
            all_inputs, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )
      else:
        contexts = [''] * len(predictions)

      #### Do KSM here
      ksm_scores = self.compute_ksm_scores(predict_results.predictions, predict_results.label_ids, contexts, models, tokenizers, seperator='<SEP>')
      self.tracker.add_metrics(ksm_scores, metric_key_prefix, 'ksm')

      for pidx, (pred, context) in enumerate(zip(predictions, contexts)):
        output_prediction_file = os.path.join(self.cfg.sample_dirs[self.stage], "generated_predictions_e{}_{}.txt".format(epoch, pidx))

        text = "OUTPUT: {} \n\n INPUT: {}".format(pred, context)
        with open(output_prediction_file, "w") as writer:
          writer.write(text)

      metrics = predict_results.metrics
      self.tracker.add_metrics(metrics, metric_key_prefix, 'rouge')

      # self.logger.info("E{:02d} (eval) {} {}".format(self.epoch, 
      #   self.tracker.metric_str('epoch', metric_key_prefix)))

      self.logger.info("E{:02d} (eval) {} {}".format(self.epoch, 
      self.tracker.loss_str('epoch'), 
      self.tracker.metric_str('epoch', stage=self.stage)))

    return predict_results

  def compute_ksm_scores(self, pred_tokens, labels, contexts, models, tokenizers, seperator='<SEP>'):
    
    sim_scores = {}
    for pidx, (pred_toks, label, context) in enumerate(zip(pred_tokens, labels, contexts)):

      if self.stage in ['chart_text', 'seq']:
        task = None
        for t, prepend_str in TASK2PREPEND.items():
          if context.startswith(prepend_str):
            task = t
            break
        
        if task is None:
          raise ValueError(f"No task found in context:  {context[:100]}")
      elif self.stage == 'caption':
        task = 'caption'
      else:
        raise ValueError(f"Stage not recognised for ksm: {self.stage}")
      
      if task == 'data':
         continue
      #Create key words for context and task

      if task in ['chart_text', 'series_name', 'axis','categorical']:
        task_str = tokenizers[self.stage].decode(
            pred_toks, skip_special_tokens=False, clean_up_tokenization_spaces=True)
        task_str = task_str.replace('<pad>','').replace('</s>','')
        task_keywords = [t.strip() for t in task_str.split(seperator)]

        #Replace -100 with 0
        label[label==-100] = 0
        ref_str = tokenizers[self.stage].decode(
            label, skip_special_tokens=False, clean_up_tokenization_spaces=True)
        ref_str = ref_str.replace('<pad>','').replace('</s>','')
        reference_keywords = [t.strip() for t in ref_str.split(seperator)]

      elif task in ['caption']:
        decoded_str = tokenizers[self.stage].decode(
            pred_toks, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )

        task_keywords = [kw[0] for kw in yake_text(decoded_str)]
        reference_keywords = [kw[0] for kw in yake_text(context)]

      #Tokenize reference and task
      reference_tok = tokenizers['ksm'](reference_keywords, max_length=128, 
        padding="max_length", truncation=True, return_tensors="pt")
      task_tok = tokenizers['ksm'](task_keywords, max_length=128, 
        padding="max_length", truncation=True, return_tensors="pt")

      reference_emb = embed_and_encode(
        reference_tok, models['ksm'], device=self.device)
      task_emb = embed_and_encode(
        task_tok, models['ksm'], device=self.device)
      
      #Average embeddings through the sequence
      reference_emb = reference_emb.mean(1)
      task_emb = task_emb.mean(1)
      
      sim_score = calc_similarity(reference_emb, task_emb).mean().detach().cpu().item()
      
      if 'ksm_' + task not in sim_scores:
        sim_scores['ksm_' + task] = []

      sim_scores['ksm_' + task].append(sim_score)
    
    return sim_scores
  
  def eval_loop(self, cur_stage, loader, models, tokenizers, metric_key_prefix='eval', prediction_loss_only=False):

    self.tracker.reset_all()
    iterator = loader.__iter__()
    steps_in_epoch = len(iterator)

    models[cur_stage].eval()

    # Initialize containers
    # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
    losses_host = None
    preds_host = None
    labels_host = None
    inputs_host = None

    # losses/preds/labels on CPU (final containers)
    all_losses = None
    all_preds = None
    all_labels = None
    all_inputs = None
    observed_num_examples = 0
    batch_size = self.bsz

    for step, inputs in enumerate(iterator):
      
      if isinstance(self.display, int) and step % self.display == 0 and step > 0:
        self.logger.info("Eval | E{:02d} Step {:04d}/{:04d} ".format(self.epoch, step, self.cfg.eval.max_steps))

      if isinstance(self.cfg.eval.max_steps, int) and step > self.cfg.eval.max_steps:
        break

      # Update the observed num examples
      observed_batch_size = find_batch_size(inputs)
      if observed_batch_size is not None:
          observed_num_examples += observed_batch_size
          # For batch samplers, batch_size is not known by the dataloader in advance.
          if batch_size is None:
              batch_size = observed_batch_size
              
      loss, logits, labels, inputs  = self.prediction_step(
        models[cur_stage], tokenizer=tokenizers[cur_stage], 
        inputs=inputs, prediction_loss_only=prediction_loss_only)

      inputs_decode = inputs["input_ids"] 

      # Update containers on host
      if loss is not None:
        losses = self._nested_gather(loss.repeat(batch_size))
        losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)

      if labels is not None:
        labels = self._pad_across_processes(labels)
        labels = self._nested_gather(labels)
        labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)

      if inputs_decode is not None:
        inputs_decode = self._pad_across_processes(inputs_decode)
        inputs_decode = self._nested_gather(inputs_decode)
        inputs_host = (
            inputs_decode
            if inputs_host is None
            else nested_concat(inputs_host, inputs_decode, padding_index=-100)
        )
      
      if logits is not None:
        logits = self._pad_across_processes(logits)
        logits = self._nested_gather(logits)
        #if self.preprocess_logits_for_metrics is not None:
            #logits = self.preprocess_logits_for_metrics(logits, labels)
        preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
      
      # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
      if self.eval_accumulation_steps is not None and (step + 1) % self.eval_accumulation_steps == 0:
        if losses_host is not None:
            losses = nested_numpify(losses_host)
            all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
        if preds_host is not None:
            logits = nested_numpify(preds_host)
            all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
        if inputs_host is not None:
            inputs_decode = nested_numpify(inputs_host)
            all_inputs = (
                inputs_decode
                if all_inputs is None
                else nested_concat(all_inputs, inputs_decode, padding_index=-100)
            )
        if labels_host is not None:
            labels = nested_numpify(labels_host)
            all_labels = (
                labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
            )

        # Set back to None to begin a new accumulation
        losses_host, preds_host, inputs_host, labels_host = None, None, None, None


    # Gather all remaining tensors and put them back on the CPU
    if losses_host is not None:
        losses = nested_numpify(losses_host)
        all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
    if preds_host is not None:
        logits = nested_numpify(preds_host)
        all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
    if inputs_host is not None:
        inputs_decode = nested_numpify(inputs_host)
        all_inputs = (
            inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
        )
    if labels_host is not None:
        labels = nested_numpify(labels_host)
        all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)

    num_samples = steps_in_epoch
    # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
    # samplers has been rounded to a multiple of batch_size, so we truncate.
    if all_losses is not None:
        all_losses = all_losses[:num_samples]
    if all_preds is not None:
        all_preds = nested_truncate(all_preds, num_samples)
    if all_labels is not None:
        all_labels = nested_truncate(all_labels, num_samples)
    if all_inputs is not None:
        all_inputs = nested_truncate(all_inputs, num_samples)
  
    # Metrics!
    if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
      if self.include_inputs_for_metrics:
          metrics = self.compute_metrics(
              EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs),
              tokenizers[cur_stage]
          )
      else:
          metrics = self.compute_metrics(
            EvalPrediction(predictions=all_preds, label_ids=all_labels),
            tokenizers[cur_stage])
    else:
        metrics = {}
    
    metrics = denumpify_detensorize(metrics)

    if all_losses is not None:
        metrics["loss"] = all_losses.mean().item()

    
    return EvalLoopOutputwInputs(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples, inputs=all_inputs), all_inputs

  def prediction_step(self, model, tokenizer, inputs, prediction_loss_only=False, ignore_keys=[]):
    has_labels = 'labels' in inputs
    model.eval()
    
    inputs = self._prepare_inputs(inputs)

    gen_kwargs = {
            "max_length": self._max_length,
            "num_beams": self._num_beams,
            "synced_gpus": False,
            "repetition_penalty": self._repetition_penalty,
            "temperature": self._gen_temperature
        }

    if "attention_mask" in inputs:
            gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
    if "global_attention_mask" in inputs:
        gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None)
    
    # prepare generation inputs
    generation_inputs = inputs["input_ids"]

    if model.__class__.__name__ == 'DistributedDataParallel':
      generated_tokens = model.module.generate(
          generation_inputs,
          **gen_kwargs,
      )      
    else:
      generated_tokens = model.generate(
          generation_inputs,
          **gen_kwargs,
      )
    # in case the batch is shorter than max length, the output should be padded
    if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
        generated_tokens = self._pad_tensors_to_max_len(
          generated_tokens, gen_kwargs["max_length"], model, tokenizer)

    loss = None
    if has_labels:
      with torch.no_grad():
        with self.autocast_smart_context_manager():
          outputs = model(**inputs)
        
        if self.loss_fn is not None:
          logits = outputs.logits

          if isinstance(logits, np.ndarray):
            logits = torch.from_numpy(logits).to(self.device)
          
          logits = torch.flatten(logits, start_dim=0, end_dim=1)
          labels = torch.flatten(inputs["labels"], start_dim=0, end_dim=1)

          loss = self.loss_fn(logits, labels).mean().detach()
        else:
            loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
          
    if prediction_loss_only:
        return (loss, None, None)

    if has_labels:
        labels = inputs["labels"]
        if labels.shape[-1] < gen_kwargs["max_length"]:
            labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"], model, tokenizer)
    else:
        labels = None

    return (loss, generated_tokens, labels, inputs)
    

  def _pad_tensors_to_max_len(self, tensor, max_length, model, tokenizer):
    if tokenizer is not None and hasattr(tokenizer, "pad_token_id"):
        # If PAD token is not defined at least EOS token has to be defined
        pad_token_id = (
            tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
        )
    else:
        if model.config.pad_token_id is not None:
            pad_token_id = model.config.pad_token_id
        else:
            raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")

    padded_tensor = pad_token_id * torch.ones(
        (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
    )
    padded_tensor[:, : tensor.shape[-1]] = tensor
    return padded_tensor