CodeBERT-Attack / rnn / run_rnn.py
run_rnn.py
Raw
# -*- coding: utf-8 -*-
"""
Created on Mon Nov  9 15:03:49 2020

@author: DrLC
"""

import argparse
import os
from rnn import RNNClassifier
from dataset import ClassifierDataset, classifier_collate_fn

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, RandomSampler
    
def trainEpochs(epochs, training_dataloader, dev_dataloader, classifier,
                criterion, optimizer,
                batch_size=32, batch_size_eval=64, print_each=100,
                saving_path='./', lrdecay=1):
    
    n_batch = int(len(training_dataloader.dataset) / batch_size)
    
    for _ep in range(1, 1+epochs):
        classifier.train()
        print_loss_total = 0
        print('Start epoch ' + str(_ep) + '....')
        for _iter, batch in enumerate(training_dataloader):
            inputs, labels = batch
            labels = labels.to(classifier.device)
            optimizer.zero_grad()
            logits = classifier.run(inputs, batch_size, False)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            print_loss_total += loss.item()
            if (_iter + 1) % print_each == 0: 
                print_loss_avg = print_loss_total / print_each
                print_loss_total = 0
                print('\tEp %d %d/%d, loss = %.6f' \
                      % (_ep, (_iter + 1), n_batch, print_loss_avg))
        
        evaluate(classifier, dev_dataloader, batch_size_eval)
        torch.save(classifier.state_dict(),
                   os.path.join(saving_path, str(_ep)+'.pt'))
            
        if lrdecay < 1:
            adjust_learning_rate(optimizer, lrdecay)
 

def adjust_learning_rate(optimizer, decay_rate=0.8):
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr'] * decay_rate
                
def evaluate(classifier, dataloader, batch_size=128):
    
    classifier.eval()
    testnum = 0
    testcorrect = 0
    
    for _iter, batch in enumerate(dataloader):
        
        inputs, labels = batch
        labels = labels.to(classifier.device)
        outputs = classifier(inputs, batch_size, True)
        res = torch.argmax(outputs, dim=1) == labels
        testcorrect += torch.sum(res)
        testnum += len(labels)

    print('eval_acc:  %.2f%%' % (float(testcorrect) * 100.0 / testnum))
    
if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', type=str, default='-1')
    parser.add_argument('--lr', type=float, default=0.003)
    parser.add_argument('--epoch', type=int, default=30)
    parser.add_argument('--bs', type=int, default=32)
    parser.add_argument('--bs_eval', type=int, default=100)
    parser.add_argument('--l2p', type=float, default=0)
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--lrdecay', type=float, default=0.9)
    parser.add_argument('--n_dev', type=int, default=1000)
    parser.add_argument('--trainset', type=str, default='../data/train.pkl')
    parser.add_argument('--testset', type=str, default='../data/test.pkl')
    parser.add_argument('--save_dir', type=str, default='../model/oj_lstm')
    parser.add_argument('--tokenizer', type=str,
                        default="/var/data/lushuai/bertvsbert/save/only_mlm/poj-classifier/checkpoint-51000-0.986")
    parser.add_argument('--word2vec', type=str, default="../data/w2v.model")
    parser.add_argument('--eval', type=bool, default=False)
    
    opt = parser.parse_args()
    
    _eval = opt.eval
    
    _trainset_path = opt.trainset
    _testset_path = opt.testset
    _save = opt.save_dir
    if _eval:
        _save = os.path.join(_save, "model.pt")
    _tokenizer_path = opt.tokenizer
    _w2v_path = opt.word2vec
    
    _drop = opt.dropout
    _lr = opt.lr
    _l2p = opt.l2p
    _lrdecay = opt.lrdecay
    
    _bs = opt.bs
    _bs_eval = opt.bs_eval
    _ep = opt.epoch
    _n_dev = opt.n_dev
    
    if int(opt.gpu) < 0:
        device = torch.device("cpu")
    else:
        device = torch.device("cuda")
        os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu

    model = "LSTM"
    hidden_size = 600
    n_layers = 2
    n_class = 104
    max_len = 512
    attn = True
    bidirection = True

    if _eval:
        test_set = ClassifierDataset(file_path=_testset_path,
                                     file_type='test',
                                     max_len=max_len)
        test_sampler = RandomSampler(test_set)
        test_dataloader = DataLoader(dataset=test_set,
                                     sampler=test_sampler,
                                     batch_size=_bs_eval,
                                     drop_last=False,
                                     collate_fn=classifier_collate_fn)
    else:
        training_set = ClassifierDataset(file_path=_trainset_path,
                                         file_type='train',
                                         n_dev=_n_dev,
                                         max_len=max_len)
        training_sampler = RandomSampler(training_set)
        training_dataloader = DataLoader(dataset=training_set,
                                         sampler=training_sampler,
                                         batch_size=_bs,
                                         drop_last=True,
                                         collate_fn=classifier_collate_fn)
        dev_set = ClassifierDataset(file_path=_trainset_path,
                                    file_type='dev',
                                    n_dev=_n_dev,
                                    max_len=max_len)
        dev_sampler = RandomSampler(dev_set)
        dev_dataloader = DataLoader(dataset=dev_set,
                                    sampler=dev_sampler,
                                    batch_size=_bs_eval,
                                    drop_last=True,
                                    collate_fn=classifier_collate_fn)

    classifier = RNNClassifier(num_class=n_class,
                               hidden_dim=hidden_size,
                               n_layers=n_layers,
                               tokenizer_path=_tokenizer_path,
                               w2v_path=_w2v_path,
                               max_len=max_len,
                               drop_prob=_drop,
                               model=model,
                               brnn=bidirection,
                               attn=attn,
                               device=device).to(device)
    
    if _eval:
        classifier.load_state_dict(torch.load(_save))
        classifier.eval()
        evaluate(classifier, test_dataloader, _bs_eval)
    else:
        classifier.train()
        optimizer = optim.Adam(classifier.parameters(), lr=_lr, weight_decay=_l2p) 
        criterion = nn.CrossEntropyLoss()
        
        trainEpochs(_ep, training_dataloader, dev_dataloader, classifier, criterion, optimizer,
                    saving_path=_save, batch_size=_bs, batch_size_eval=_bs_eval, lrdecay=_lrdecay)