# -*- coding: utf-8 -*- """ Created on Mon Nov 9 15:03:49 2020 @author: DrLC """ import argparse import os from dataset import CloneDataset, clone_collate_fn import tqdm import torch import torch.nn as nn from torch import optim from torch.utils.data import DataLoader, RandomSampler from sklearn.metrics import f1_score def trainEpochs(epochs, training_dataloader, dev_dataloader, classifier, criterion, optimizer, batch_size=32, batch_size_eval=64, print_each=1000, saving_path='./', lrdecay=1): n_batch = int(len(training_dataloader.dataset) / batch_size) for _ep in range(1, 1+epochs): classifier.train() print_loss_total = 0 print('Start epoch ' + str(_ep) + '....') for _iter, batch in enumerate(training_dataloader): inputs, labels = batch labels = labels.to(classifier.device) optimizer.zero_grad() logits = classifier.run(inputs, batch_size, False) loss = criterion(logits, labels) loss.backward() optimizer.step() print_loss_total += loss.item() if (_iter + 1) % print_each == 0: print_loss_avg = print_loss_total / print_each print_loss_total = 0 print('\tEp %d %d/%d, loss = %.6f' \ % (_ep, (_iter + 1), n_batch, print_loss_avg), flush=True) print (" eval...") evaluate(classifier, dev_dataloader, batch_size_eval) torch.save(classifier.state_dict(), os.path.join(saving_path, str(_ep)+'.pt')) if lrdecay < 1: adjust_learning_rate(optimizer, lrdecay) def adjust_learning_rate(optimizer, decay_rate=0.8): for param_group in optimizer.param_groups: param_group['lr'] = param_group['lr'] * decay_rate def evaluate(classifier, dataloader, batch_size=128): classifier.eval() testnum = 0 testcorrect = 0 y_true = [] y_pred = [] for _iter, batch in enumerate(tqdm.tqdm(dataloader)): inputs, labels = batch labels = labels.to(classifier.device) outputs = classifier(inputs, batch_size, True) preds = torch.argmax(outputs, dim=1) res = preds == labels testcorrect += torch.sum(res) testnum += len(labels) y_true += list(labels.cpu().numpy()) y_pred += list(preds.cpu().numpy()) print('Eval_acc: %.2f%%' % (float(testcorrect) * 100.0 / testnum)) print('Eval_f1: %.2f%%' % (f1_score(y_true, y_pred, average='binary') * 100), flush=True) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--gpu', type=str, default='-1') parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--bs', type=int, default=128) parser.add_argument('--bs_eval', type=int, default=256) parser.add_argument('--l2p', type=float, default=0) parser.add_argument('--dropout', type=float, default=0.3) parser.add_argument('--lrdecay', type=float, default=0.99) parser.add_argument('--data_path', type=str, default="/var/data/lushuai/bertvsbert/data/ojclone/ojclone_norm.jsonl") parser.add_argument('--trainset', type=str, default="/var/data/lushuai/bertvsbert/data/ojclone/train.txt") parser.add_argument('--validset', type=str, default="/var/data/lushuai/bertvsbert/data/ojclone/valid.txt") parser.add_argument('--testset', type=str, default="/var/data/lushuai/bertvsbert/data/ojclone/test.txt") parser.add_argument('--save_dir', type=str, default='../model/ojclone_lstm') parser.add_argument('--tokenizer', type=str, default="/var/data/zhanghz/codebert-base-mlm") parser.add_argument('--word2vec', type=str, default="../data/ojclone_w2v.model") parser.add_argument('--valid_ratio', type=float, default=0.2) parser.add_argument('--eval', type=bool, default=False) opt = parser.parse_args() _eval = opt.eval _src_path = opt.data_path _trainset_path = opt.trainset _validset_path = opt.validset _valid_ratio = opt.valid_ratio _testset_path = opt.testset _save = opt.save_dir if _eval: _save = os.path.join(_save, "model.pt") _tokenizer_path = opt.tokenizer _w2v_path = opt.word2vec _drop = opt.dropout _lr = opt.lr _l2p = opt.l2p _lrdecay = opt.lrdecay _bs = opt.bs _bs_eval = opt.bs_eval _ep = opt.epoch if int(opt.gpu) < 0: device = torch.device("cpu") else: device = torch.device("cuda") os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu from rnn import RNNClone model = "LSTM" hidden_size = 600 n_layers = 2 n_class = 2 max_len = 400 attn = True bidirection = True if _eval: test_set = CloneDataset(src_path=_src_path, file_path=_testset_path, file_type='test', max_len=max_len, keep_ratio=_valid_ratio) test_sampler = RandomSampler(test_set) test_dataloader = DataLoader(dataset=test_set, sampler=test_sampler, batch_size=_bs_eval, drop_last=False, collate_fn=clone_collate_fn) else: training_set = CloneDataset(src_path=_src_path, file_path=_trainset_path, file_type='train', max_len=max_len) training_sampler = RandomSampler(training_set) training_dataloader = DataLoader(dataset=training_set, sampler=training_sampler, batch_size=_bs, drop_last=True, collate_fn=clone_collate_fn) dev_set = CloneDataset(src_path=_src_path, file_path=_validset_path, file_type='valid', max_len=max_len, keep_ratio=_valid_ratio) dev_sampler = RandomSampler(dev_set) dev_dataloader = DataLoader(dataset=dev_set, sampler=dev_sampler, batch_size=_bs_eval, drop_last=True, collate_fn=clone_collate_fn) classifier = RNNClone(num_class=n_class, hidden_dim=hidden_size, n_layers=n_layers, tokenizer_path=_tokenizer_path, w2v_path=_w2v_path, max_len=max_len, drop_prob=_drop, model=model, brnn=bidirection, attn=attn, device=device).to(device) if _eval: classifier.load_state_dict(torch.load(_save)) classifier.eval() evaluate(classifier, test_dataloader, _bs_eval) else: classifier.train() optimizer = optim.Adam(classifier.parameters(), lr=_lr, weight_decay=_l2p) criterion = nn.CrossEntropyLoss() trainEpochs(_ep, training_dataloader, dev_dataloader, classifier, criterion, optimizer, saving_path=_save, batch_size=_bs, batch_size_eval=_bs_eval, lrdecay=_lrdecay)