honeyplotnet / dataset / text.py
text.py
Raw
# ---------------------------------------------------------------
# Copyright (c) __________________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------


import random

import torch

from utils import TASK2PREPEND

from .data import PmcDataDataset
from .base import get_text_window

class PmcTextDataset(PmcDataDataset):
  def __init__(self, data, **kwargs):
    super().__init__( data, **kwargs)

    assert self.tasks is not None or len(self.tasks) > 0
    self.data = self.filter_data_by_tasks(self._data)

  def __getitem__(self, index):
      
    while True:
      d = self.data[index % len(self.data)]

      try:
        fig_id = d['fig_id']
        context_start = d['fig_index'][fig_id][0]
      except:
        self.del_list.append(index)
        index += 1
        continue
      
      all_text = d['all_text']
      captions = d['captions']

      all_text = self.preprocess_all_text(all_text)

      caption_label, all_text, context_start = self.get_caption_label(
        captions, all_text, context_start)
      
      ###################### 
      # Prepare targets (randomly pick a task)
      #####################
      inputs = {}
      outputs = {}

      if 'caption' in self.tasks:
        _, outputs['caption'] = self._tokenize(self.tokenizer, 
          caption_label, max_source_len=self.max_source_len, 
          max_target_len=self.max_target_len, is_target=True)

      series_name = self.get_series_names(d)

      if len(series_name) and 'series_name' in self.tasks:
        series_name = self.sep_token.join(series_name)
        outputs['series_name'] = self.tokenize_tgt_flatten(series_name)

      categorical_data = self.get_categorical_values(d)

      if len(categorical_data) > 0 and 'categorical' in self.tasks:
        categorical_data = self.sep_token.join(categorical_data)
        #categorical_data = '{}'.format(categorical_data)
        outputs['categorical'] = self.tokenize_tgt_flatten(categorical_data)

      axis_data = self.get_axis_names(d)

      if len(axis_data) > 0 and 'axis' in self.tasks:
        axis_data = self.sep_token.join(axis_data)
        outputs['axis'] = self.tokenize_tgt_flatten(axis_data)

      #Pick a task randomly
      task = random.choice(list(outputs.keys()))

      ######################
      # Prepare inputs (context)
      ######################    

      #Tokenize context (document text or caption)
      if task == 'caption':
        context = get_text_window(
          all_text, context_start, 
          tgt_token=self.tgt_token,
          window_size=self.window_size, 
          widen_rate=self.widen_rate, 
          max_widen_len=self.max_widen_len, 
          min_sent_len=self.min_sent_len)
        
        #Prepend task to document context
        context = TASK2PREPEND[task] + context
      else:
        #Prepend task to caption
        context = TASK2PREPEND[task] + caption_label

      inputs, _ = self._tokenize(self.tokenizer, 
        context, max_source_len=self.max_source_len, 
        max_target_len=self.max_target_len, is_target=False)
      
      #Add a task
      inputs['labels'] = outputs[task]['input_ids']

      for k in list(inputs.keys()):
        inputs[k] = inputs[k].squeeze(0)
      
      return inputs

  def filter_data_by_tasks(self, data):
    new_data = []

    if 'caption' in self.tasks or 'data' in self.tasks:
      return data
    
    #Ensure a balanced number of every data type
    for d in data:
      series_name = self.get_series_names(d)
      categorical_data = self.get_categorical_values(d)
      axis_data = self.get_axis_names(d)
      if (len(series_name) and 'series_name' in self.tasks) \
        or (len(categorical_data) and 'categorical' in self.tasks) \
           or (len(axis_data) and 'axis' in self.tasks):
        new_data.append(d)
    return new_data

  def tokenize_tgt_flatten(self, text, flatten=False):
    #list of integers. 
    with self.tokenizer.as_target_tokenizer():
      tokens = self.tokenizer(text, max_length=self.max_target_len, padding="max_length", truncation=True)

    input_ids = [l if l != self.tokenizer.pad_token_id else -100 for l in tokens['input_ids']]

    if flatten:
      input_ids = torch.tensor([item for sublist in input_ids for item in sublist], dtype=torch.long)
    else:
      input_ids = torch.tensor(input_ids, dtype=torch.long)

    output = {}
    output['input_ids'] = input_ids
    return output
  
  def get_axis_names(self, d):
    axis_title_ids = []
    for b in d['task3']['output']['text_roles']:
      if b['role'] == 'axis_title':
        axis_title_ids.append(b['id'])

    axis_title_text = []
    for b in d['task2']['output']['text_blocks']:
      if b['id'] in axis_title_ids:
        if 'unnamed' not in b['text']:
          axis_title_text.append(b['text'])
    
    return axis_title_text

  def get_series_names(self, d):
    all_series_name = []
    for _, series_data in enumerate(d['task6']['output']['data series']):
      series_name = None
      if 'unnamed' not in series_data['name']:
        series_name = series_data['name']
      if series_name is not None:
        v = series_name.replace('\n',' ').strip()
        if len(v):
          all_series_name.append(v)
    return all_series_name
  
  def get_categorical_values(self, d):

    all_categorical_data = []

    #Categorical data is repeated. Only go through first series
    series_data = d['task6']['output']['data series'][0]

    for _, data in enumerate(series_data['data']):
      for _, (_, v) in enumerate(data.items()):
        if type(v) == str:
          cat_name = None
          if 'unnamed' not in v:
            cat_name = v
          if cat_name is not None:
            cat_name = cat_name.replace('\n',' ').strip()
            if len(cat_name):
              all_categorical_data.append(cat_name)
    return all_categorical_data