CodeBERT-Attack / rnn / bcb_run_rnn.py
bcb_run_rnn.py
Raw
# -*- coding: utf-8 -*-
"""
Created on Mon Nov  9 15:03:49 2020

@author: DrLC
"""

import argparse
import os
from dataset import CloneDataset, clone_collate_fn

import tqdm

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, RandomSampler
from sklearn.metrics import f1_score
    
def trainEpochs(epochs, training_dataloader, dev_dataloader, classifier,
                criterion, optimizer,
                batch_size=32, batch_size_eval=64, print_each=1000,
                saving_path='./', lrdecay=1):
    
    n_batch = int(len(training_dataloader.dataset) / batch_size)
    
    for _ep in range(1, 1+epochs):
        classifier.train()
        print_loss_total = 0
        print('Start epoch ' + str(_ep) + '....')
        for _iter, batch in enumerate(training_dataloader):
            inputs, labels = batch
            labels = labels.to(classifier.device)
            optimizer.zero_grad()
            logits = classifier.run(inputs, batch_size, False)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            print_loss_total += loss.item()
            if (_iter + 1) % print_each == 0: 
                print_loss_avg = print_loss_total / print_each
                print_loss_total = 0
                print('\tEp %d %d/%d, loss = %.6f' \
                      % (_ep, (_iter + 1), n_batch, print_loss_avg), flush=True)
        
        print ("  eval...")
        evaluate(classifier, dev_dataloader, batch_size_eval)
        torch.save(classifier.state_dict(),
                   os.path.join(saving_path, str(_ep)+'.pt'))
            
        if lrdecay < 1:
            adjust_learning_rate(optimizer, lrdecay)
 

def adjust_learning_rate(optimizer, decay_rate=0.8):
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr'] * decay_rate
                
def evaluate(classifier, dataloader, batch_size=128):
    
    classifier.eval()
    testnum = 0
    testcorrect = 0
    
    y_true = []
    y_pred = []
    
    for _iter, batch in enumerate(tqdm.tqdm(dataloader)):
        
        inputs, labels = batch
        labels = labels.to(classifier.device)
        outputs = classifier(inputs, batch_size, True)
        preds = torch.argmax(outputs, dim=1)
        res = preds == labels
        testcorrect += torch.sum(res)
        testnum += len(labels)
        y_true += list(labels.cpu().numpy())
        y_pred += list(preds.cpu().numpy())

    print('Eval_acc:  %.2f%%' % (float(testcorrect) * 100.0 / testnum))
    print('Eval_f1:   %.2f%%' % (f1_score(y_true, y_pred, average='binary') * 100), flush=True)
    
if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', type=str, default='-1')
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--epoch', type=int, default=20)
    parser.add_argument('--bs', type=int, default=128)
    parser.add_argument('--bs_eval', type=int, default=256)
    parser.add_argument('--l2p', type=float, default=0)
    parser.add_argument('--dropout', type=float, default=0.3)
    parser.add_argument('--lrdecay', type=float, default=0.99)
    parser.add_argument('--data_path', type=str, default="/var/data/lushuai/bertvsbert/data/ojclone/ojclone_norm.jsonl")
    parser.add_argument('--trainset', type=str, default="/var/data/lushuai/bertvsbert/data/ojclone/train.txt")
    parser.add_argument('--validset', type=str, default="/var/data/lushuai/bertvsbert/data/ojclone/valid.txt")
    parser.add_argument('--testset', type=str, default="/var/data/lushuai/bertvsbert/data/ojclone/test.txt")
    parser.add_argument('--save_dir', type=str, default='../model/ojclone_lstm')
    parser.add_argument('--tokenizer', type=str,
                        default="/var/data/zhanghz/codebert-base-mlm")
    parser.add_argument('--word2vec', type=str, default="../data/ojclone_w2v.model")
    parser.add_argument('--valid_ratio', type=float, default=0.2)
    parser.add_argument('--eval', type=bool, default=False)
    
    opt = parser.parse_args()
    
    _eval = opt.eval
    
    _src_path = opt.data_path
    _trainset_path = opt.trainset
    _validset_path = opt.validset
    _valid_ratio = opt.valid_ratio
    _testset_path = opt.testset
    _save = opt.save_dir
    if _eval:
        _save = os.path.join(_save, "model.pt")
    _tokenizer_path = opt.tokenizer
    _w2v_path = opt.word2vec
    
    _drop = opt.dropout
    _lr = opt.lr
    _l2p = opt.l2p
    _lrdecay = opt.lrdecay
    
    _bs = opt.bs
    _bs_eval = opt.bs_eval
    _ep = opt.epoch
    
    if int(opt.gpu) < 0:
        device = torch.device("cpu")
    else:
        device = torch.device("cuda")
        os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu

    from rnn import RNNClone

    model = "LSTM"
    hidden_size = 600
    n_layers = 2
    n_class = 2
    max_len = 400
    attn = True
    bidirection = True

    if _eval:
        test_set = CloneDataset(src_path=_src_path,
                                file_path=_testset_path,
                                file_type='test',
                                max_len=max_len,
                                keep_ratio=_valid_ratio)
        test_sampler = RandomSampler(test_set)
        test_dataloader = DataLoader(dataset=test_set,
                                     sampler=test_sampler,
                                     batch_size=_bs_eval,
                                     drop_last=False,
                                     collate_fn=clone_collate_fn)
    else:
        training_set = CloneDataset(src_path=_src_path,
                                    file_path=_trainset_path,
                                    file_type='train',
                                    max_len=max_len)
        training_sampler = RandomSampler(training_set)
        training_dataloader = DataLoader(dataset=training_set,
                                         sampler=training_sampler,
                                         batch_size=_bs,
                                         drop_last=True,
                                         collate_fn=clone_collate_fn)
        dev_set = CloneDataset(src_path=_src_path,
                               file_path=_validset_path,
                               file_type='valid',
                               max_len=max_len,
                               keep_ratio=_valid_ratio)
        dev_sampler = RandomSampler(dev_set)
        dev_dataloader = DataLoader(dataset=dev_set,
                                    sampler=dev_sampler,
                                    batch_size=_bs_eval,
                                    drop_last=True,
                                    collate_fn=clone_collate_fn)

    classifier = RNNClone(num_class=n_class,
                          hidden_dim=hidden_size,
                          n_layers=n_layers,
                          tokenizer_path=_tokenizer_path,
                          w2v_path=_w2v_path,
                          max_len=max_len,
                          drop_prob=_drop,
                          model=model,
                          brnn=bidirection,
                          attn=attn,
                          device=device).to(device)
    
    if _eval:
        classifier.load_state_dict(torch.load(_save))
        classifier.eval()
        evaluate(classifier, test_dataloader, _bs_eval)
    else:
        classifier.train()
        optimizer = optim.Adam(classifier.parameters(), lr=_lr, weight_decay=_l2p) 
        criterion = nn.CrossEntropyLoss()

        trainEpochs(_ep, training_dataloader, dev_dataloader, classifier, criterion, optimizer,
                    saving_path=_save, batch_size=_bs, batch_size_eval=_bs_eval, lrdecay=_lrdecay)