honeyplotnet / models / continuous / encoder.py
encoder.py
Raw
# ---------------------------------------------------------------
# Copyright (c) __________________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------

import random

import torch
import torch.nn as nn

from .helper import (
  get_chart_type_dict,
  pad_vector
)

from ..constant import SCALE_DIMS, REG_DIMS, UNIQ_CHART_HEADS, CHART_TO_HEAD_IDX

from .coder import Coder

class Encoder(Coder):
  def __init__(self,
    enc_conv,
    enc_tf,
    out_dim, 
    max_series_blocks,
    max_cont_blocks,
    chart_type_conditional=False,
    use_pos_embs=True, 
    norm_mode='minmax',
    device='cuda:0', 
    debug=False,
    ):
    super().__init__()

    self.chart_type_conditional = chart_type_conditional
    self.max_cont_blocks = max_cont_blocks
    self.max_series_blocks = max_series_blocks
    self.out_dim = out_dim
    self.norm_mode = norm_mode
    self.use_pos_embs = use_pos_embs
    self.device = device
    self.debug = debug
    self.dtype_float = torch.float32

    self.enc_conv  = enc_conv
    self.enc_tf    = enc_tf

    self.cont_encoder = nn.ModuleDict()
    for name in ['scale', 'continuous']:
      self.cont_encoder[name] = nn.ModuleDict()

    for head_name, head_dim in SCALE_DIMS[self.norm_mode].items():
      self.cont_encoder['scale'][head_name] = nn.Linear(head_dim, self.out_dim)

    for head_name, head_dim in REG_DIMS.items():
      self.cont_encoder['continuous'][head_name] = nn.Linear(head_dim, self.out_dim)

    self.pos_emb_cont = None
    self.pos_emb_scale = None
    if self.use_pos_embs:
      self.pos_emb_cont = nn.Parameter(
        torch.zeros(1, max_series_blocks, max_cont_blocks, self.out_dim))

      self.pos_emb_scale = nn.Parameter(
        torch.zeros(1, max_series_blocks, self.out_dim))

    if self.chart_type_conditional:
      self.ct_emb = nn.Embedding(len(UNIQ_CHART_HEADS), self.out_dim)
    
    self.apply(self._init_weights)


  def forward(self,inputs):
        
    # Shape into sequence
    encoder_input, _, loss = self.prepare_encoder_input(inputs)
    
    # Input through transformer
    if self.enc_tf is not None:
      encoder_output = self.enc_tf(encoder_input)

    # Downsample encoder inputs
    encoder_output = self.enc_conv(encoder_output)

    return encoder_output, loss

  def get_chart_type_emb(self, ct_idx=None, chart_type=None, enc_ver=False): 
    if ct_idx is None and chart_type is None:
      raise ValueError("Need to provide either indinces or string")

    if chart_type is not None:
      ct_idx = [CHART_TO_HEAD_IDX[ct] for ct in chart_type] 
      ct_idx = torch.tensor(ct_idx, dtype=torch.long, device=self.device).view(-1,1)

    return self.ct_emb(ct_idx)
  

  def prepare_encoder_input(self, inputs):

    x = inputs['chart_data'] if 'chart_data' in inputs else inputs

    chart_type = x['chart_type']
    chart_type_dict = x.get('chart_type_dict')

    if chart_type_dict is None:
      chart_type_dict = get_chart_type_dict(chart_type)

    scale_embd, scale_mask, s_loss = self.preencode_scale(
      inputs_embeds=x['scale']['inputs_embeds'],
      attention_mask=x['scale']['attention_mask'],
      chart_type_dict=chart_type_dict
      )

    cont_embd, cont_mask, c_loss  = self.preencode_continuous(
      inputs_embeds=x['continuous']['inputs_embeds'],
      attention_mask=x['continuous']['attention_mask'],
      chart_type_dict=chart_type_dict,
      flatten_series=True
    )

    loss = s_loss + c_loss

    encoder_input = [scale_embd, cont_embd]
    encoder_mask  = [scale_mask, cont_mask]

    #Prepend chart type before encoding
    if self.chart_type_conditional:
      ct_vec = self.get_chart_type_emb(chart_type=chart_type, enc_ver=True)
      ct_mask = torch.ones_like(ct_vec).mean(-1) #.to(self.device)
      encoder_input = [ct_vec] + encoder_input
      encoder_mask = [ct_mask] + encoder_mask

    encoder_input = torch.cat(encoder_input, dim=1)
    encoder_mask  = torch.cat(encoder_mask, dim=1)

    #Pad or crop to fit inside sequence length
    bsz, seq_len, dim = encoder_input.shape
    if encoder_input.size(1) >= self.max_cont_blocks:
      encoder_input = encoder_input[:, :self.max_cont_blocks, :]
      encoder_mask  = encoder_mask[:, :self.max_cont_blocks]
    else:
      pad_len = self.max_cont_blocks - seq_len
      
      inp_pad  = torch.zeros((bsz, pad_len, dim), device=self.device, dtype=self.dtype_float)
      mask_pad = torch.zeros((bsz, pad_len), device=self.device, dtype=self.dtype_float)

      encoder_input = torch.cat([encoder_input, inp_pad], dim=1)
      encoder_mask  = torch.cat([encoder_mask, mask_pad], dim=1)

    return encoder_input, encoder_mask, loss

  def preencode_continuous(self, inputs_embeds, attention_mask, chart_type_dict, pad_len=None, pad_dim=1, flatten_series=True):
    if isinstance(attention_mask, list):
      attention_mask = torch.stack(attention_mask, dim=0)
    
    assert isinstance(attention_mask, torch.Tensor)
    cont_mask = attention_mask.to(self.device)
    bsz, row_len, col_len = cont_mask.shape

    #Declare continuous embeddings
    cont_embd = torch.zeros((bsz, row_len, col_len, self.out_dim), dtype=self.dtype_float, device=self.device)

    #Loop through each chart type dict containing information on what indices contains what chart type
    for head_name, ind in chart_type_dict.items():

      #Prepare inputs by obtaining rows pertaining to a particular chart type
      cont_x = torch.stack([l for idx, l in enumerate(inputs_embeds) if idx in ind], dim=0).to(self.device)
      mask = cont_mask[ind,:] 
        
      #Loop through each row and encode values
      for s_idx in range(row_len):
        data_x     = cont_x[:, s_idx, :]
        data_mask  = mask[:, s_idx, :].unsqueeze(-1)
        cont_embd[ind, s_idx, :] = self.cont_encoder['continuous'][head_name](data_x) * data_mask
    
    #Ensure all heads are used for torchrun
    total_loss = torch.tensor(0.0, device=self.device, dtype=self.dtype_float)
    for head_name, head in self.cont_encoder['continuous'].items():
      if head_name not in chart_type_dict:
        blank = torch.zeros( [1,1,REG_DIMS[head_name]], device=self.device)
        total_loss += (head(blank) * 0).sum()

    ### Add position embds here
    series_count, pt_count = cont_embd.size(1), cont_embd.size(2)

    if series_count > self.max_series_blocks:
      randints = torch.tensor(
        sorted(random.sample(range(series_count), self.max_series_blocks)), dtype=torch.long, device=self.device)
      cont_embd = torch.index_select(cont_embd, dim=1, index=randints)
      cont_mask = torch.index_select(cont_mask, dim=1, index=randints)
    
    if pt_count > self.max_cont_blocks:
      randints = torch.tensor(
        sorted(random.sample(range(pt_count), self.max_cont_blocks)), dtype=torch.long, device=self.device)
      cont_embd = torch.index_select(cont_embd, dim=2, index=randints)
      cont_mask = torch.index_select(cont_mask, dim=2, index=randints)

    cont_embd = cont_embd + (self.pos_emb_cont[:,:series_count,:pt_count,:] if self.pos_emb_cont is not None else 0.0)

    if flatten_series:
      cont_embd = torch.flatten(cont_embd, start_dim=1, end_dim=2)
      cont_mask = torch.flatten(cont_mask, start_dim=1, end_dim=2)

    if pad_len is not None and isinstance(pad_len, int):
      cont_embd = pad_vector(cont_embd, pad_len, pad_dim, self.device, self.dtype_float)
      cont_mask = pad_vector(cont_mask, pad_len, pad_dim, self.device, self.dtype_float)
    
    return  cont_embd, cont_mask, total_loss


  def preencode_scale(self, inputs_embeds, chart_type_dict, attention_mask=None, pad_len=None, pad_dim=1):

    bsz = len(inputs_embeds)
    #Create attention mask if it doesnt exist
    if attention_mask is not None and attention_mask[0] is not None:
      scale_mask = torch.stack(attention_mask, dim=0).to(self.device)
    else:
      scale_mask = torch.ones((bsz), dtype=self.dtype_float, device=self.device)

    emb_len = min(scale_mask.size(1), self.max_series_blocks)
    scale_embd = torch.zeros((bsz, emb_len, self.out_dim), dtype=self.dtype_float, device=self.device)

    for head_name, ind in chart_type_dict.items():
      scale_x     = torch.stack([v for idx, v in enumerate(inputs_embeds) if idx in ind], dim=0).to(self.device)

      scale_enc = self.cont_encoder['scale'][head_name](scale_x) 
      series_count = scale_enc.size(1)

      scale_enc    = scale_enc + (self.pos_emb_scale[:,:series_count,:] if self.pos_emb_scale is not None else 0.0)

      scale_m    = scale_mask[ind, :].unsqueeze(-1)
      scale_enc *= scale_m
      scale_embd[ind, :, :]  = scale_enc
    
    #Ensure all heads contributes to the loss
    total_loss = torch.tensor(0.0, device=self.device, dtype=self.dtype_float)
    for head_name, head in self.cont_encoder['scale'].items():
      if head_name not in chart_type_dict:
        blank = torch.zeros([1,1,SCALE_DIMS['minmax'][head_name]], device=self.device)
        total_loss += (head(blank) * 0).sum()

    ### Padding to the right
    if pad_len is not None and isinstance(pad_len, int):
      scale_embd = pad_vector(scale_embd, pad_len, pad_dim, self.device, self.dtype_float)
      scale_mask = pad_vector(scale_mask, pad_len, pad_dim, self.device, self.dtype_float)

    return scale_embd, scale_mask, total_loss