import os import pickle import time import logging import torch import numpy as np import torch.nn as nn from tqdm import tqdm from sklearn.metrics import accuracy_score, f1_score from collections import defaultdict from src.utils import prompt_for_connection, prompt_for_metaphor_label, log_config, get_parameter_number class MMetaphorCOTTrainer: def __init__(self, model, config, train_loader, valid_loader, test_loader) -> None: self.model = model self.config = config self.train_loader, self.valid_loader, self.test_loader = train_loader, valid_loader, test_loader self.save_name = os.path.join(config.target_dir, config.save_name) self.final_score = 0 self.final_res = '' self.scores, self.lines = [], [] log_name = os.path.join(config.target_dir, 'COT_log' + time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) + '.log') logging.basicConfig(filename=log_name, filemode="w", format="\n%(levelname)s:%(message)s\n", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG) self.logger = logging log_config(self.config, self.logger) self.re_init() def train(self): best_score, best_iter = 0, -1 self.logger.info("START TRAINING...") print(self.config.optimizer.state_dict()['param_groups'][0]['lr']) parameters_num = get_parameter_number(self.model) print("parameters num:{}, trainable parameters num:{}".format(parameters_num['Total'], parameters_num['Trainable'])) for epoch in tqdm(range(self.config.epoch_size)): self.logger.info(f"TRAIN EPOCH: {epoch}") self.model.global_epoch = epoch self.global_epoch = epoch self.train_step() result = self.evaluate_step(mode='valid') self.logger.info('Epoch {}, Valid F1: {}, ACC: {}'.format(self.global_epoch, result['F1'], result['Acc'])) print('Epoch {}, Valid F1: {}, ACC: {}'.format(self.global_epoch, result['F1'], result['Acc'])) self.re_init() score = result['default'] self.add_instance(result) res = self.get_best() if score > best_score: best_score, best_iter = score, epoch save_name = self.save_name.format(epoch, time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())) if not os.path.exists(self.config.target_dir): os.makedirs(self.config.target_dir) torch.save({'epoch': epoch, 'model': self.model.cpu().state_dict()}, save_name) self.logger.info("***EPOCH: {} , best score update: {}".format(epoch, best_score)) self.model.to(self.config.device) elif epoch - best_iter > self.config.patience: self.logger.info("***Not upgrade for {} steps, early stopping...".format(self.config.patience)) # print("Not upgrade for {} steps, early stopping...".format(self.config.patience)) break self.model.to(self.config.device) # res = self.final_evaluate(best_score, best_iter) score = res['default'] self.add_instance(res) # save_name = self.save_name.format(epoch) self.final_score, self.final_res = score, res def prepare_step_two(self, outputs, data): tokenizer = self.model.module.tokenizer if self.config.multi_gpu else self.model.tokenizer context_ids = data['context_ids'] output_ids, output_masks = [data[w] for w in 'output_ids, output_masks'.strip().split(', ')] contexts = [tokenizer.decode(ids) for ids in context_ids] contexts = [context.replace('', '').replace('', '').strip() for context in contexts] new_prompts = [] new_contexts = [] for context, output in zip(contexts, outputs): # prompt = prompt_for_metaphor_label2(context, output) new_context, prompt = prompt_for_connection(context, output) new_prompts.append(prompt) new_contexts.append(new_context) batch_inputs = tokenizer.batch_encode_plus(new_prompts, padding=True, return_tensors='pt', max_length=self.config.max_length) batch_inputs = batch_inputs.data batch_contexts = self.model.tokenizer.batch_encode_plus(new_contexts, padding=True, return_tensors='pt', max_length=self.config.max_length) batch_contexts = batch_contexts.data res = { 'input_ids': batch_inputs['input_ids'], 'input_masks': batch_inputs['attention_mask'], 'output_ids': output_ids, 'output_masks': output_masks, 'context_ids': batch_contexts['input_ids'], 'image_prefix': data['image_prefix'], # 'image_ids': data['image_ids'], } res = {k: v.to(self.config.device) for k, v in res.items()} return res def prepare_step_three(self, outputs, pre_data, data): tokenizer = self.model.module.tokenizer if self.config.multi_gpu else self.model.tokenizer context_ids = pre_data['context_ids'] output_ids, output_masks = [data[w] for w in 'output_ids, output_masks'.strip().split(', ')] contexts = [tokenizer.decode(ids) for ids in context_ids] contexts = [context.replace('', '').replace('', '').strip() for context in contexts] new_prompts = [] new_contexts = [] for context, output in zip(contexts, outputs): prompt = prompt_for_metaphor_label(context, output) new_prompts.append(prompt) batch_inputs = tokenizer.batch_encode_plus(new_prompts, padding=True, return_tensors='pt', max_length=self.config.max_length) batch_inputs = batch_inputs.data res = { 'input_ids': batch_inputs['input_ids'], 'input_masks': batch_inputs['attention_mask'], 'output_ids': output_ids, 'output_masks': output_masks, 'image_prefix': data['image_prefix'], # 'image_ids': data['image_ids'], } res = {k: v.to(self.config.device) for k, v in res.items()} return res def prepare_train_step(self, outputs, data): tokenizer = self.model.module.tokenizer if self.config.multi_gpu else self.model.tokenizer context_ids = data['context_ids'] output_ids, output_masks = [data[w] for w in 'output_ids, output_masks'.strip().split(', ')] contexts = [tokenizer.decode(ids) for ids in context_ids] contexts = [context.replace('', '').replace('', '').strip() for context in contexts] new_prompts = [] for context, output in zip(contexts, outputs): prompt = context + f'Is this sample contains metaphor? Answer: [mask]' new_prompts.append(prompt) batch_inputs = tokenizer.batch_encode_plus(new_prompts, padding=True, return_tensors='pt', max_length=self.config.max_length) batch_inputs = batch_inputs.data res = { 'input_ids': batch_inputs['input_ids'], 'input_masks': batch_inputs['attention_mask'], 'output_ids': output_ids, 'output_masks': output_masks, 'image_prefix': data['image_prefix'], # 'image_ids': data['image_ids'], } res = {k: v.to(self.config.device) for k, v in res.items()} return res def train_step(self): self.model.train() train_data = tqdm(self.train_loader, total=self.train_loader.data_length, position=0) losses = [] for i, data in enumerate(train_data): step_one_inferred_output = self.model.module.generate(step_one=True, train=False, **data) \ if self.config.multi_gpu \ else self.model.generate(step_one=True, train=False, **data) self.logger.info(f"Step_one_inferred_output:{step_one_inferred_output}") # print("step one output:", step_one_inferred_output) step_one_inferred_data = self.prepare_step_two(step_one_inferred_output, data) # print("step two input:", [self.model.tokenizer.decode(ids) for ids in step_one_inferred_data['input_ids']]) step_two_inferred_output = self.model.generate(**step_one_inferred_data) # self.logger.info(f"Step_two_inferred_output:{step_two_inferred_output}") # print("step two output:", step_two_inferred_output) step_two_inferred_data = self.prepare_step_three(step_two_inferred_output, step_one_inferred_data, data) final_data = step_two_inferred_data # print("step three input:", [self.model.tokenizer.decode(ids) for ids in step_two_inferred_data['input_ids']]) loss = self.model(**final_data) losses.append(loss.item()) loss.backward() if (i + 1) % self.config.gradient_accumulation_steps == 0: nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) self.config.optimizer.step() self.config.scheduler.step() self.model.zero_grad() description = "Epoch {}, loss:{:.4f}".format(self.global_epoch, np.mean(losses)) train_data.set_description(description) def evaluate_step(self, dataLoader=None, mode='valid'): self.model.eval() dataLoader = self.valid_loader if dataLoader is None else dataLoader dataiter = dataLoader for i, data in tqdm(enumerate(dataiter), total=dataLoader.data_length, position=0): with torch.no_grad(): step_one_inferred_output = self.model.module.generate(step_one=True, train=False, **data) \ if self.config.multi_gpu \ else self.model.generate(step_one=True, train=False, **data) self.logger.info(f"Step_one_inferred_output: {step_one_inferred_output}") step_one_inferred_data = self.prepare_step_two(step_one_inferred_output, data) step_two_inferred_output = self.model.generate(**step_one_inferred_data) # self.logger.info(f"Step_two_inferred_output: {step_two_inferred_output}") step_two_inferred_data = self.prepare_step_three(step_two_inferred_output, step_one_inferred_data, data) # output = self.model.evaluate(**step_two_inferred_data) output = self.model.evaluate(**step_two_inferred_data) self.add_output(data, output) result = self.report_score(mode=mode) return result def inferrence(self, path=''): self.model.load_state_dict(torch.load(path, map_location=self.config.device)['model']) self.model.eval() res = self.evaluate_step(self.test_loader, mode='test') self.add_instance(res) # self.report_score() return res def add_instance(self, res): self.lines.append(res) def get_best(self): best_id = np.argmax([w['default'] for w in self.lines]) res = self.lines[best_id] return res def re_init(self): self.preds, self.golds = defaultdict(list), defaultdict(list) self.keys = ['total'] def add_output(self, data, output): gold = data['input_labels'] # print('output:', output) if self.config.model_path.startswith('google'): self.preds['total'] += output else: self.preds['total'] += output.cpu() self.golds['total'] += gold.tolist() self.golds['id'] += data['ids'] def report_score(self, mode='valid'): res = {} res['Acc'] = accuracy_score(self.golds['total'], self.preds['total']) res['F1'] = f1_score(self.golds['total'], self.preds['total'], labels=[0, 1], average='macro') res['default'] = res['F1'] res['mode'] = mode return res class MetaphorLLaVATrainer: def __init__(self, model, config, train_loader, valid_loader, test_loader) -> None: self.model = model self.config = config self.train_loader, self.valid_loader, self.test_loader = train_loader, valid_loader, test_loader self.save_name = os.path.join(config.target_dir, config.save_name) self.final_score = 0 self.final_res = '' log_name = os.path.join(config.target_dir, 'COT_log' + time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) + '.log') logging.basicConfig(filename=log_name, filemode="w", format="\n%(levelname)s:%(message)s\n", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG) self.logger = logging log_config(self.config, self.logger) self.scores, self.lines = [], [] self.re_init() def train(self): best_score, best_iter = 0, -1 print("train epoch size:", self.config.epoch_size, ' lr:', self.config.bert_lr) parameters_num = get_parameter_number(self.model) print("parameters num:{}, trainable parameters num:{}".format(parameters_num['Total'], parameters_num['Trainable'])) for epoch in tqdm(range(self.config.epoch_size)): self.model.global_epoch = epoch self.global_epoch = epoch self.train_step() result = self.evaluate_step(mode='valid') print('Epoch {}, Valid F1: {}, ACC: {}'.format(self.global_epoch, result['F1'], result['Acc'])) self.re_init() score = result['default'] self.add_instance(result) res = self.get_best() if score > best_score: best_score, best_iter = score, epoch save_name = self.save_name.format(epoch, time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())) if not os.path.exists(self.config.target_dir): os.makedirs(self.config.target_dir) torch.save({'epoch': epoch, 'model': self.model.cpu().state_dict()}, save_name) self.model.to(self.config.device) elif epoch - best_iter > self.config.patience: print("Not upgrade for {} steps, early stopping...".format(self.config.patience)) break self.model.to(self.config.device) # res = self.final_evaluate(best_iter) # score = res['default'] # self.add_instance(res) # save_name = self.save_name.format(epoch) # self.final_score, self.final_res = score, res score = res['default'] self.add_instance(res) # save_name = self.save_name.format(epoch) self.final_score, self.final_res = score, res def train_step(self): self.model.train() train_data = tqdm(self.train_loader, position=0) losses = [] for i, data in enumerate(train_data): loss = self.model(**data) losses.append(loss.item()) loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) description = "Epoch {}, loss:{:.4f}".format(self.global_epoch, np.mean(losses)) train_data.set_description(description) self.config.optimizer.step() self.config.scheduler.step() self.model.zero_grad() def evaluate_step(self, dataLoader=None, mode='valid'): self.model.eval() dataLoader = self.valid_loader if dataLoader is None else dataLoader dataiter = dataLoader for i, data in tqdm(enumerate(dataiter), total=dataLoader.data_length, position=0): with torch.no_grad(): output = self.model.evaluate(**data) self.add_output(data, output) result = self.report_score(mode=mode) return result def inferrence(self, path=''): self.model.load_state_dict(torch.load(path, map_location=self.config.device)['model']) self.model.eval() res = self.evaluate_step(self.test_loader, mode='test') self.add_instance(res) # self.report_score() return res def add_instance(self, res): self.lines.append(res) def get_best(self): best_id = np.argmax([w['default'] for w in self.lines]) res = self.lines[best_id] return res def re_init(self): self.preds, self.golds = defaultdict(list), defaultdict(list) self.ids = defaultdict(list) self.keys = ['total'] def add_output(self, data, output): gold = data['input_labels'] # print('output:', output) if self.config.model_path.startswith('google'): self.preds['total'] += output else: self.preds['total'] += output.cpu() self.golds['total'] += gold.tolist() self.preds['id'] += data['id'] def report_score(self, mode='valid'): res = {} res['Acc'] = accuracy_score(self.golds['total'], self.preds['total']) res['F1'] = f1_score(self.golds['total'], self.preds['total'], labels=[0, 1], average='macro') res['default'] = res['F1'] res['mode'] = mode for k, v in res.items(): if isinstance(v, float): res[k] = round(v * 100, 3) return res class ImageTrainer: def __init__(self, model, config, train_loader, valid_loader, test_loader) -> None: self.model = model self.config = config self.train_loader, self.valid_loader, self.test_loader = train_loader, valid_loader, test_loader self.save_name = os.path.join(config.target_dir, config.save_name) self.final_score = 0 self.final_res = '' self.scores, self.lines = [], [] def train(self): for epoch in tqdm(range(self.config.epoch_size)): self.model.global_epoch = epoch self.global_epoch = epoch self.train_step() result = self.evaluate_step(mode='valid') def train_step(self): self.model.train() train_data = tqdm(self.train_loader, ncols=10) losses = [] for data in train_data: loss = self.model(**data) losses.append(loss.item()) loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) description = "Epoch {}, loss:{:.4f}".format(self.global_epoch, np.mean(losses)) train_data.set_description(description) self.config.optimizer.step() self.config.scheduler.step() self.model.zero_grad() def evaluate_step(self, dataLoader=None, mode='valid'): self.model.eval() dataLoader = self.valid_loader if dataLoader is None else dataLoader dataiter = dataLoader ip = {} for i, data in tqdm(enumerate(dataiter), total=dataLoader.data_length): # print(data) with torch.no_grad(): output = self.model(**data) for i, id in enumerate(data['ids']): ip[id] = output[i] if mode == 'train': data = pickle.load(open('./data/meme_train.pkl', 'rb')) name = './data/meme_train.pkl' elif mode == 'dev': data = pickle.load(open('./data/meme_dev.pkl', 'rb')) name = './data/meme_dev.pkl' else: data = pickle.load(open('./data/meme_test.pkl', 'rb')) name = './data/meme_test.pkl' for id in data['id']: print(ip[id]) data['image_prefix'].append(ip[id]) with open(name, 'wb') as f: pickle.dump(data, f) return