# -*- coding: utf-8 -*- """ Created on Mon Nov 9 15:03:49 2020 @author: DrLC """ import argparse import os from rnn import RNNClassifier from dataset import ClassifierDataset, classifier_collate_fn import torch import torch.nn as nn from torch import optim from torch.utils.data import DataLoader, RandomSampler def trainEpochs(epochs, training_dataloader, dev_dataloader, classifier, criterion, optimizer, batch_size=32, batch_size_eval=64, print_each=100, saving_path='./', lrdecay=1): n_batch = int(len(training_dataloader.dataset) / batch_size) for _ep in range(1, 1+epochs): classifier.train() print_loss_total = 0 print('Start epoch ' + str(_ep) + '....') for _iter, batch in enumerate(training_dataloader): inputs, labels = batch labels = labels.to(classifier.device) optimizer.zero_grad() logits = classifier.run(inputs, batch_size, False) loss = criterion(logits, labels) loss.backward() optimizer.step() print_loss_total += loss.item() if (_iter + 1) % print_each == 0: print_loss_avg = print_loss_total / print_each print_loss_total = 0 print('\tEp %d %d/%d, loss = %.6f' \ % (_ep, (_iter + 1), n_batch, print_loss_avg)) evaluate(classifier, dev_dataloader, batch_size_eval) torch.save(classifier.state_dict(), os.path.join(saving_path, str(_ep)+'.pt')) if lrdecay < 1: adjust_learning_rate(optimizer, lrdecay) def adjust_learning_rate(optimizer, decay_rate=0.8): for param_group in optimizer.param_groups: param_group['lr'] = param_group['lr'] * decay_rate def evaluate(classifier, dataloader, batch_size=128): classifier.eval() testnum = 0 testcorrect = 0 for _iter, batch in enumerate(dataloader): inputs, labels = batch labels = labels.to(classifier.device) outputs = classifier(inputs, batch_size, True) res = torch.argmax(outputs, dim=1) == labels testcorrect += torch.sum(res) testnum += len(labels) print('eval_acc: %.2f%%' % (float(testcorrect) * 100.0 / testnum)) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--gpu', type=str, default='-1') parser.add_argument('--lr', type=float, default=0.003) parser.add_argument('--epoch', type=int, default=30) parser.add_argument('--bs', type=int, default=32) parser.add_argument('--bs_eval', type=int, default=100) parser.add_argument('--l2p', type=float, default=0) parser.add_argument('--dropout', type=float, default=0.1) parser.add_argument('--lrdecay', type=float, default=0.9) parser.add_argument('--n_dev', type=int, default=1000) parser.add_argument('--trainset', type=str, default='../data/train.pkl') parser.add_argument('--testset', type=str, default='../data/test.pkl') parser.add_argument('--save_dir', type=str, default='../model/oj_lstm') parser.add_argument('--tokenizer', type=str, default="/var/data/lushuai/bertvsbert/save/only_mlm/poj-classifier/checkpoint-51000-0.986") parser.add_argument('--word2vec', type=str, default="../data/w2v.model") parser.add_argument('--eval', type=bool, default=False) opt = parser.parse_args() _eval = opt.eval _trainset_path = opt.trainset _testset_path = opt.testset _save = opt.save_dir if _eval: _save = os.path.join(_save, "model.pt") _tokenizer_path = opt.tokenizer _w2v_path = opt.word2vec _drop = opt.dropout _lr = opt.lr _l2p = opt.l2p _lrdecay = opt.lrdecay _bs = opt.bs _bs_eval = opt.bs_eval _ep = opt.epoch _n_dev = opt.n_dev if int(opt.gpu) < 0: device = torch.device("cpu") else: device = torch.device("cuda") os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu model = "LSTM" hidden_size = 600 n_layers = 2 n_class = 104 max_len = 512 attn = True bidirection = True if _eval: test_set = ClassifierDataset(file_path=_testset_path, file_type='test', max_len=max_len) test_sampler = RandomSampler(test_set) test_dataloader = DataLoader(dataset=test_set, sampler=test_sampler, batch_size=_bs_eval, drop_last=False, collate_fn=classifier_collate_fn) else: training_set = ClassifierDataset(file_path=_trainset_path, file_type='train', n_dev=_n_dev, max_len=max_len) training_sampler = RandomSampler(training_set) training_dataloader = DataLoader(dataset=training_set, sampler=training_sampler, batch_size=_bs, drop_last=True, collate_fn=classifier_collate_fn) dev_set = ClassifierDataset(file_path=_trainset_path, file_type='dev', n_dev=_n_dev, max_len=max_len) dev_sampler = RandomSampler(dev_set) dev_dataloader = DataLoader(dataset=dev_set, sampler=dev_sampler, batch_size=_bs_eval, drop_last=True, collate_fn=classifier_collate_fn) classifier = RNNClassifier(num_class=n_class, hidden_dim=hidden_size, n_layers=n_layers, tokenizer_path=_tokenizer_path, w2v_path=_w2v_path, max_len=max_len, drop_prob=_drop, model=model, brnn=bidirection, attn=attn, device=device).to(device) if _eval: classifier.load_state_dict(torch.load(_save)) classifier.eval() evaluate(classifier, test_dataloader, _bs_eval) else: classifier.train() optimizer = optim.Adam(classifier.parameters(), lr=_lr, weight_decay=_l2p) criterion = nn.CrossEntropyLoss() trainEpochs(_ep, training_dataloader, dev_dataloader, classifier, criterion, optimizer, saving_path=_save, batch_size=_bs, batch_size_eval=_bs_eval, lrdecay=_lrdecay)