honeyplotnet / dataset / base.py
base.py
Raw
# ---------------------------------------------------------------
# Copyright (c) __________________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------

import numpy as np
import torch
import random
import re

from .constants import TEXT_CUTOFF_POINTS

def get_text_window(all_text, context_start, tgt_token='[MASK1]', window_size=10, widen_rate=0.5, max_widen_len=2, min_sent_len=10):
  '''
  window_size        : measured as number of full sentences.
  min_sent_len       : minimum size of sentences in characters
  widen_rate         : rate that window size increases or decreases randomly by 1
  '''
  assert widen_rate >= 0.0 and widen_rate <= 1.0
  assert max_widen_len >= 0
    
  sentence_split = []
  text_len = 0
  context_sent_idx = None
  for s in all_text.split('. '):
    sent_len = len(s)
    text_len += sent_len
    
    if sent_len >= min_sent_len:
      sentence_split.append(s.strip())
    
    #Track when context starts within sentences
    if text_len >= context_start and context_sent_idx is None:
      context_sent_idx = len(sentence_split)
      sentence_split.append(tgt_token)
      
  if len(all_text) >= context_start or context_sent_idx is None:
    #Random spot in middle
    mid_pt = len(sentence_split) // 2
    sentence_split.insert(mid_pt, tgt_token)
    context_sent_idx = mid_pt #sum([len(s) for s in sentence_split[:mid_pt]])

  #Calculate window size flexibly
  if random.random() < widen_rate:
    #increase or decrease ?
    inc_dec = 1 if random.random() < 0.5 else -1
    #randomly pick a change size
    change_size = random.randint(0, max_widen_len)
    window_size += (inc_dec * change_size)

  start_idx = context_sent_idx - int(np.ceil(window_size / 2))
  end_idx = context_sent_idx + window_size // 2
  
  if start_idx < 0:
    end_idx -= start_idx
    start_idx = 0
  elif end_idx > len(sentence_split):
    start_idx = start_idx - (end_idx - len(sentence_split))
    end_idx = len(sentence_split)
  #ensure limits
  start_idx = max(start_idx, 0)
  end_idx = min(end_idx, len(sentence_split))
  
  output = '. '.join(sentence_split[start_idx:end_idx])
  output = re.sub(r'\[\d+\]', '', output)
  return output

def stack_dict(list_batch):
    collated_batch = {}
    for batch in list_batch:
      if batch is None: continue
      for name, data in batch.items():
          if name not in collated_batch:
              collated_batch[name] = []
          collated_batch[name].append(data)
    return collated_batch

def stack_tensor_dict(dict_tensors):
    for k in list(dict_tensors.keys()):
      dict_tensors = torch.stack(dict_tensors[k], dim=0)
    return dict_tensors

def shift_right_vec(x, start_values=0.0):

    x_len = len(x.shape)
    one_frame = [1] * x_len
    if x_len == 3:
      one_frame[0] = x.size(0)
      one_frame[-1] = x.size(-1)
    elif x_len == 2:
      one_frame[0] = x.size(0)
    elif x_len == 1:
      pass
    else:
      raise ValueError("Input has invalid shape")

    start_seq = torch.ones(one_frame, device=x.device, dtype=torch.float32)
    start_seq = start_seq * start_values

    if x_len == 3:
      return torch.cat([start_seq, x[:,:-1, :]], dim=1)
    elif x_len == 2:
      return torch.cat([start_seq, x[:, :-1]], dim=1)
    elif x_len == 1:
      return torch.cat([start_seq, x[:-1]], dim=0)
    else:
      raise NotImplementedError()

def shift_tokens_right(input_ids, pad_token_id, dim=-1):
    """Shift input ids one token to the right, and wrap the last non pad token (usually <eos>)."""
    prev_output_tokens = input_ids.clone()
    index_of_eos = (input_ids.ne(pad_token_id).sum(dim=dim) - 1).unsqueeze(-1)
    prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
    prev_output_tokens[:, 1:] = input_ids[:, :-1]
    return prev_output_tokens
  
def shift_tokens_right_pad(input_ids, pad_token_id, dim=-1):
    """Shift input ids one token to the right, and pad front."""
    prev_output_tokens = input_ids.clone()
    prev_output_tokens[:, 0] = pad_token_id
    prev_output_tokens[:, 1:] = input_ids[:, :-1]
    return prev_output_tokens

class BaseDataset(torch.utils.data.Dataset):
  def __init__(self, tokenizer, window_size, widen_rate, max_widen_len, 
               min_sent_len, max_source_len, max_target_len, 
               pad_to_max_len, ignore_pad_token_for_loss, tgt_token='<mask_1>'):
    super().__init__()

    if isinstance(tokenizer, dict) and 'caption' in tokenizer:
      tokenizer = tokenizer['caption']

    self.tokenizer = tokenizer
    self.window_size = window_size
    self.widen_rate = widen_rate
    self.max_widen_len = max_widen_len
    self.min_sent_len = min_sent_len
    self.max_source_len = max_source_len
    self.max_target_len = max_target_len
    self.pad_to_max_len = pad_to_max_len
    
    self.ignore_pad_token_for_loss = ignore_pad_token_for_loss
    self.padding = "max_length" if pad_to_max_len else False
    self.tgt_token = tgt_token


  def preprocess_all_text(self, all_text, min_cutoff_rate=0.8):
    #Remove text beyond the conclusion. 
    #Method: Find the first mention of a list of words. Remove all text beyond the first mention.

    text_len = len(all_text)
    min_cutoff_idx = text_len
    for text in TEXT_CUTOFF_POINTS:
      cutoff_idx = all_text.lower().find(text.lower())

      if cutoff_idx >= (text_len * min_cutoff_rate):

        if cutoff_idx < min_cutoff_idx:
          min_cutoff_idx = cutoff_idx
    
    all_text = all_text[:min_cutoff_idx]
    assert len(all_text) >= (text_len * min_cutoff_rate), "bug found: text >> {}".format(all_text)
    return all_text

  def get_caption_label(self, captions, all_text, context_start):
    caption_label = ''
    for key in ['title','p']:
      cap = captions.get(key)
      if isinstance(cap, list):
        for c in cap:
          #Remove location of context
          all_text = all_text.replace(c, '')
          context_start -= len(c)
            
          # split caption into senteces and remove parts that have \\
          c = '. '.join([s.replace('\n', ' ') for s in c.strip().split('. ')])
          caption_label += c 
      elif cap is None:
        continue
      else:
        raise cap
    return caption_label, all_text, context_start
  
  def preprocess_captions(self, inputs, targets):  

    model_inputs, _ = self._tokenize(self.tokenizer, 
      inputs, max_source_len=self.max_source_len, 
      max_target_len=self.max_target_len, is_target=False)

    # Setup the tokenizer for targets
    with self.tokenizer.as_target_tokenizer():
      labels = self.tokenizer(targets, max_length=self.max_target_len, padding=self.padding, truncation=True)

    model_inputs["decoder_input_ids"] = shift_tokens_right(torch.tensor(labels["input_ids"]), self.tokenizer.pad_token_id)
    
    if self.padding == "max_length" and self.ignore_pad_token_for_loss:
      labels["input_ids"] = [
          l if l != self.tokenizer.pad_token_id else -100 for l in labels["input_ids"]
      ]
    
    model_inputs["labels"] = torch.tensor(labels["input_ids"])

    for k in list(model_inputs.keys()):
      model_inputs[k] = model_inputs[k].squeeze(0)
    
    model_inputs['raw'] = targets
    return model_inputs
  
  def _tokenize(self, tokenizer, text, max_source_len, max_target_len=32, is_target=False, shift_right=False):
    
    inputs = tokenizer(
      text, max_length=max_source_len, 
      padding=self.padding, truncation=True, return_tensors="pt")
    
    labels = None
    if is_target:
      with tokenizer.as_target_tokenizer():
        labels = tokenizer(text, max_length=max_target_len, padding=self.padding, truncation=True)

        if shift_right:
          inputs["decoder_input_ids"] = shift_tokens_right(torch.tensor(labels["input_ids"]).long(), tokenizer.pad_token_id, dim=-1)
          inputs["decoder_attention_mask"] = shift_tokens_right(torch.tensor(labels["attention_mask"]).long(), tokenizer.pad_token_id, dim=-1)

      if self.padding == "max_length" and self.ignore_pad_token_for_loss:

        labels["input_ids"] = [
            l if l != tokenizer.pad_token_id else -100 for l in labels["input_ids"]
        ]
        labels["input_ids"] = torch.tensor(labels["input_ids"]).long()
        
    return inputs, labels