#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @Time : 2023/12/5 20:34
# @Author : yebulk
import os
import numpy as np
os.environ['CUDA_VISIBLE_DEVICES'] = "1"
import torch
import re
import json
import argparse
import random
from transformers import BertTokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, \
T5ForConditionalGeneration, AutoModelForSeq2SeqLM
from BartModel import BartForMultimodalGeneration
from utils_data import img_shape, load_data_std, load_data_img, MetaphorDatasetStd, MeatphorDatasetImg, split_chinese_string
from utils_prompt import *
from utils_evaluate import get_scores
from rich.table import Column, Table
from rich import box
from rich.console import Console
from bert_score import score
console = Console(record=True)
import nltk
nltk.download('punkt')
import evaluate
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data_root', type=str, default='../Data/')
parser.add_argument('--output_dir', type=str, default='experiments—Ours-notexts')
parser.add_argument('--model', type=str, default='../modals/bart-large-chinese')
# parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
parser.add_argument('--epoch', type=int, default=8)
parser.add_argument('--lr', type=float, default=5e-6)
parser.add_argument('--bs', type=int, default=16)
parser.add_argument('--input_len', type=int, default=16)
parser.add_argument('--output_len', type=int, default=16)
parser.add_argument('--eval_bs', type=int, default=16)
parser.add_argument('--eval_acc', type=int, default=None, help='evaluate accumulation step')
parser.add_argument('--train_split', type=str, default='train', choices=['train', 'debug_train_set', 'minitrain'])
parser.add_argument('--val_split', type=str, default='val', choices=['test', 'val', 'minival'])
parser.add_argument('--test_split', type=str, default='test', choices=['test', 'minitest'])
parser.add_argument('--use_generate', action='store_true', help='only for baseline to improve inference speed')
parser.add_argument('--final_eval', action='store_true', help='only evaluate the model at the final epoch')
parser.add_argument('--user_msg', type=str, default="source", choices=['target', 'source'], help='experiment type in the save_dir')
parser.add_argument('--img_type', type=str, default='vit', choices=['detr', 'clip', 'resnet', 'vit'],
help='type of image features')
parser.add_argument('--eval_le', type=str, default=None, help='generated rationale for the dev set')
parser.add_argument('--test_le', type=str, default=None, help='generated rationale for the test set')
parser.add_argument('--evaluate_dir', type=str, default=None, help='the directory of model for evaluation')
parser.add_argument('--caption_file', type=str, default='data/instruct_captions.json')
parser.add_argument('--use_caption', action='store_true', help='use image captions or not')
parser.add_argument('--prompt_format', type=str, default="Stage_two", help='prompt format template',
choices=['Stage_One', 'Stage_two'])
parser.add_argument('--seed', type=int, default=42, help='random seed')
args = parser.parse_args()
return args
def BARTTrainer(
dataframe, args,
):
torch.manual_seed(args.seed) # pytorch random seed
np.random.seed(args.seed) # numpy random seed
torch.backends.cudnn.deterministic = True
if args.evaluate_dir is not None:
args.model = args.evaluate_dir
tokenizer = BertTokenizer.from_pretrained(args.model)
console.log(f"""[Model]: {args.model}...\n""")
console.log(f"[Data]: Reading data...\n")
problems = dataframe['problems']
qids = dataframe['qids']
train_qids = qids['train']
test_qids = qids['test']
val_qids = qids['val']
if args.evaluate_dir is not None:
save_dir = args.evaluate_dir
else:
model_name = args.model.replace("/", "-")
gpu_count = torch.cuda.device_count()
save_dir = f"{args.output_dir}/{args.user_msg}_{model_name}_{args.img_type}_lr{args.lr}_bs{args.bs * gpu_count}_op{args.output_len}_ep{args.epoch}"
if not os.path.exists(save_dir):
os.mkdir(save_dir)
print(save_dir)
if args.img_type is not None:
patch_size = img_shape[args.img_type]
model = BartForMultimodalGeneration.from_pretrained(args.model, patch_size=patch_size)
name_maps = dataframe['name_maps']
image_features = dataframe['image_features']
train_set = MeatphorDatasetImg(
problems,
train_qids,
name_maps,
tokenizer,
args.input_len,
args.output_len,
args,
image_features,
) # 4886
eval_set = MeatphorDatasetImg(
problems,
val_qids,
name_maps,
tokenizer,
args.input_len,
args.output_len,
args,
image_features,
args.eval_le,
) # 610
test_set = MeatphorDatasetImg(
problems,
test_qids,
name_maps,
tokenizer,
args.input_len,
args.output_len,
args,
image_features,
args.test_le,
)# 612
else:
model = AutoModelForSeq2SeqLM.from_pretrained(args.model)
train_set = MetaphorDatasetStd(
problems,
train_qids,
tokenizer,
args.input_len,
args.output_len,
args,
)
eval_set = MetaphorDatasetStd(
problems,
val_qids,
tokenizer,
args.input_len,
args.output_len,
args,
args.eval_le,
)
test_set = MetaphorDatasetStd(
problems,
test_qids,
tokenizer,
args.input_len,
args.output_len,
args,
args.test_le,
)
datacollator = DataCollatorForSeq2Seq(tokenizer)
print("model parameters: ", model.num_parameters())
#
# # accuracy for answer inference
#
# def compute_metrics_acc(eval_preds):
# if args.use_generate:
# preds, targets = eval_preds
# if isinstance(preds, tuple):
# preds = preds[0]
# else:
# preds = eval_preds.predictions[0]
# targets = eval_preds.label_ids
# preds = preds.argmax(axis=2)
# preds = tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
# targets = tokenizer.batch_decode(targets, skip_special_tokens=True, clean_up_tokenization_spaces=True)
# correct = 0
# assert len(preds) == len(targets)
# for idx, pred in enumerate(preds):
# reference = targets[idx]
# reference = extract_ans(reference)
# extract_pred = extract_ans(pred)
# best_option = extract_pred
# if reference == best_option:
# correct += 1
# return {'accuracy': 1.0 * correct / len(targets)}
#
# # rougel for rationale generation
metric = evaluate.load("/mnt/sdb1/jsy/yjw/prompt/my-cot-metaphor/metrics/rouge")
#
#
def compute_metrics_rougel(eval_preds):
if args.use_generate:
preds, targets = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
else:
preds = eval_preds.predictions[0]
# print("preds", preds)
targets = eval_preds.label_ids
preds = preds.argmax(axis=2)
preds = tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
targets = tokenizer.batch_decode(targets, skip_special_tokens=True, clean_up_tokenization_spaces=True)
decoded_preds = split_chinese_string(preds)
decoded_labels = split_chinese_string(targets)
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
result = {k: round(v * 100, 4) for k, v in result.items()}
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
result["gen_len"] = np.mean(prediction_lens)
return result
def compute_metrics_Bertscore(eval_preds):
if args.use_generate:
preds, targets = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
else:
preds = eval_preds.predictions[0]
targets = eval_preds.label_ids
preds = preds.argmax(axis=2)
preds = tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
targets = tokenizer.batch_decode(targets, skip_special_tokens=True, clean_up_tokenization_spaces=True)
cands = preds
refs = targets
P, R, F1 = score(cands, refs, model_type="bert-base-chinese", lang="zh", verbose=True)
print(F1)
print(f"System level F1 score: {F1.mean():.3f}")
# 构建包含P、R和F1的字典
result_dict = {
"P": P.mean().item(),
"R": R.mean().item(),
"F1": F1.mean().item()
}
return result_dict
# #
# # only use the last model for evaluation to save time
if args.final_eval:
training_args = Seq2SeqTrainingArguments(
save_dir,
do_train=True if args.evaluate_dir is None else False,
do_eval=False,
evaluation_strategy="no",
logging_strategy="steps",
save_strategy="epoch",
save_total_limit=2,
learning_rate=args.lr,
eval_accumulation_steps=args.eval_acc,
per_device_train_batch_size=args.bs,
per_device_eval_batch_size=args.eval_bs,
weight_decay=0.01,
num_train_epochs=args.epoch,
predict_with_generate=args.use_generate,
generation_max_length=args.output_len,
report_to="none",
)
# evaluate at each epoch
else:
training_args = Seq2SeqTrainingArguments(
save_dir,
do_train=True if args.evaluate_dir is None else False,
do_eval=True,
evaluation_strategy="epoch",
logging_strategy="steps",
save_strategy="epoch",
save_total_limit=2,
learning_rate=args.lr,
eval_accumulation_steps=args.eval_acc,
per_device_train_batch_size=args.bs,
per_device_eval_batch_size=args.eval_bs,
weight_decay=0.01,
num_train_epochs=args.epoch,
metric_for_best_model="F1",
# metric_for_best_model="accuracy" if args.prompt_format == "QCMG-A" or args.prompt_format == "QCM-A" else "rougeL",
predict_with_generate=args.use_generate,
generation_max_length=args.output_len,
load_best_model_at_end=True,
report_to="none",
)
# #
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_set,
eval_dataset=eval_set,
data_collator=datacollator,
tokenizer=tokenizer,
compute_metrics=compute_metrics_Bertscore
# compute_metrics=compute_metrics_rougel
# compute_metrics=compute_metrics_acc if args.prompt_format == "QCMG-A" or args.prompt_format == "QCM-A" else compute_metrics_rougel
)
# #
if args.evaluate_dir is None:
trainer.train()
trainer.save_model(save_dir)
print("save_dir", save_dir)
metrics = trainer.evaluate(eval_dataset=test_set, max_length=args.output_len)
trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)
# #
predict_results = trainer.predict(test_dataset=test_set, max_length=args.output_len)
if trainer.is_world_process_zero():
if args.use_generate:
preds, targets = predict_results.predictions, predict_results.label_ids
else:
preds = predict_results.predictions[0]
targets = predict_results.label_ids
preds = preds.argmax(axis=2)
preds = tokenizer.batch_decode(
preds, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
targets = tokenizer.batch_decode(
targets, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
results_ans = {}
results_rationale = {}
results_reference = {}
num_fail = 0
for idx, qid in enumerate(test_qids):
pred = preds[int(idx)]
ref = targets[int(idx)]
results_rationale[str(qid)] = pred
results_reference[str(qid)] = ref
scores = get_scores(results_ans, results_rationale, results_reference,
os.path.join(args.data_root, "scienceqa/problems.json"))
preds = [pred.strip() for pred in preds]
output_data = {
"num_fail": num_fail,
"scores": scores,
"preds": preds,
"labels": targets}
output_prediction_file = os.path.join(save_dir, "predictions_ans_test.json")
with open(output_prediction_file, "w" ,encoding="utf-8") as writer:
writer.write(json.dumps(output_data, indent=4, ensure_ascii=False))
# generate the rationale for the eval set
# if args.prompt_format == "QCM-LE" or args.prompt_format == "QCM-E":
torch.cuda.empty_cache()
del predict_results, preds, targets
predict_results = trainer.predict(test_dataset=eval_set, max_length=args.output_len)
if trainer.is_world_process_zero():
if args.use_generate:
preds, targets = predict_results.predictions, predict_results.label_ids
else:
preds = predict_results.predictions[0]
targets = predict_results.label_ids
preds = preds.argmax(axis=2)
preds = tokenizer.batch_decode(
preds, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
targets = tokenizer.batch_decode(
targets, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
preds = [pred.strip() for pred in preds]
output_data = {"preds": preds,
"labels": targets}
output_prediction_file = os.path.join(save_dir, "predictions_ans_eval.json")
with open(output_prediction_file, "w", encoding="utf-8") as writer:
writer.write(json.dumps(output_data, indent=4, ensure_ascii=False))
if __name__ == '__main__':
# training logger to log training progress
training_logger = Table(
Column("Epoch", justify="center"),
Column("Steps", justify="center"),
Column("Loss", justify="center"),
title="Training Status",
pad_edge=False,
box=box.ASCII,
)
args = parse_args()
print("args", args)
print('====Input Arguments====')
print(json.dumps(vars(args), indent=2, sort_keys=False))
random.seed(args.seed)
if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir)
if args.img_type is not None:
problems, qids, name_maps, image_features = load_data_img(args) # probelms, test question ids, shot example ids
dataframe = {'problems': problems, 'qids': qids, 'name_maps': name_maps, 'image_features': image_features}
else:
problems, qids = load_data_std(args) # probelms, test question ids, shot example ids
dataframe = {'problems': problems, 'qids': qids}
BARTTrainer(
dataframe=dataframe,
args=args
)