# -*- coding: utf-8 -*- """ Created on Mon Oct 19 11:27:11 2020 @author: DrLC """ from utils import align_subtokens, is_uid, is_special_id, is_java_uid, is_java_special_id import torch import copy from torch.nn import CrossEntropyLoss class CodeBERT_Attack_UID(object): def __init__(self, lang="c"): lang = lang.upper() if lang == "C": self.is_uid = is_uid self.is_special_id = is_special_id elif lang == 'JAVA': self.is_uid = is_java_uid self.is_special_id = is_java_special_id else: assert False def find_vulnerable_uids(self, cls, uid, ground_truth_label, other_uid=None, n_vul=5): masked_code = uid.gen_mask() masked_uids = list(masked_code.keys()) if other_uid is None: batch_ = [masked_code[i] for i in masked_uids] + [uid.code] # The last one is the original else: batch_ = [] for i in masked_uids: batch_.append(masked_code[i]) batch_.append(other_uid.code) batch_.append(uid.code) # The last pair is the original batch_.append(other_uid.code) batch = [" ".join(s) for s in batch_] logits = cls.run(batch)[:, ground_truth_label] score = logits[-1] - logits[:-1] # Importance score of each UID n_selected = int(min(len(score), n_vul)) _, idxs = score.topk(n_selected) ret = {} for i in idxs: if score[i] > 0: if other_uid is not None: ret[masked_uids[i]] = batch_[2 * i] else: ret[masked_uids[i]] = batch_[i] return ret def generate_candidates(self, uid, mlm, vulnerable, idx2txt, bpe_indicator='Ġ', n_candidate=10, smoothing=0.1, batch_size=32, criterion=None, len_threshold=512, max_computational=1e8, other_uid=None): assert smoothing < 1 specials = mlm.tokenizer.all_special_tokens if criterion is None: criterion = CrossEntropyLoss(reduction="none") device = mlm.device # 1. Generate all candidate subtokens subtokens = mlm.tokenize(" ".join(uid.code), False, False)[0] mlm_logits = mlm.run(" ".join(uid.code))[0] align_sub2token, align_token2sub = align_subtokens(uid.code, subtokens, bpe_indicator, specials) all_candidate_subtoken_logits, all_candidate_subtoken_ids = mlm_logits.topk(n_candidate, dim=-1) # 2. Merge the subtokens to create the candidate tokens at each position candidates = [] computational = 1 # If the subtoken seq is too long, use the simplified version -- considering only the very first positions for i in range(len(uid.code)): if vulnerable == uid.code[i].strip(): new_candidates = [] if n_candidate ** len(align_token2sub[i]) >= max_computational: break for j in align_token2sub[i]: if j >= len_threshold: break if len(new_candidates) <= 0: for idx, logit in zip(all_candidate_subtoken_ids[j], all_candidate_subtoken_logits[j]): if idx2txt[idx][0] == bpe_indicator: new_candidates.append((idx2txt[idx][1:], logit.item())) else: new_candidates_tmp = new_candidates new_candidates = [] for idx, logit in zip(all_candidate_subtoken_ids[j], all_candidate_subtoken_logits[j]): if idx2txt[idx][0] != bpe_indicator and idx2txt[idx] not in specials: for _c, _l in new_candidates_tmp: new_candidates.append((_c+idx2txt[idx], logit.item()+_l)) if len(new_candidates) <= 0: continue computational *= len(new_candidates) new_candidates = {it[0]: it[1] for it in new_candidates} new_candidates[None] = min(new_candidates.values()) * smoothing candidates.append(new_candidates) if computational * len(new_candidates) >= max_computational: break # 3. Merge all tokens to create the candidate tokens upon the whole snippet merged_candidates = {} if len(candidates) <= 0: return None, None for c in candidates[0].keys(): if c is None: merged_candidates[None] = candidates[0][None] else: merged_candidates[c] = candidates[0][c] for _candidates in candidates[1:]: for c in _candidates.keys(): None_logit = merged_candidates[None] if c is None: merged_candidates[None] = _candidates[None] elif c in merged_candidates.keys(): merged_candidates[c] += _candidates[c] else: merged_candidates[c] = None_logit + _candidates[c] merged_candidates.pop(None) merged_candidates = sorted(merged_candidates.items(), key=lambda it: it[1], reverse=True) # 4. Filter out the non-uid tokens candidates = [] for _c, _ in merged_candidates[:batch_size]: c = _c.strip() if self.is_uid(c) and (not self.is_special_id(c)) and (c not in uid.sym2pos.keys()): candidates.append(c) if len(candidates) <= 0: return None, None # 5. Select topk accorading to PPL by MLM targets = mlm.tokenize(candidates, cut_and_pad=True, ret_id=True) logits = mlm.run(candidates, batch_size) targets = torch.tensor(targets).to(device)[:, :logits.shape[1]] log_ppl = criterion(logits.permute([0,2,1]), targets).mean(dim=-1) _, selectd_idx = log_ppl.topk(min(n_candidate, len(log_ppl))) ret = [candidates[i] for i in selectd_idx] seq = [] for i in ret: seq.append(copy.deepcopy(uid.code)) for j in uid.sym2pos[vulnerable]: seq[-1][j] = i seq[-1] = " ".join(seq[-1]) if other_uid is not None: seq.append(" ".join(other_uid.code)) return ret, seq class CodeBERT_Attack_CMP(object): def __init__(self): pass def find_vulnerable_cmps(self, cls, cmp, ground_truth_label, n_vul=10): masked_code = cmp.gen_mask() batch_ = masked_code + [cmp.code] # The last one is the original batch = [" ".join(s) for s in batch_] logits = cls.run(batch)[:, ground_truth_label] score = logits[-1] - logits[:-1] # Importance score of each UID n_selected = int(min(len(score), n_vul)) _, idxs = score.topk(n_selected) ret = [] for i in idxs: if score[i] > 0: ret.append(int(i)) return ret def generate_candidates(self, cmp, mlm, vulnerable, batch_size=32, n_candidate=10, criterion=None): if criterion is None: criterion = CrossEntropyLoss(reduction="none") device = mlm.device # 1. Generate all candidate swapped sequences candidates = [] seqs = [] for idx in vulnerable: cmp.update_cmp(idx) candidates.append(" ".join(cmp.code)) seqs.append(" ".join(cmp.code[cmp.cmptab[idx][0]: cmp.cmptab[idx][2]])) cmp.update_cmp(idx) # 2. Filter according to PPL by MLM targets = mlm.tokenize(candidates, cut_and_pad=True, ret_id=True) logits = mlm.run(candidates, batch_size) targets = torch.tensor(targets).to(device)[:, :logits.shape[1]] log_ppl = criterion(logits.permute([0,2,1]), targets).mean(dim=-1) _, selectd_idx = log_ppl.topk(min(n_candidate, len(log_ppl))) ret = [vulnerable[int(i)] for i in selectd_idx] seq = [seqs[int(i)] for i in selectd_idx] cand = [candidates[int(i)] for i in selectd_idx] return ret, seq, cand class CodeBERT_Attack_WS(object): def __init__(self, ws=[" ", "\n", "\t"]): self.ws = copy.deepcopy(ws) def find_vulnerable_tokens(self, cls, ws, ground_truth_label, other_ws=None, n_vul=5, n_mask=15): masked_code = ws.gen_mask(n_mask) masked_idxs = list(masked_code.keys()) if other_ws is None: batch_ = [masked_code[i] for i in masked_idxs] + [ws.code] # The last one is the original else: batch_ = [] for i in masked_idxs: batch_.append(masked_code[i]) batch_.append(other_ws.code) batch_.append(ws.code) # The last pair is the original batch_.append(other_ws.code) batch = [" ".join(s) for s in batch_] logits = cls.run(batch)[:, ground_truth_label] score = logits[-1] - logits[:-1] # Importance score of each UID n_selected = int(min(len(score), n_vul)) _, idxs = score.topk(n_selected) ret = {} for i in idxs: if score[i] > 0: if other_ws is not None: ret[masked_idxs[i]] = batch_[2 * i] else: ret[masked_idxs[i]] = batch_[i] return ret def generate_candidates(self, ws, mlm, vulnerable, n_candidate=10, batch_size=32, criterion=None, len_threshold=512, other_ws=None): if criterion is None: criterion = CrossEntropyLoss(reduction="none") device = mlm.device # 1. Generate all candidates (white space inserted) candidates, idxs, wss = [], [], [] for v in vulnerable.keys(): for i in self.ws: candidates.append(copy.deepcopy(ws.code)) candidates[-1][v] = i + candidates[-1][v] candidates[-1] = " ".join(candidates[-1]) idxs.append(v) wss.append(i) # 2. Select topk according to PPL by MLM targets = mlm.tokenize(candidates, cut_and_pad=True, ret_id=True) logits = mlm.run(candidates, batch_size) targets = torch.tensor(targets).to(device)[:, :logits.shape[1]] log_ppl = criterion(logits.permute([0,2,1]), targets).mean(dim=-1) _, selectd_idx = log_ppl.topk(min(n_candidate, len(log_ppl))) # 5. Select topk accorading to PPL by MLM targets = mlm.tokenize(candidates, cut_and_pad=True, ret_id=True) logits = mlm.run(candidates, batch_size) targets = torch.tensor(targets).to(device)[:, :logits.shape[1]] log_ppl = criterion(logits.permute([0,2,1]), targets).mean(dim=-1) _, selectd_idx = log_ppl.topk(min(n_candidate, len(log_ppl))) ret = [(idxs[i], wss[i]) for i in selectd_idx] seq = [] for i in selectd_idx: seq.append(candidates[i]) if other_ws is not None: seq.append(" ".join(other_ws.code)) return ret, seq if __name__ == "__main__": from uid import UIDStruct from cmp import CMPStruct from codebert import codebert_mlm, codebert_cls from utils import normalize, denormalize from cparser import CCode import argparse import os import pickle import json import time parser = argparse.ArgumentParser() parser.add_argument('--gpu', type=str, default='-1') parser.add_argument('--mlm_path', type=str, default="/var/data/lushuai/bertvsbert/save/poj/checkpoint-9000-1.0555", help="Path to the masked language model") parser.add_argument('--cls_path', type=str, default="/var/data/lushuai/bertvsbert/save/poj-classifier/checkpoint-51000-0.986", help="Path to the OJ classifier") parser.add_argument('--testset', type=str, default="../data/test.pkl") opt = parser.parse_args() if int(opt.gpu) < 0: device = torch.device("cpu") else: os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu device = torch.device("cuda") mlm_model = codebert_mlm(opt.mlm_path, device) vocab_path = mlm_model.tokenizer.vocab_files_names["vocab_file"] with open(os.path.join(opt.mlm_path, vocab_path), "r") as f: txt2idx = json.load(f) tmp = sorted(txt2idx.items(), key=lambda it: it[1]) idx2txt = [it[0] for it in tmp] assert txt2idx[idx2txt[-1]] == len(idx2txt) - 1, \ "\n"+idx2txt[-1]+"\n"+str(txt2idx[idx2txt[-1]])+"\n"+str(len(idx2txt)-1) cls_model = codebert_cls(opt.cls_path, device) atk = CodeBERT_Attack_CMP() #atk = CodeBERT_Attack_UID() with open(opt.testset, "rb") as f: d = pickle.load(f) len_threshold = -1 times_per_example = 5 time_st = time.time() for i in range(len(d['src'])): if len_threshold > 0 and len(d['src'][i]) >= len_threshold: continue s = [t.strip() for t in d['src'][i]] s_norm = normalize(s) y = d['label'][i] - 1 logits = cls_model.run([" ".join(s_norm)]) if logits[0].argmax().item() != y: continue cmp = CMPStruct(s, mask=mlm_model.tokenizer.unk_token) #uid = UIDStruct(s, mask=mlm_model.tokenizer.unk_token) for cnt_per_example in range(times_per_example): vulnerables = atk.find_vulnerable_cmps(cls_model, cmp, y) #vulnerables = atk.find_vulnerable_uids(cls_model, uid, y) if len(vulnerables) <= 0: continue candidates, _, _ = atk.generate_candidates(cmp, mlm_model, vulnerables, n_candidate=10) #candidates, _ = atk.generate_candidates(uid, mlm_model, list(vulnerables.keys())[0], # idx2txt, n_candidate=10, smoothing=0.1) if candidates is None: continue cmp.update_cmp(candidates[0]) #uid.update_sym(list(vulnerables.keys())[0], candidates[0]) assert CCode(" ".join(denormalize(cmp.code))), \ "\n"+" ".join(s_norm)+"\n"+" ".join(cmp.code) print ("\r%.1f min | %6s / %6s |" % ((time.time()-time_st)/60., str(i), str(len(d['src']))), end="\r") print ("\n%.1f min" % ((time.time()-time_st)/60.))