CoC-code / code / src / model.py
model.py
Raw
import torch.nn as nn
import torch
import numpy as np
from transformers import AutoTokenizer, T5ForConditionalGeneration, RobertaForMaskedLM, BertForMaskedLM, AutoModel, \
    AutoImageProcessor, AutoConfig, BertModel, AutoModelForSequenceClassification, BertLayer, AutoModelForImageClassification
from torchtext.vocab import GloVe
from transformers import T5Config, T5ForConditionalGeneration, T5EncoderModel

import torch
from torch import nn
from torch.nn import CrossEntropyLoss, BCELoss, BCEWithLogitsLoss
from transformers.modeling_outputs import (
    BaseModelOutput,
    Seq2SeqLMOutput,
)
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
from torch.utils.checkpoint import checkpoint
from torchvision import models


class MultimodalEncoder(nn.Module):
    def __init__(self, config):
        super(MultimodalEncoder, self).__init__()
        self.text_encoder = T5EncoderModel.from_pretrained(config.model_path)
        text_encoder_config = AutoConfig.from_pretrained(config.model_path)
        config.hidden_size = text_encoder_config.hidden_size
        self.config = config
        self.prefix_length = config.prefix_length

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            image_ids=None,
    ):
        text_features = self.text_encoder(input_ids, attention_mask).last_hidden_state
        return text_features


class MMetaphorCOTModel(nn.Module):
    def __init__(self, config):
        super(MMetaphorCOTModel, self).__init__()
        self.config = config
        self.mm_encoder = MultimodalEncoder(config)
        self.llm = T5ForConditionalGeneration.from_pretrained(config.model_path)
        self.dense_layer = nn.Sequential(
            nn.Linear(self.config.vision_dim, self.config.hidden_size * self.config.prefix_length),
            # nn.BatchNorm1d(self.config.hidden_size * self.config.prefix_length),
            nn.Dropout(self.config.dropout),
            # nn.Linear((self.config.hidden_size * self.config.prefix_length) * 2,
            #           self.config.hidden_size * self.config.prefix_length),
            nn.ReLU(inplace=True)
        )

        self.mha_layer = torch.nn.MultiheadAttention(embed_dim=config.hidden_size, kdim=config.hidden_size,
                                                     vdim=config.hidden_size, num_heads=4, batch_first=True)

        self.gate_dense = nn.Linear(2 * config.hidden_size, config.hidden_size)
        self.sigmoid = nn.Sigmoid()
        self.tokenizer = AutoTokenizer.from_pretrained(config.model_path)

    def forward(
            self,
            **kwargs
    ):
        # input_ids, input_masks, image_ids, output_ids, output_masks = [kwargs[w] for w in '\
        #                             input_ids, input_masks, image_ids, output_ids, output_masks'.strip().split(
        #     ', ')]
        input_ids, input_masks, image_prefix, output_ids, output_masks = [kwargs[w] for w in '\
                                            input_ids, input_masks, image_prefix, output_ids, output_masks'.strip().split(
            ', ')]
        output_ids[output_ids[:, :] == self.tokenizer.pad_token_id] = -100
        text_features = self.mm_encoder(input_ids, input_masks, image_prefix)
        vision_features = self.dense_layer(image_prefix)
        vision_features = vision_features.view(-1, self.config.prefix_length, self.config.hidden_size)
        vision_features, _ = self.mha_layer(text_features, vision_features, vision_features)
        merge = torch.cat([text_features, vision_features], dim=-1)
        gate = self.sigmoid(self.gate_dense(merge))
        embedded = (1 - gate) * text_features + gate * vision_features
        output = self.llm(inputs_embeds=embedded, decoder_input_ids=None,
                          decoder_attention_mask=output_masks, labels=output_ids)
        loss = output[0]
        return loss

    def forward_entity(self, embedded, output_ids, output_masks, *args):
        output = self.llm(inputs_embeds=embedded, decoder_input_ids=None,
                          decoder_attention_mask=output_masks, labels=output_ids)
        return output[0]

    def generate(self, **kwargs):
        input_ids, input_masks, image_prefix = [kwargs[w] for w in '\
                                            input_ids, input_masks, image_prefix'.strip().split(
            ', ')]
        # input_ids, input_masks, image_ids = [kwargs[w] for w in '\
        #                                     input_ids, input_masks, image_ids'.strip().split(
        #     ', ')]
        text_features = self.mm_encoder(input_ids, input_masks, image_prefix)

        # vision_features = self.dense_layer(vision_features)
        vision_features = self.dense_layer(image_prefix)
        vision_features = vision_features.view(-1, self.config.prefix_length, self.config.hidden_size)
        vision_features, _ = self.mha_layer(text_features, vision_features, vision_features)
        # print("vision:", vision_features.shape)
        # print("text:", text_features.shape)
        merge = torch.cat([text_features, vision_features], dim=-1)
        # print("merge:", merge.shape)
        gate = self.sigmoid(self.gate_dense(merge))
        embedded = (1 - gate) * text_features + gate * vision_features
        output = self.llm.generate(inputs_embeds=embedded, max_length=self.config.max_length)
        dec = [self.tokenizer.decode(ids) for ids in output]
        output = [context.replace('<pad>', '').replace('</s>', '').replace('<extra_id_0>', '').replace('<extra_id_1>',
                                                                                                       '').replace(
            '<extra_id_2>', '').strip() for context in
                  dec]
        return output

    def evaluate(self, step_one=False, **kwargs):
        input_ids, input_masks, image_prefix = [kwargs[w] for w in '\
                                                    input_ids, input_masks, image_prefix'.strip().split(
            ', ')]
        # input_ids, input_masks, image_ids = [kwargs[w] for w in '\
        #                                             input_ids, input_masks, image_ids'.strip().split(
        #     ', ')]
        text_features = self.mm_encoder(input_ids, input_masks, image_prefix)
        vision_features = self.dense_layer(image_prefix)
        vision_features = vision_features.view(-1, self.config.prefix_length, self.config.hidden_size)
        vision_features, _ = self.mha_layer(text_features, vision_features, vision_features)
        merge = torch.cat([text_features, vision_features], dim=-1)
        gate = self.sigmoid(self.gate_dense(merge))
        embedded = (1 - gate) * text_features + gate * vision_features
        output = self.llm.generate(inputs_embeds=embedded, max_length=self.config.max_length)
        dec = [self.tokenizer.decode(ids) for ids in output]
        # print(dec)
        label_dict = {w: i for i, w in enumerate(self.config.label_list)}
        output = [label_dict.get(
            w.replace('<pad>', '').replace('</s>', '').replace('<extra_id_0>', '').replace('<extra_id_1>', '').replace(
                '<extra_id_2>', '').strip(), 0) for w in dec]
        # print(output)
        return output


class ExtractImagePrefix(nn.Module):
    def __init__(self, config):
        super(ExtractImagePrefix, self).__init__()
        self.vision_encoder = models.resnet50(pretrained=True)

    def forward(
            self,
            **kwargs
    ):
        vision_features = self.vision_encoder(kwargs['image_ids'])
        return vision_features


class MetaphorLLaVABackBone(nn.Module):
    def __init__(self, config):
        super(MetaphorLLaVABackBone, self).__init__()
        self.config = config
        # self.engine = T5ForConditionalGeneration.from_pretrained(config.model_path)
        self.model = self.getEngine(config.model_path)
        self.tokenizer = AutoTokenizer.from_pretrained(config.model_path)
        self.loss = nn.CrossEntropyLoss()
        self.mm_encoder = MultimodalEncoder(config)
        self.dense_layer = nn.Sequential(
            nn.Linear(self.config.vision_dim, self.config.hidden_size * self.config.prefix_length),
            # nn.Linear(self.config.vision_dim, self.config.hidden_size),
            # nn.BatchNorm1d(self.config.hidden_size * self.config.prefix_length),
            nn.Dropout(self.config.dropout),
            # nn.Linear((self.config.hidden_size * self.config.prefix_length) * 2,
            #           self.config.hidden_size * self.config.prefix_length),
            nn.ReLU(inplace=True)
        )
        self.mha_layer = torch.nn.MultiheadAttention(embed_dim=config.hidden_size, kdim=config.hidden_size,
                                                     vdim=config.hidden_size, num_heads=self.config.num_heads,
                                                     batch_first=True)
        self.gate_dense = nn.Linear(2 * config.hidden_size, config.hidden_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, **kwargs):
        input_ids, input_masks, output_ids, output_masks, input_labels, image_prefix, mask_pos = [kwargs[w] for w in '\
                input_ids, input_masks, output_ids, output_masks, input_labels, image_prefix, mask_pos'.strip().split(
            ', ')]
        # input_ids, input_masks, output_ids, output_masks, input_labels, image_blip, mask_pos = [kwargs[w] for w in '\
        #         input_ids, input_masks, output_ids, output_masks, input_labels, image_blip, mask_pos'.strip().split(
        #     ', ')]
        # 由self.model计算出loss返回给模型
        output_ids[output_ids[:, :] == self.tokenizer.pad_token_id] = -100
        text_features = self.mm_encoder(input_ids, input_masks, image_prefix)
        vision_features = self.dense_layer(image_prefix)
        vision_features = vision_features.view(-1, self.config.prefix_length, self.config.hidden_size)
        vision_features, _ = self.mha_layer(text_features, vision_features, vision_features)
        merge = torch.cat([text_features, vision_features], dim=-1)
        gate = self.sigmoid(self.gate_dense(merge))
        embedded = (1 - gate) * text_features + gate * vision_features
        output = self.model(embedded, input_labels, output_ids, output_masks, mask_pos)

        return output

    def generate(self, **kwargs):
        input_ids, input_masks, image_prefix = [kwargs[w] for w in '\
                        input_ids, input_masks, image_prefix'.strip().split(
            ', ')]
        text_features = self.mm_encoder(input_ids, input_masks, image_prefix)
        vision_features = self.dense_layer(image_prefix)
        vision_features = vision_features.view(-1, self.config.prefix_length, self.config.hidden_size)
        vision_features, _ = self.mha_layer(text_features, vision_features, vision_features)
        merge = torch.cat([text_features, vision_features], dim=-1)
        gate = self.sigmoid(self.gate_dense(merge))
        embedded = (1 - gate) * text_features + gate * vision_features
        output = self.model.generate(embedded)
        return output

    def evaluate(self, **kwargs):
        input_ids, input_masks, output_ids, output_masks, input_labels, image_prefix, mask_pos = [
            kwargs[w] for w in '\
                            input_ids, input_masks, output_ids, output_masks, input_labels, image_prefix, mask_pos'.strip().split(
                ', ')]
        text_features = self.mm_encoder(input_ids, input_masks, image_prefix)
        vision_features = self.dense_layer(image_prefix)
        vision_features = vision_features.view(-1, self.config.prefix_length, self.config.hidden_size)
        vision_features, _ = self.mha_layer(text_features, vision_features, vision_features)
        merge = torch.cat([text_features, vision_features], dim=-1)
        gate = self.sigmoid(self.gate_dense(merge))
        embedded = (1 - gate) * text_features + gate * vision_features
        output = self.model.evaluate(embedded, mask_pos)
        return output

    def getEngine(self, model_name):
        if model_name == 'google/flan-t5-base' or model_name == 'google/flan-t5-large':
            print('use T5ForConditionalGeneration model:{}'.format(model_name))
            return MetaphorLLaVAT5Base(self.config)
        elif model_name == 'bert-base-uncased' or model_name == 'bert-large-uncased':
            print('use BertForMaskedLM model:{}'.format(model_name))


class MetaphorLLaVAT5Base(nn.Module):
    def __init__(self, config):
        super(MetaphorLLaVAT5Base, self).__init__()
        self.config = config
        self.engine = T5ForConditionalGeneration.from_pretrained(config.model_path)
        self.tokenizer = AutoTokenizer.from_pretrained(config.model_path)

    def forward(self, input_embedded, input_labels, output_ids, output_masks, mask_pos):
        # print(output_ids)
        # output_ids[output_ids[:, :] == self.tokenizer.pad_token_id] = -100
        output = self.engine(inputs_embeds=input_embedded, decoder_input_ids=None,
                             decoder_attention_mask=output_masks, labels=output_ids)
        # print(output)
        loss = output[0]
        return loss

    def generate(self, input_embedded):
        output = self.engine.generate(inputs_embeds=input_embedded, max_length=50)
        dec = [self.tokenizer.decode(ids) for ids in output]
        output = [context.replace('<pad>', '').replace('</s>', '').replace('<extra_id_0>', '').strip() for context in
                  dec]
        return output

    def evaluate(self, input_embedded, mask_pos):
        output = self.engine.generate(inputs_embeds=input_embedded, max_length=50)
        dec = [self.tokenizer.decode(ids) for ids in output]
        # print(dec)
        label_dict = {w: i for i, w in enumerate(self.config.label_list)}
        output = [label_dict.get(w.replace('<pad>', '').replace('</s>', '').strip(), 0) for w in dec]
        # print(output)
        return output