CodeBERT-Attack / rnn / rnn.py
rnn.py
Raw
# -*- coding: utf-8 -*-
"""
Created on Wed Jun 24 11:43:38 2020

@author: DrLC
"""

import torch
import torch.nn as nn
import numpy
from gensim.models import Word2Vec
from transformers import RobertaTokenizer

class RNNEncoder(nn.Module):
    
    def __init__(self, embedding_dim, hidden_dim, n_layers,
                 drop_prob=0.5, model="LSTM", brnn=True):
        
        super(RNNEncoder, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.bidirectional = brnn
        if model.upper() == "LSTM":
            self.m = nn.LSTM(embedding_dim, hidden_dim, 
                             self.n_layers, dropout=drop_prob, bidirectional=brnn)
        elif model.upper() == "GRU":
            self.m = nn.GRU(embedding_dim, hidden_dim, 
                            self.n_layers, dropout=drop_prob, bidirectional=brnn)
        elif model.upper() in ["RNN_RELU", "RNN_TANH"]:
            if model.upper() == "RNN_RELU":
                self.m = nn.RNN(embedding_dim, hidden_dim, self.n_layers,
                                nonlinearity="relu", dropout=drop_prob, bidirectional=brnn)
            else:
                self.m = nn.RNN(embedding_dim, hidden_dim, self.n_layers,
                                nonlinearity="tanh", dropout=drop_prob, bidirectional=brnn)
        else:
            assert False, "Invalid "+model.upper()
        
    def forward(self, input, hidden=None):
        return self.m(input, hidden)


class RNNClassifier(nn.Module):
    
    def __init__(self, num_class, hidden_dim, n_layers, tokenizer_path, w2v_path,
                 max_len=512, drop_prob=0.5, model="LSTM", brnn=True, attn=True,
                 device=None, verbose=False):
        
        super(RNNClassifier, self).__init__()
        self.tokenizer = RobertaTokenizer.from_pretrained(tokenizer_path)
        self.max_len = max_len
        w2v = Word2Vec.load(w2v_path)
        self.embedding_size = w2v.wv.vectors.shape[1]
        self.idx2txt = list(w2v.wv.vocab.keys()) + self.tokenizer.all_special_tokens
        self.txt2idx = {self.idx2txt[i]: i for i in range(len(self.idx2txt))}
        assert self.txt2idx[self.idx2txt[0]] == 0
        assert self.txt2idx[self.idx2txt[len(self.idx2txt)-1]] == len(self.idx2txt)-1
        assert len(self.txt2idx) == len(self.idx2txt)
        emb_matrix = numpy.concatenate([w2v.wv.vectors,
                                        numpy.zeros([len(self.tokenizer.all_special_tokens),
                                                     w2v.wv.vector_size])])
        emb_matrix = torch.FloatTensor(emb_matrix)
        self.embedding = nn.Embedding.from_pretrained(emb_matrix)
        self.encoder = RNNEncoder(embedding_dim=self.embedding_size,
                               hidden_dim=hidden_dim,
                               n_layers=n_layers,
                               drop_prob=drop_prob,
                               model=model,
                               brnn=brnn)
        self.hidden_dim = hidden_dim * 2 if brnn else hidden_dim
        if attn:
            self.query = nn.Linear(self.hidden_dim, 1)
            self.attn_softmax = nn.Softmax(dim=0)
        else:
            self.query, self.attn_softmax = None, None
        self.n_channel = self.hidden_dim
        self.n_class = num_class
        self.classify = nn.Linear(self.n_channel, self.n_class)
        self.pred_softmax = nn.Softmax(dim=1)
        size = 0
        for p in self.parameters():
            size += p.nelement()
        if verbose:
            print('Total param size: {}'.format(size))
        if device is None:
            self.device = torch.device("cuda")
        else:
            self.device = device
            
    def get_hidden(self, inputs):
        
        emb = self.embedding(inputs)        
        hidden_states, _ = self.encoder(emb)
        return hidden_states
    
    def get_mask(self, hidden_states, ls):
        
        mask = (torch.arange(hidden_states.shape[0]).to(self.device)[None, :] < ls[:, None])
        return mask.permute([1, 0])
    
    def get_attention(self, hidden_states, mask):
        
        alpha_logits = self.query(hidden_states.reshape([-1, self.hidden_dim]))
        alpha_logits = alpha_logits.reshape(hidden_states.shape[:2])
        alpha_logits[~mask] = float("-inf")
        alpha = self.attn_softmax(alpha_logits)
        return alpha
        
    def _forward(self, inputs, ls):
  
        hidden_states = self.get_hidden(inputs)
        mask = self.get_mask(hidden_states, ls)
        if self.query is not None:
            alpha = self.get_attention(hidden_states, mask)
        else:
            alpha = mask.float() / ls.reshape([hidden_states.shape[1], 1]).permute([1, 0]).float()
        # [l, bs, nch] => [bs, nch] => [bs, ncl]
        _alpha = torch.stack([alpha for _ in range(self.n_channel)], dim=2)
        logits = self.classify(torch.sum(hidden_states * _alpha, dim=0))
        return logits

    def tokenize(self, inputs, cut_and_pad=False, ret_id=False):
        
        rets = []
        lens = []
        if isinstance(inputs, str):
            inputs = [inputs]
        for sent in inputs:
            if cut_and_pad:
                tokens = self.tokenizer.tokenize(sent)[:self.max_len-2]
                tokens = [self.tokenizer.cls_token] + tokens + [self.tokenizer.sep_token]
                lens.append(len(tokens))
                padding_length = self.max_len - len(tokens)
                tokens += [self.tokenizer.pad_token] * padding_length
            else:
                tokens = self.tokenizer.tokenize(sent)
                tokens = [self.tokenizer.cls_token] + tokens + [self.tokenizer.sep_token]
                lens.append(len(tokens))
            if not ret_id:
                rets.append(tokens)
            else:
                ids = []
                for t in tokens:
                    if t in self.txt2idx.keys():
                        ids.append(self.txt2idx[t])
                    else:
                        ids.append(self.txt2idx[self.tokenizer.unk_token])
                rets.append(ids)
        if not ret_id:
            return rets
        return rets, lens
    
    def _run_batch(self, batch, lens, eval_mode):
        
        batch_max_length = lens.max().item()
        inputs = batch[:, :batch_max_length]
        inputs = inputs.to(self.device)
        lens = lens.to(self.device)
        inputs = inputs.permute([1, 0])
        if eval_mode:
            self.eval()
            with torch.no_grad():
                logits = self._forward(inputs, lens)
        else:
            self.train()
            logits = self._forward(inputs, lens)
        return logits
    
    def forward(self, inputs, bs=16, eval_mode=True):
        
        logits = self.run(inputs, bs, eval_mode)
        prob = self.pred_softmax(logits)
        return prob
    
    def run(self, inputs, batch_size=16, eval_mode=True):
        
        input_ids, input_lens = self.tokenize(inputs, cut_and_pad=True, ret_id=True)
        outputs = None
        batch_num = (len(input_ids) - 1) // batch_size + 1
        for step in range(batch_num):
            batch = torch.tensor(input_ids[step*batch_size: (step+1)*batch_size])
            lens = torch.tensor(input_lens[step*batch_size: (step+1)*batch_size])
            if outputs is None:
                outputs = self._run_batch(batch, lens, eval_mode)
            else:
                outputs = torch.cat((outputs, self._run_batch(batch, lens, eval_mode)), 0)
        return outputs
    
class RNNClone(RNNClassifier):
    
    def __init__(self, num_class, hidden_dim, n_layers, tokenizer_path, w2v_path,
                 max_len=512, drop_prob=0.5, model="LSTM", brnn=True, attn=True,
                 device=None, verbose=False):
        
        super(RNNClone, self).__init__(num_class, hidden_dim, n_layers, tokenizer_path,
                                       w2v_path, max_len, drop_prob, model, brnn, attn,
                                       device, verbose)
        self.linear_merge = nn.Linear(self.n_channel * 2, self.n_channel)
    
    def _forward(self, inputs, ls):
  
        hidden_states = self.get_hidden(inputs)
        mask = self.get_mask(hidden_states, ls)
        if self.query is not None:
            alpha = self.get_attention(hidden_states, mask)
        else:
            alpha = mask.float() / ls.reshape([hidden_states.shape[1], 1]).permute([1, 0]).float()
        # [l, 2*bs, nch] => [bs, 2*nch] => [bs, nch] * 2 => [bs, ncl]
        _alpha = torch.stack([alpha for _ in range(self.n_channel)], dim=2)
        encoding_ = torch.sum(hidden_states * _alpha, dim=0).reshape([-1, self.n_channel * 2])
        encoding1 = encoding_[:, : encoding_.shape[1]//2]
        encoding2 = encoding_[:, encoding_.shape[1]//2 :]
        assert encoding1.shape == encoding2.shape
        encoding = torch.abs(encoding1 - encoding2)
        logits = self.classify(encoding)
        return logits
    
if __name__ == "__main__":
    
    vocab_size = 300
    embed_size = 256
    hidden_size = 128
    n_layer = 2
    n_class = 2
    drop_prob = 0.5
    model = "LSTM"
    bidirection = True
    
    w2v_path = "../data/bcb_w2v.model"
    tokenizer_path = "/var/data/zhanghz/codebert-base-mlm"
    
    device = torch.device("cpu")
    
    model = RNNClone(num_class=n_class,
                     hidden_dim=hidden_size,
                     n_layers=n_layer,
                     tokenizer_path=tokenizer_path,
                     w2v_path=w2v_path,
                     max_len=512,
                     drop_prob=drop_prob,
                     model=model,
                     brnn=bidirection,
                     attn=False,
                     device=device)
    
    inputs = [
        "int main ( ) { int n , i ; n = 1 ; return 0; }",
        "int main ( ) { int <mask>, i ; <mask> = 1 ; return 0 ; }", 
        "void main ( ) { double x ; }",
        "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }",
        "int main ( ) { int n , i ; n = 1 ; return 0; }",
        "int main ( ) { int <mask>, i ; <mask> = 1 ; return 0 ; }", 
        "void main ( ) { double x ; }",
        "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }",
        "int main ( ) { int n , i ; n = 1 ; return 0; }",
        "int main ( ) { int <mask>, i ; <mask> = 1 ; return 0 ; }", 
        "void main ( ) { double x ; }",
        "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }",
        "int main ( ) { int n , i ; n = 1 ; return 0; }",
        "int main ( ) { int <mask>, i ; <mask> = 1 ; return 0 ; }", 
        "void main ( ) { double x ; }",
        "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }",
        "int main ( ) { int n , i ; n = 1 ; return 0; }",
        "int main ( ) { int <mask>, i ; <mask> = 1 ; return 0 ; }", 
        "void main ( ) { double x ; }",
        "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }",
    ]
    
    tokens1 = model.tokenize(inputs)
    for s in tokens1:
        for t in s:
            print (t, end=" ")
        print ()
    print ()
    tokens2, _ = model.tokenize(inputs, False, True)
    for s in tokens2:
        for t in s:
            print (model.idx2txt[t], end=" ")
        print ()
    print ()
     
    l1 = model.run(inputs, 16, True)
    print (l1.shape)
    print (l1)
    l2 = model.run(inputs, 32, True)
    print (l2.shape)
    print (l2)
    l3 = model.run(inputs, 16, False)
    print (l3.shape)
    print (l3)
    l4 = model.run(inputs, 32, False)
    print (l4.shape)
    print (l4)