CoC-code / code / main.py
main.py
Raw
import argparse
import yaml
import torch
from attrdict import AttrDict
import pandas as pd
from torch import nn
from transformers import AutoConfig

from src.utils import set_seed, load_params_LLM, frozen
from src.loader import MetaphorDataLoader, MetaphorLLaVADataLoader
from src.model import MMetaphorCOTModel, ExtractImagePrefix, MetaphorLLaVABackBone
from src.engine import ImageTrainer, MMetaphorCOTTrainer, MetaphorLLaVATrainer
import os

os.environ["CUDA_VISIBLE_DEVICES"] = '0,1'


class Template:
    def __init__(self, args):
        config = AttrDict(yaml.load(open(args.config, 'r', encoding='utf-8'), Loader=yaml.FullLoader))
        names = []
        for k, v in vars(args).items():
            setattr(config, k, v)
        config.dataname = config.data_name
        set_seed(config.seed)

        config.device = torch.device('cuda:{}'.format(config.cuda_index) if torch.cuda.is_available() else 'cpu')
        names = [config.model_size, config.dataname] + names
        config.save_name = '_'.join(list(map(str, names))) + '_{}_{}_{}.pth.tar'
        self.config = config

    def forward(self):
        print(f"Running on the {self.config.data_name} data.")
        if self.config.use_llava:
            (self.trainLoader, self.validLoader, self.testLoader), self.config = MetaphorLLaVADataLoader(
                self.config).get_data()
            self.model = MetaphorLLaVABackBone(config=self.config).to(self.config.device)
            self.model = frozen(self.model)
            self.config = load_params_LLM(self.config, self.model, self.trainLoader)

            trainer = MetaphorLLaVATrainer(self.model, self.config, self.trainLoader, self.validLoader,
                                           self.testLoader)
            print("Choosing llava prompt mode.")
        elif self.config.reasoning == 'mprompt':
            (self.trainLoader, self.validLoader, self.testLoader), self.config = MetaphorDataLoader(
                self.config).get_data()
            self.model = MMetaphorCOTModel(config=self.config).to(self.config.device)
            if self.config.multi_gpu:
                self.model = nn.DataParallel(self.model)
                self.model = self.model.cuda()
            self.model = frozen(self.model)

            self.config = load_params_LLM(self.config, self.model, self.trainLoader)
            trainer = MMetaphorCOTTrainer(self.model, self.config, self.trainLoader, self.validLoader,
                                          self.testLoader)
            print("Choosing multimodal prompt mode.")

        elif self.config.reasoning == 'extract':
            (self.trainLoader, self.validLoader, self.testLoader), self.config = MetaphorDataLoader(
                self.config).get_data()
            self.model = ExtractImagePrefix(config=self.config).to(self.config.device)
            self.config = load_params_LLM(self.config, self.model, self.trainLoader)
            print("Choosing multimodal prompt mode.")
            trainer = ImageTrainer(self.model, self.config, self.trainLoader, self.validLoader,
                                   self.testLoader)
        else:
            raise 'Should choose a correct reasoning mode'

        if self.config.zero_shot == True:
            print("Zero-shot mode for evaluation.")
            r = trainer.evaluate_step(self.testLoader, 'test')
            print(r)
            return

        if self.config.reasoning == 'extract':
            print("Extract image feature.")
            test_ = trainer.evaluate_step(self.testLoader, 'test')
            dev_ = trainer.evaluate_step(self.validLoader, 'dev')
            train_ = trainer.evaluate_step(self.trainLoader, 'train')
            torch.save(train_, './data/preprocessed/image_meme_prefix_train.pt')
            torch.save(test_, './data/preprocessed/image_meme_prefix_test.pt')
            torch.save(dev_, './data/preprocessed/image_meme_prefix_dev.pt')
            return

        if self.config.inferrence:
            res = trainer.inferrence(
                    '/aworkspace/data/save/base_meme_2_2023-12-07-15-43-07.pth.tar')
            print(res)
            return
        trainer.train()
        lines = trainer.lines

        df = pd.DataFrame(lines)
        print(df.to_string())


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--cuda_index', default=1)
    parser.add_argument('-r', '--reasoning', default='mprompt',
                        choices=['mprompt', 'extract'],
                        help='multimodal prompt or extract image feature')
    parser.add_argument('-z', '--zero_shot', action='store_true', default=False,
                        help='running under zero-shot mode or fine-tune mode')
    parser.add_argument('-d', '--data_name', default='laptops',
                        choices=['met', 'meme'],
                        help='semeval data name')
    parser.add_argument('-f', '--config', default='./config/config.yaml', help='config file')
    parser.add_argument('-m', '--multi_gpu', action='store_true', default=False, help='use multi gpu')
    parser.add_argument('-s', '--special', action='store_true', default=False, help='special mode')
    parser.add_argument('-us', '--use_sim', action='store_true', default=False, help='special mode')
    args = parser.parse_args()
    template = Template(args)
    template.forward()