honeyplotnet / models / seq_model.py
seq_model.py
Raw
# ---------------------------------------------------------------
# Copyright (c) ________________________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------

import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical

import copy

from transformers import (
  T5Config, BigBirdPegasusConfig, PegasusConfig,
  T5ForConditionalGeneration, BigBirdPegasusForConditionalGeneration, PegasusForConditionalGeneration,
  T5PreTrainedModel, BigBirdPegasusPreTrainedModel, PegasusPreTrainedModel,
  AutoTokenizer

)

from transformers.models.t5.modeling_t5 import T5Stack
from transformers.models.pegasus.modeling_pegasus import PegasusDecoder
from transformers.models.bigbird_pegasus.modeling_bigbird_pegasus import BigBirdPegasusDecoder

from models.constant import UNIQ_CHART_HEADS

SUPPORTED_MODELS = [
    'google/t5-v1_1-large',
    'google/pegasus-pubmed',
    'google/bigbird-pegasus-large-pubmed'
]

def init_seq_model(cfg, device_id, load_opt=True):

    sep_token = cfg.data.dataset.chart_data.sep_token
    codebook_size = 2 + len(UNIQ_CHART_HEADS) + \
        max(cfg.model.continuous_data.vq.n_emb1, cfg.model.continuous_data.vq.n_emb2)
    
    code_seq_len = 1 + cfg.model.continuous_data.vq.emb_len1 + cfg.model.continuous_data.vq.emb_len2
    
    decoder2_num_layers = cfg.model.seq.decoder2_num_layers

    #0 reserved for tokenizer.pad_token_id, 1 reserved for tokenizer.eos_token_id
    model_cfg = cfg.model.seq.hf_model
    assert model_cfg.name in SUPPORTED_MODELS, "Unsupported model: {}".format(model_cfg.name)

    cfg_kwargs = {"cache_dir": cfg.cache_dir, 
                  "codebook_size": codebook_size, 
                  "code_seq_len": code_seq_len,
                  "decoder2_num_layers": decoder2_num_layers}
    
    tok_kwargs = {"cache_dir": cfg.cache_dir}

    tokenizer = AutoTokenizer.from_pretrained(model_cfg.name, **tok_kwargs)
    num_added_toks = tokenizer.add_tokens([sep_token], special_tokens=True)

    assert tokenizer.pad_token_id == 0, tokenizer.pad_token_id
    assert tokenizer.eos_token_id == 1, tokenizer.eos_token_id

    if model_cfg.name == 'google/t5-v1_1-large':
        config_class = DoubleDecoderT5Config
        base_model_class = T5ForConditionalGeneration
        model_class = DoubleDecoderT5
    elif model_cfg.name == 'google/pegasus-pubmed':
        config_class = DoubleDecoderPegasusConfig
        base_model_class = PegasusForConditionalGeneration
        model_class = DoubleDecoderPegasus
    elif model_cfg.name == 'google/bigbird-pegasus-large-pubmed':
        config_class = DoubleDecoderBigBirdConfig
        base_model_class = BigBirdPegasusForConditionalGeneration
        model_class = DoubleDecoderBigBird

    hf_config = config_class.from_pretrained(model_cfg.name, **cfg_kwargs)
    model_kwargs = {"cache_dir": cfg.cache_dir, "config": hf_config, "from_tf": False}
    base_model = base_model_class.from_pretrained(model_cfg.name, **model_kwargs)
    base_model.resize_token_embeddings(len(tokenizer))
    assert base_model.config.decoder_start_token_id is not None, "Make sure that `config.decoder_start_token_id` is correctly defined"
    

    model = model_class(base_model, hf_config)
    base_num_params = sum(p.numel() for p in base_model.parameters() if p.requires_grad) / 1e6
    total_num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6

    if cfg.rank in [0,'cpu']: print("Number of trainable parameters | {} | Base: {:.4f} Total: {:.4f}".format(
            model_cfg.name, base_num_params, total_num_params))

    opt = None
    scheduler = None
    opt_mode = int(cfg.model.seq.opt_mode)
    if load_opt:
        if device_id is not 'cpu':
            model.cuda(device_id)

        #1: Text only, (Freeze decoder2)
        #2: Data only, (Freeze pre-trained model)
        params = []
        assert opt_mode in [0,1,2]
        if opt_mode == 0:
            params = list(filter(lambda p: p.requires_grad, model.parameters()))
        else:
            for name, param in model.named_parameters():
                if (opt_mode == 1 and 'model.' in name) or \
                    (opt_mode == 2 and 'model.' not in name):
                    param.requires_grad = True
                    params.append(param)
                else:
                    param.requires_grad = False

            
        lr = cfg.train.optim.learning_rate
        betas = cfg.train.optim.betas
        if cfg.train.optim.type == 'AdamW':
            opt = torch.optim.AdamW(params, lr=lr, betas=betas)
        elif cfg.train.optim.type == 'Adam':
            opt = torch.optim.Adam(params, lr=lr, betas=betas)
        else:
            raise NotImplementedError()

        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=5, verbose=False)

    return model, tokenizer, opt, scheduler

class DoubleDecoderT5Config(T5Config):
    def __init__(
        self,
        codebook_size: int = 128,
        code_seq_len: int = 29,
        decoder2_num_layers: int = 4,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.codebook_size = codebook_size 
        self.code_seq_len = code_seq_len
        self.decoder2_num_layers = decoder2_num_layers

class DoubleDecoderPegasusConfig(PegasusConfig):
    def __init__(
        self,
        codebook_size: int = 128,
        code_seq_len: int = 29,
        decoder2_num_layers: int = 4,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.codebook_size = codebook_size 
        self.code_seq_len = code_seq_len
        self.decoder2_num_layers = decoder2_num_layers

class DoubleDecoderBigBirdConfig(BigBirdPegasusConfig):
    def __init__(
        self,
        codebook_size: int = 128,
        code_seq_len: int = 29,
        decoder2_num_layers: int = 4,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.codebook_size = codebook_size 
        self.code_seq_len = code_seq_len
        self.decoder2_num_layers = decoder2_num_layers

class DoubleDecoderBase(object):
    base_model_prefix = "model"
    _keys_to_ignore_on_load_missing = [r"decoder2\.weight", r"data_head\.weight"]
    def __init__(self, **kwargs):
        pass 

    def get_encoder(self):
        return self.model.get_encoder()

    def get_decoder(self):
        return self.model.get_decoder()
    
    def get_input_embeddings(self):
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(self, **kwargs):
        return self.model.prepare_inputs_for_generation(**kwargs)
    
    def prepare_decoder_input_ids_from_labels(self, labels):
        return self.model.prepare_decoder_input_ids_from_labels(labels)

    def set_output(self, mode):
        assert mode in ['both', 'text', 'data']
        self.output_mode = mode
    
    def generate_codes(self, encoder_hidden_states, greedy=False):

        bsz = encoder_hidden_states.shape[0]
        input_ids = torch.zeros([bsz, 1], dtype=torch.long, device=self.device)

        for _ in range(self.config.code_seq_len):
            outputs = self.decoder2(
                input_ids=input_ids,
                encoder_hidden_states=encoder_hidden_states,
            )
            code_logits = self.data_head(outputs[0])

            if greedy:
                tokens = code_logits.argmax(-1)
            else:
                tokens = Categorical(logits=code_logits).sample()
                

            new_input_ids = tokens[:,-1:]
            input_ids = torch.cat([input_ids, new_input_ids], dim=-1)
        
        #remove start of sentence token
        code_tokens = input_ids[:,1:]
        return code_tokens
    
    def generate(self, **kwargs):
        if self.output_mode == 'text':
            input_ids = kwargs.pop('input_ids')
            text_tokens = self.model.generate(input_ids, **kwargs)
            return text_tokens
        
        elif self.output_mode == 'data':
            kwargs['output_hidden_states'] = True
            outputs = self.model.get_encoder()(**kwargs)
            code_tokens = self.generate_codes(outputs.last_hidden_state)
            return code_tokens
        else:
            raise ValueError("Only use text or data as output mode")
        
    def set_output(self, mode):
        assert mode in ['both', 'text', 'data']
        self.output_mode = mode

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        encoder_outputs=None,
        decoder_input_ids=None,
        decoder2_input_ids=None,
        decoder_attention_mask=None,
        use_cache=None,
        output_attentions=None,
        return_dict=None,
        **kwargs
    ):  
        # call the parent class's forward method
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            encoder_outputs=encoder_outputs,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=True,
            return_dict=return_dict,
        )

        # Pass through second decoder
        dec2_out = self.decoder2(
            input_ids=decoder2_input_ids,
            attention_mask=None,
            inputs_embeds=None,
            encoder_hidden_states=outputs.encoder_last_hidden_state,
            encoder_attention_mask=attention_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            return_dict=return_dict,
        )

        code_logits = self.data_head(dec2_out[0])
        
        if self.output_mode == 'text':
            return outputs
        elif self.output_mode == 'data':
            outputs.logits = code_logits
            return outputs
        elif self.output_mode == 'both':
            return (outputs.logits, code_logits,) 
        raise

class DoubleDecoderT5(DoubleDecoderBase, T5PreTrainedModel):
    def __init__(self, model, config):
        super(T5PreTrainedModel, self).__init__(config)
        self.model = model
        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
        decoder_config.is_encoder_decoder = False
        decoder_config.num_layers = config.decoder2_num_layers
        data_shared = nn.Embedding(config.codebook_size, config.d_model)
        self.decoder2 = T5Stack(decoder_config, data_shared)
        self.data_head = nn.Linear(config.d_model, config.codebook_size)
        self.output_mode = 'both'

class DoubleDecoderPegasus(DoubleDecoderBase, PegasusPreTrainedModel):
    def __init__(self, model, config):
        super(PegasusPreTrainedModel, self).__init__(config)
        self.model = model
        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
        decoder_config.is_encoder_decoder = False
        decoder_config.num_layers = config.decoder2_num_layers
        data_shared = nn.Embedding(config.codebook_size, config.d_model)
        self.decoder2 = PegasusDecoder(decoder_config, data_shared)
        self.data_head = nn.Linear(config.d_model, config.codebook_size)
        self.output_mode = 'both'

class DoubleDecoderBigBird(DoubleDecoderBase, BigBirdPegasusPreTrainedModel):
    def __init__(self, model, config):
        super(BigBirdPegasusPreTrainedModel, self).__init__(config)
        self.model = model
        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
        decoder_config.is_encoder_decoder = False
        decoder_config.num_layers = config.decoder2_num_layers
        data_shared = nn.Embedding(config.codebook_size, config.d_model)
        self.decoder2 = BigBirdPegasusDecoder(decoder_config, data_shared) 
        self.data_head = nn.Linear(config.d_model, config.codebook_size)
        self.output_mode = 'both'