# -*- coding: utf-8 -*- """ Created on Wed Jun 24 11:43:38 2020 @author: DrLC """ import torch import torch.nn as nn import numpy from gensim.models import Word2Vec from transformers import RobertaTokenizer class RNNEncoder(nn.Module): def __init__(self, embedding_dim, hidden_dim, n_layers, drop_prob=0.5, model="LSTM", brnn=True): super(RNNEncoder, self).__init__() self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.n_layers = n_layers self.bidirectional = brnn if model.upper() == "LSTM": self.m = nn.LSTM(embedding_dim, hidden_dim, self.n_layers, dropout=drop_prob, bidirectional=brnn) elif model.upper() == "GRU": self.m = nn.GRU(embedding_dim, hidden_dim, self.n_layers, dropout=drop_prob, bidirectional=brnn) elif model.upper() in ["RNN_RELU", "RNN_TANH"]: if model.upper() == "RNN_RELU": self.m = nn.RNN(embedding_dim, hidden_dim, self.n_layers, nonlinearity="relu", dropout=drop_prob, bidirectional=brnn) else: self.m = nn.RNN(embedding_dim, hidden_dim, self.n_layers, nonlinearity="tanh", dropout=drop_prob, bidirectional=brnn) else: assert False, "Invalid "+model.upper() def forward(self, input, hidden=None): return self.m(input, hidden) class RNNClassifier(nn.Module): def __init__(self, num_class, hidden_dim, n_layers, tokenizer_path, w2v_path, max_len=512, drop_prob=0.5, model="LSTM", brnn=True, attn=True, device=None, verbose=False): super(RNNClassifier, self).__init__() self.tokenizer = RobertaTokenizer.from_pretrained(tokenizer_path) self.max_len = max_len w2v = Word2Vec.load(w2v_path) self.embedding_size = w2v.wv.vectors.shape[1] self.idx2txt = list(w2v.wv.vocab.keys()) + self.tokenizer.all_special_tokens self.txt2idx = {self.idx2txt[i]: i for i in range(len(self.idx2txt))} assert self.txt2idx[self.idx2txt[0]] == 0 assert self.txt2idx[self.idx2txt[len(self.idx2txt)-1]] == len(self.idx2txt)-1 assert len(self.txt2idx) == len(self.idx2txt) emb_matrix = numpy.concatenate([w2v.wv.vectors, numpy.zeros([len(self.tokenizer.all_special_tokens), w2v.wv.vector_size])]) emb_matrix = torch.FloatTensor(emb_matrix) self.embedding = nn.Embedding.from_pretrained(emb_matrix) self.encoder = RNNEncoder(embedding_dim=self.embedding_size, hidden_dim=hidden_dim, n_layers=n_layers, drop_prob=drop_prob, model=model, brnn=brnn) self.hidden_dim = hidden_dim * 2 if brnn else hidden_dim if attn: self.query = nn.Linear(self.hidden_dim, 1) self.attn_softmax = nn.Softmax(dim=0) else: self.query, self.attn_softmax = None, None self.n_channel = self.hidden_dim self.n_class = num_class self.classify = nn.Linear(self.n_channel, self.n_class) self.pred_softmax = nn.Softmax(dim=1) size = 0 for p in self.parameters(): size += p.nelement() if verbose: print('Total param size: {}'.format(size)) if device is None: self.device = torch.device("cuda") else: self.device = device def get_hidden(self, inputs): emb = self.embedding(inputs) hidden_states, _ = self.encoder(emb) return hidden_states def get_mask(self, hidden_states, ls): mask = (torch.arange(hidden_states.shape[0]).to(self.device)[None, :] < ls[:, None]) return mask.permute([1, 0]) def get_attention(self, hidden_states, mask): alpha_logits = self.query(hidden_states.reshape([-1, self.hidden_dim])) alpha_logits = alpha_logits.reshape(hidden_states.shape[:2]) alpha_logits[~mask] = float("-inf") alpha = self.attn_softmax(alpha_logits) return alpha def _forward(self, inputs, ls): hidden_states = self.get_hidden(inputs) mask = self.get_mask(hidden_states, ls) if self.query is not None: alpha = self.get_attention(hidden_states, mask) else: alpha = mask.float() / ls.reshape([hidden_states.shape[1], 1]).permute([1, 0]).float() # [l, bs, nch] => [bs, nch] => [bs, ncl] _alpha = torch.stack([alpha for _ in range(self.n_channel)], dim=2) logits = self.classify(torch.sum(hidden_states * _alpha, dim=0)) return logits def tokenize(self, inputs, cut_and_pad=False, ret_id=False): rets = [] lens = [] if isinstance(inputs, str): inputs = [inputs] for sent in inputs: if cut_and_pad: tokens = self.tokenizer.tokenize(sent)[:self.max_len-2] tokens = [self.tokenizer.cls_token] + tokens + [self.tokenizer.sep_token] lens.append(len(tokens)) padding_length = self.max_len - len(tokens) tokens += [self.tokenizer.pad_token] * padding_length else: tokens = self.tokenizer.tokenize(sent) tokens = [self.tokenizer.cls_token] + tokens + [self.tokenizer.sep_token] lens.append(len(tokens)) if not ret_id: rets.append(tokens) else: ids = [] for t in tokens: if t in self.txt2idx.keys(): ids.append(self.txt2idx[t]) else: ids.append(self.txt2idx[self.tokenizer.unk_token]) rets.append(ids) if not ret_id: return rets return rets, lens def _run_batch(self, batch, lens, eval_mode): batch_max_length = lens.max().item() inputs = batch[:, :batch_max_length] inputs = inputs.to(self.device) lens = lens.to(self.device) inputs = inputs.permute([1, 0]) if eval_mode: self.eval() with torch.no_grad(): logits = self._forward(inputs, lens) else: self.train() logits = self._forward(inputs, lens) return logits def forward(self, inputs, bs=16, eval_mode=True): logits = self.run(inputs, bs, eval_mode) prob = self.pred_softmax(logits) return prob def run(self, inputs, batch_size=16, eval_mode=True): input_ids, input_lens = self.tokenize(inputs, cut_and_pad=True, ret_id=True) outputs = None batch_num = (len(input_ids) - 1) // batch_size + 1 for step in range(batch_num): batch = torch.tensor(input_ids[step*batch_size: (step+1)*batch_size]) lens = torch.tensor(input_lens[step*batch_size: (step+1)*batch_size]) if outputs is None: outputs = self._run_batch(batch, lens, eval_mode) else: outputs = torch.cat((outputs, self._run_batch(batch, lens, eval_mode)), 0) return outputs class RNNClone(RNNClassifier): def __init__(self, num_class, hidden_dim, n_layers, tokenizer_path, w2v_path, max_len=512, drop_prob=0.5, model="LSTM", brnn=True, attn=True, device=None, verbose=False): super(RNNClone, self).__init__(num_class, hidden_dim, n_layers, tokenizer_path, w2v_path, max_len, drop_prob, model, brnn, attn, device, verbose) self.linear_merge = nn.Linear(self.n_channel * 2, self.n_channel) def _forward(self, inputs, ls): hidden_states = self.get_hidden(inputs) mask = self.get_mask(hidden_states, ls) if self.query is not None: alpha = self.get_attention(hidden_states, mask) else: alpha = mask.float() / ls.reshape([hidden_states.shape[1], 1]).permute([1, 0]).float() # [l, 2*bs, nch] => [bs, 2*nch] => [bs, nch] * 2 => [bs, ncl] _alpha = torch.stack([alpha for _ in range(self.n_channel)], dim=2) encoding_ = torch.sum(hidden_states * _alpha, dim=0).reshape([-1, self.n_channel * 2]) encoding1 = encoding_[:, : encoding_.shape[1]//2] encoding2 = encoding_[:, encoding_.shape[1]//2 :] assert encoding1.shape == encoding2.shape encoding = torch.abs(encoding1 - encoding2) logits = self.classify(encoding) return logits if __name__ == "__main__": vocab_size = 300 embed_size = 256 hidden_size = 128 n_layer = 2 n_class = 104 drop_prob = 0.5 model = "LSTM" bidirection = True w2v_path = "../data/w2v.model" tokenizer_path = "/var/data/lushuai/bertvsbert/save/poj-classifier/checkpoint-51000-0.986" device = torch.device("cpu") model = RNNClassifier(num_class=n_class, hidden_dim=hidden_size, n_layers=n_layer, tokenizer_path=tokenizer_path, w2v_path=w2v_path, max_len=512, drop_prob=drop_prob, model=model, brnn=bidirection, attn=False, device=device) inputs = [ "int main ( ) { int n , i ; n = 1 ; return 0; }", "int main ( ) { int <mask>, i ; <mask> = 1 ; return 0 }", "void main ( ) { double x ; }", "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }", "int main ( ) { int n , i ; n = 1 ; return 0; }", "int main ( ) { int <mask>, i ; <mask> = 1 ; return 0 }", "void main ( ) { double x ; }", "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }", "int main ( ) { int n , i ; n = 1 ; return 0; }", "int main ( ) { int <mask>, i ; <mask> = 1 ; return 0 }", "void main ( ) { double x ; }", "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }", "int main ( ) { int n , i ; n = 1 ; return 0; }", "int main ( ) { int <mask>, i ; <mask> = 1 ; return 0 }", "void main ( ) { double x ; }", "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }", "int main ( ) { int n , i ; n = 1 ; return 0; }", "int main ( ) { int <mask>, i ; <mask> = 1 ; return 0 }", "void main ( ) { double x ; }", "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }", "int main ( ) { int n , i ; n = 1 ; return 0; }", "int main ( ) { int <mask>, i ; <mask> = 1 ; return 0 }", "void main ( ) { double x ; }", "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }", "int main ( ) { int n , i ; n = 1 ; return 0; }", "int main ( ) { int <mask>, i ; <mask> = 1 ; return 0 }", "void main ( ) { double x ; }", "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }" ] tokens1 = model.tokenize(inputs) for s in tokens1: for t in s: print (t, end=" ") print () print () tokens2, _ = model.tokenize(inputs, False, True) for s in tokens2: for t in s: print (model.idx2txt[t], end=" ") print () print () assert tokens2 == tokens2 l1 = model.run(inputs, 16, True) print (l1.shape) print (l1) l2 = model.run(inputs, 32, True) print (l2.shape) print (l2) l3 = model.run(inputs, 16, False) print (l3.shape) print (l3) l4 = model.run(inputs, 32, False) print (l4.shape) print (l4)