honeyplotnet / models / build.py
build.py
Raw
# ---------------------------------------------------------------
# Copyright (c) __________________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------


import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP

from .conv import Conv1dEncoder, Conv1dDecoder 
from .continuous import ContinuousModel, Discriminator
from .gpt import GPTNoEmbed
from .mh_dropout import MHDropoutNetRandom1D
from models.vq import VectorQuantizer

from .seq_model import init_seq_model

from transformers import (
  AutoConfig, AutoTokenizer, 
  AutoModelForSeq2SeqLM, T5Config,
)

from transformers.models.t5.modeling_t5 import T5Stack

def init_model(cfg, mode, stage, device_id):
  '''Initialize all models according to config file'''
  
  use_distributed = cfg.torch_dist.use
  active_models   = cfg.model.active

  models, toks, opts, schs = {}, {}, {}, {}

  if cfg.eval.ksm.active:
    models['ksm'], toks['ksm'] = init_ksm_model(cfg)

  models['continuous'], opts['continuous'] = init_plot_data_model(cfg)
  
  if cfg.model.continuous_data.disc.use:
    models['disc'], opts['disc'] = init_disc_model(cfg, device_id)

  if 'seq' in active_models or stage == 'seq':
    models['seq'], toks['seq'], opts['seq'], schs['seq'] = init_seq_model(
      cfg, cfg.device_id, load_opt=stage=='seq')
    
  #Prepare models for distributed training
  if use_distributed:
    
    for s in models.keys():
      models[s] = to_distributed(
          cfg=cfg, 
          model=models[s], 
          device_id=device_id, 
          )

      if cfg.rank == 0: print(f"Creating distributed model: {s}")
      
  elif cfg.device_id not in [None, 'cpu']:
    for s in models.keys():
      if models[s] is not None and device_id != 'cpu':
        models[s].to('cuda:{}'.format(device_id))

  return models, toks, opts, schs

def to_distributed(cfg, model, device_id):
  if model is None: return None
  model.to(f'cuda:{device_id}')
  model = DDP(model, device_ids=[device_id], find_unused_parameters=cfg.debug) 
  return model 

def init_plot_data_model(cfg):

  use_fp16      = cfg.fp16.use
  cd_cfg        = cfg.model.continuous_data
  encoder_cfg   = cd_cfg.encoder
  decoder_cfg   = cd_cfg.decoder
  vq_cfg        = cd_cfg.vq
  mhd_cfg       = cd_cfg.mhd
  
  max_blocks   = encoder_cfg.max_blocks

  #### Encoder
  enc_conv_kwargs = encoder_cfg.conv
  enc_conv_kwargs['channels'] = [max_blocks.points] + enc_conv_kwargs['channels'] 
  last_chn_enc = enc_conv_kwargs['channels'][-1]

  enc_conv = Conv1dEncoder(**enc_conv_kwargs)
  enc_proj1 = nn.Conv1d(last_chn_enc, vq_cfg.emb_len1, 1)
  
  enc_tf = None
  if encoder_cfg.transformer.use:
    enc_tf = init_transformer(encoder_cfg.transformer, 
      block_size=max_blocks.points + 1, emb_dim=vq_cfg.emb_dim1, use_pos_embs=False)
  
  #### Decoder
  # y_hat > dec_conv > dec_tf > proj
  dec_inp_channels = int(cfg.model.continuous_data.decoder.chart_type_conditional) + (vq_cfg.emb_len1)
  dec_conv_kwargs = decoder_cfg.conv
  dec_conv_kwargs['channels'] = [dec_inp_channels] + dec_conv_kwargs['channels']
  last_chn_dec = dec_conv_kwargs['channels'][-1]

  dec_conv = Conv1dDecoder(**dec_conv_kwargs)
  dec_proj_col = nn.Conv1d(last_chn_dec, max_blocks.points, 1)
  dec_proj_row = nn.Conv1d(last_chn_dec, max_blocks.series, 1)
  
  dec_tf_col = init_transformer(decoder_cfg.transformer, 
    block_size=max_blocks.points + 1, emb_dim=vq_cfg.emb_dim1)
  dec_tf_row = init_transformer(decoder_cfg.transformer, 
    block_size=max_blocks.series + 1, emb_dim=vq_cfg.emb_dim1)

  vq1_kwargs = {
    'n_emb': vq_cfg.n_emb1,
    'emb_dim': vq_cfg.emb_dim1,
    'beta': vq_cfg.beta,
    'tiled': vq_cfg.tiled,
    'ema_update': vq_cfg.ema_update,
    'random_restart': vq_cfg.random_restart
    }

  vq_layer1 = VectorQuantizer(**vq1_kwargs)

  ################################
  # 3. MH Dropout Block
  ################################
  enc_proj2 = None
  enc_proj3 = None
  vq_layer2 = None
  mhd_layer = None
  if mhd_cfg.use:
    vq2_kwargs = {
      'n_emb': vq_cfg.n_emb2,
      'emb_dim': vq_cfg.emb_dim2,
      'beta': vq_cfg.beta,
      'tiled': vq_cfg.tiled,
      'ema_update': vq_cfg.ema_update,
      'random_restart': vq_cfg.random_restart
      }

    vq_layer2 = VectorQuantizer(**vq2_kwargs)
    enc_proj2 = nn.Conv1d(enc_conv_kwargs['channels'][-1], vq_cfg.emb_len2, 1)
    enc_proj3 = nn.Linear(vq_cfg.emb_dim1, vq_cfg.emb_dim2)

    mhd_inp_dim = vq_cfg.emb_dim2
    
    if mhd_cfg.bottleneck:
      hidden_dim = mhd_cfg.bottleneck_dim
    else:
      hidden_dim = mhd_inp_dim

    mhd_kwargs = {
      'inp_dim': mhd_inp_dim,
      'hidden_dim': hidden_dim,
      'out_dim': vq_cfg.emb_dim1,
      'dist_reduce': mhd_cfg.dist_reduce,
      'loss_reduce': mhd_cfg.dist_reduce,
      'loss_reduce_dims': mhd_cfg.loss_reduce_dims,
      'norm': mhd_cfg.norm,
      'dist_loss': mhd_cfg.dist_loss,
      'gamma': mhd_cfg.gamma,
      'dropout_rate': mhd_cfg.dropout_rate,
      'decoder_cfg': mhd_cfg.decoder,
      'bottleneck': mhd_cfg.bottleneck
    }

    mhd_layer = MHDropoutNetRandom1D(**mhd_kwargs)
    
  data_model_kwargs = {
    'enc_conv': enc_conv,
    'enc_proj1': enc_proj1,
    'enc_proj2': enc_proj2,
    'enc_proj3': enc_proj3,
    'enc_tf': enc_tf,
    'dec_conv': dec_conv,
    'dec_tf_col': dec_tf_col,
    'dec_tf_row': dec_tf_row, 
    'dec_proj_col': dec_proj_col,
    'dec_proj_row': dec_proj_row, 
    'vq_layer1': vq_layer1,
    'vq_layer2': vq_layer2,
    'mhd_layer': mhd_layer,
    'use_mhd': mhd_cfg.use,
    'hypothese_bsz': mhd_cfg.hypothese_bsz,
    'hypothese_count': mhd_cfg.hypothese_count,
    'emb_dim1': vq_cfg.emb_dim1,
    'emb_len1': vq_cfg.emb_len1,
    'emb_len2': vq_cfg.emb_len2,
    'conditional_encoder': cfg.model.continuous_data.encoder.chart_type_conditional,
    'conditional_decoder': cfg.model.continuous_data.decoder.chart_type_conditional,
    'use_pos_embs': cfg.model.continuous_data.use_pos_embs,
    'max_series_blocks': max_blocks.series, 
    'max_cont_blocks': max_blocks.points, 
    'scale_mode': cfg.data.dataset.chart_data.scale_mode,
    'scale_eps': cfg.data.dataset.chart_data.scale_eps,
    'scale_floor': cfg.data.dataset.chart_data.scale_floor,
    'norm_mode': cfg.data.dataset.chart_data.norm_mode,
    'cont_loss_fn': cfg.model.continuous_data.loss_fn.continuous,
    'scale_loss_fn': cfg.model.continuous_data.loss_fn.scale,
    'fp16': use_fp16, 
    'debug': cfg.debug,
    'device': f'cuda:{cfg.device_id}' if cfg.device_id != 'cpu' else 'cpu'
  }

  model = ContinuousModel(**data_model_kwargs)

  opt = None
  params = list(filter(lambda p: p.requires_grad, model.parameters()))
  lr = cfg.train.optim.learning_rate
  betas = cfg.train.optim.betas
  if cfg.train.optim.type == 'AdamW':
    opt = torch.optim.AdamW(params, lr=lr, betas=betas)
  elif cfg.train.optim.type == 'Adam':
    opt = torch.optim.Adam(params, lr=lr, betas=betas)

  return model, opt


def init_disc_model(cfg, device_id):
  
  use_fp16      = cfg.fp16.use
  cd_cfg        = cfg.model.continuous_data
  max_blocks    = cd_cfg.encoder.max_blocks

  vq_cfg        = cd_cfg.vq
  disc_cfg      = cd_cfg.disc

  #### Encoder
  disc_conv_kwargs = disc_cfg.conv
  disc_conv_kwargs['channels'] = [max_blocks.points] + disc_conv_kwargs['channels'] 

  enc_conv = Conv1dEncoder(**disc_conv_kwargs)
  
  enc_tf = None
  enc_tf = init_transformer(disc_cfg.transformer, 
    block_size=max_blocks.points + 1, emb_dim=vq_cfg.emb_dim1, 
    use_pos_embs=False)

  kwargs = {
    'enc_conv': enc_conv,
    'enc_tf': enc_tf,
    'emb_dim1': vq_cfg.emb_dim1,
    'max_series_blocks': max_blocks.series,
    'max_cont_blocks': max_blocks.points,
    'norm_mode': cfg.data.dataset.chart_data.norm_mode,
    'disc_start': disc_cfg.disc_start,
    'disc_loss': disc_cfg.disc_loss,
    'disc_factor': disc_cfg.disc_factor,
    'disc_weight': disc_cfg.disc_weight,
    'disc_conditional': disc_cfg.disc_conditional,
    'use_pos_embs': disc_cfg.use_pos_embs,
    'device': f'cuda:{device_id}' if device_id != 'cpu' else 'cpu',
    'debug': cfg.debug,
    'fp16': use_fp16
  }

  model = Discriminator(**kwargs)

  opt = None
  params = list(filter(lambda p: p.requires_grad, model.parameters()))
  lr = cfg.train.optim.learning_rate
  betas = cfg.train.optim.betas
  if cfg.train.optim.type == 'AdamW':
    opt = torch.optim.AdamW(params, lr=lr, betas=betas)
  elif cfg.train.optim.type == 'Adam':
    opt = torch.optim.Adam(params, lr=lr, betas=betas)

  return model, opt

def init_ksm_model(cfg):

  model_cfg = cfg.eval.ksm  

  hf_config = AutoConfig.from_pretrained(
        model_cfg.name, cache_dir=cfg.cache_dir, 
        )
  
  tokenizer = AutoTokenizer.from_pretrained(
        model_cfg.name,
        use_fast=model_cfg.use_fast, cache_dir=cfg.cache_dir, 
        )

  model = AutoModelForSeq2SeqLM.from_pretrained(
        model_cfg.name,
        from_tf=False, 
        config=hf_config, cache_dir=cfg.cache_dir, 
        )

  model.resize_token_embeddings(len(tokenizer))
  if cfg.rank == 0:
    print("KSM CFG | backbone={}".format(model_cfg.name))

  return model, tokenizer

def init_transformer(tf_cfg, block_size, emb_dim, use_pos_embs=True):
  if tf_cfg.name == 'gpt':
    m = GPTNoEmbed(
      block_size=block_size,
      n_layer=tf_cfg.n_layer,
      n_head=tf_cfg.n_head,
      n_embd=emb_dim,
      use_pos_embs=use_pos_embs
    )
  elif tf_cfg.name == 't5_decoder':
    decoder_config = T5Config(
        vocab_size=0,
        num_layers=tf_cfg.n_layer,
        num_heads=tf_cfg.n_head,
        d_model=emb_dim,
        d_ff=int(emb_dim*4),
        d_kv=tf_cfg.d_kv,
        relative_attention_num_buckets= int(emb_dim/16) if tf_cfg.num_buckets == 0 else tf_cfg.num_buckets,
        relative_attention_max_distance= int(emb_dim/4) if tf_cfg.max_distance == 0 else tf_cfg.max_distance
      )
    decoder_config.is_decoder = True
    decoder_config.is_encoder_decoder = False
    m = T5Stack(decoder_config)
  else:
    raise NotImplementedError()

  return m