# -*- coding: utf-8 -*- """ Created on Tue Oct 27 15:58:42 2020 @author: DrLC """ from token_level import UIDStruct, WSStruct, UIDStruct_Java from exp_level import CMPStruct from utils import is_uid, is_special_id, normalize, denormalize, is_java_uid, is_java_special_id from torch.nn import CrossEntropyLoss, Softmax from codebert_attack import CodeBERT_Attack_UID, CodeBERT_Attack_CMP, CodeBERT_Attack_WS from mhm import MHM_Baseline # Import codebert & rnn later import torch import argparse import os import pickle import json import time import numpy import random import copy if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--gpu', type=str, default='-1', help="Gpu selection") parser.add_argument('--mlm_path', type=str, default="/var/data/lushuai/bertvsbert/save/only_mlm/poj/checkpoint-9000-1.0555", help="Path to the masked language model") parser.add_argument('--cls_path', type=str, default="/var/data/lushuai/bertvsbert/save/only_mlm/poj-classifier/checkpoint-51000-0.986", help="Path to the downstream CodeBERT classifier") parser.add_argument('--trainset', type=str, default="../data/train.pkl", help="Path to the train set") parser.add_argument('--testset', type=str, default="../data/test.pkl", help="Path to the test set") parser.add_argument('--uid_path', type=str, default="../data/all_uids.pkl", help="Path to the uid file") parser.add_argument('--attack', type=str, default='cba', help="Attack approach") parser.add_argument('--max_perturb_iter', type=int, default=20, help="Maximal iteration of perturbtions") parser.add_argument('--max_vulnerable', type=int, default=5, help="Maximal vulnerable number") parser.add_argument('--max_mask_ws', type=int, default=30, help="Maximal mask number for white space attack") parser.add_argument('--max_candidate', type=int, default=10, help="Maximal candidate number") parser.add_argument('--smooth_factor', type=float, default=0.1, help="Smoothing factor during merging") parser.add_argument('--init_temperature', type=float, default=1, help="Temperature initialization for SA") parser.add_argument('--cooling_factor', type=float, default=0.8, help="Temperature decreasement for SA") parser.add_argument('--model', type=str, default="CB", help="Target model / victim model") parser.add_argument('--so_path', type=str, default='../data/java-language.so', help="Path to the java parser library") parser.add_argument('--word2vec', type=str, default="../data/w2v.model", help="Path to the word2vec matrix") parser.add_argument('--rnn_path', type=str, default='../model/oj_lstm/model.pt', help="Path to the downstream RNN classifier") parser.add_argument('--data_name', type=str, default="OJ", help='Name of the dataset') parser.add_argument('--rand_idx', type=bool, default=False, help="Random indexing without vulnerablility") parser.add_argument('--rand_cand', type=bool, default=True, help="Random candidate without codebert") opt = parser.parse_args() print ("CONFIGS") args = vars(opt) for k in args.keys(): print (" " + k + " = " + str(args[k])) _attack = opt.attack.upper() # CodeBERT-Attack, MHM, Simulated Annealling CBA, Compare-swapping CBA, Compare-swapping SACBA, White-space-inserting CBA, White-space-inserting SACBA assert _attack in ['CBA', 'MHM', 'SACBA', 'CBA-CMP', 'SACBA-CMP', 'CBA-WS', 'SACBA-WS'] _victim_model = opt.model.upper() assert _victim_model in ['CB', 'LSTM'] # CodeBERT, LSTM n_perturb = opt.max_perturb_iter n_vulnerable = opt.max_vulnerable n_mask_ws = opt.max_mask_ws n_candidate = opt.max_candidate smoothing = opt.smooth_factor # Load the models (MLM and downstream CLS) if int(opt.gpu) < 0: device = torch.device("cpu") else: os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu device = torch.device("cuda") # Cannot select GPU after importing transformers... from codebert import codebert_mlm, codebert_cls from rnn import RNNClassifier data_name = opt.data_name.upper() n_class = {"OJ": 104, "JDP": 2}[data_name] class_op = {"OJ": -1, "JDP": 0}[data_name] if data_name == "JDP": _is_uid = is_java_uid _is_special_id = is_java_special_id elif data_name == "OJ": _is_uid = is_uid _is_special_id = is_special_id _lang = {"OJ": "C", "JDP": "JAVA"}[data_name] mlm_model = codebert_mlm(opt.mlm_path, device) if _victim_model == "CB": cls_model = codebert_cls(opt.cls_path, device, n_class=n_class) else: hidden_size = 600 n_layers = 2 max_len = 512 attn = True bidirection = True cls_model = RNNClassifier(num_class=n_class, hidden_dim=hidden_size, n_layers=n_layers, tokenizer_path=opt.cls_path, w2v_path=opt.word2vec, max_len=max_len, drop_prob=0., model=_victim_model, brnn=bidirection, attn=attn, device=device).to(device) cls_model.load_state_dict(torch.load(opt.rnn_path)) cls_model.eval() # Load the vocabulary of MLM vocab_path = mlm_model.tokenizer.vocab_files_names["vocab_file"] with open(os.path.join(opt.mlm_path, vocab_path), "r") as f: txt2idx = json.load(f) tmp = sorted(txt2idx.items(), key=lambda it: it[1]) idx2txt = [it[0] for it in tmp] assert txt2idx[idx2txt[-1]] == len(idx2txt) - 1, \ "\n"+idx2txt[-1]+"\n"+str(txt2idx[idx2txt[-1]])+"\n"+str(len(idx2txt)-1) # Load the test set with open(opt.testset, "rb") as f: d = pickle.load(f) if opt.rand_cand: with open(opt.uid_path, "rb") as f: all_uids = pickle.load(f) print (len(all_uids)) # Renaming attack if _attack == 'CBA': atk = CodeBERT_Attack_UID(lang=_lang) elif _attack == 'MHM': # When using MHM, collect all uids from the training set with open(opt.uid_path, "rb") as f: all_uids = pickle.load(f) print (len(all_uids)) atk = MHM_Baseline(all_uids, lang=_lang) elif _attack == 'SACBA': atk = CodeBERT_Attack_UID(lang=_lang) temperature = lambda n: opt.init_temperature * (opt.cooling_factor ** n) elif _attack == 'CBA-CMP': atk = CodeBERT_Attack_CMP() elif _attack == 'SACBA-CMP': atk = CodeBERT_Attack_CMP() temperature = lambda n: opt.init_temperature * (opt.cooling_factor ** n) elif _attack == 'CBA-WS': atk = CodeBERT_Attack_WS() elif _attack == 'SACBA-WS': atk = CodeBERT_Attack_WS() temperature = lambda n: opt.init_temperature * (opt.cooling_factor ** n) else: assert False n_total_including_originally_wrong = 0 n_total = 0 n_succ = 0 time_total = 0 ce = CrossEntropyLoss(reduction="none") softmax = Softmax(dim=-1) for i in range(len(d['norm'])): print ("Attack %d / %d. Class %d" % \ (i+1, len(d['norm']), d['label'][i])) succ = False n_total_including_originally_wrong += 1 time_st = time.time() s = [t.strip() for t in d['norm'][i]] s_norm = normalize(s) s = denormalize(s) y = d['label'][i] + class_op logits = cls_model.run([" ".join(s_norm)])[0] # Skip those original erroneously predicted examples if logits.argmax().item() != y: print (" WRONG. SKIP!") continue # Start adversarial attack old_prob = softmax(logits)[y].item() if _attack in ['MHM', 'CBA', 'SACBA']: if _lang == 'C': uid = UIDStruct(s, mask=mlm_model.tokenizer.unk_token) elif _lang == 'JAVA': uid = UIDStruct_Java(s, mask=mlm_model.tokenizer.unk_token, so_path=opt.so_path) print (" UIDs: ", end="") for i in uid.sym2pos.keys(): print (i, end=" ") print () elif _attack in ['CBA-CMP', 'SACBA-CMP']: cmp = CMPStruct(s, mask=mlm_model.tokenizer.unk_token) print (" CMPs: ", end="") for s, _, t in cmp.cmptab: for w in cmp.code[s: t]: print (w, end=" ") print (",", end=" ") print () elif _attack in ['CBA-WS', 'SACBA-WS']: ws = WSStruct(s, mask=mlm_model.tokenizer.unk_token, max_len=512) # MHM if _attack == 'MHM': res = atk.mcmc(uid=uid, label=y, classifier=cls_model, n_candi=n_candidate, max_iter=n_perturb) if res['succ']: succ = True # CodeBert-Attack (UID Renaming) elif _attack in ['CBA', 'SACBA']: for it in range(n_perturb): # Find vulnerable identifiers if opt.rand_idx: vulnerables = random.sample(list(uid.sym2pos.keys()), 1) vulnerables = {i: None for i in vulnerables} else: vulnerables = atk.find_vulnerable_uids(cls=cls_model, uid=uid, ground_truth_label=y, n_vul=n_vulnerable) candidate_uids, candidate_seqs, candidate_old_uids = [], [], [] for v in vulnerables.keys(): # Generate possible candidates for each vulnerable uids if opt.rand_cand: c = [] s = [] for _c in random.sample(all_uids, n_candidate): if _c not in uid.sym2pos.keys(): c.append(_c) s.append(copy.deepcopy(uid.code)) for jjj in uid.sym2pos[v]: s[-1][jjj] = _c s[-1] = " ".join(s[-1]) else: c, s = atk.generate_candidates(uid=uid, mlm=mlm_model, vulnerable=v, idx2txt=idx2txt, bpe_indicator='Ġ', n_candidate=n_candidate, smoothing=smoothing, batch_size=n_candidate*2, criterion=ce, max_computational=1e6, len_threshold=510) if c is None or s is None: continue candidate_old_uids += [v for _ in c] candidate_uids += c candidate_seqs += s if len(candidate_seqs) <= 0: break # Probe the target model probs = softmax(cls_model.run(candidate_seqs)) preds = probs.argmax(dim=-1) for pi in range(len(preds)): # Find an adversarial example if preds[pi].item() != y: print (" %s => %s, %d (%.5f%%) => %d %d (%.5f%% %.5f%%)" % \ (candidate_old_uids[pi], candidate_uids[pi], y, old_prob*100, y, preds[pi], probs[pi][y].item()*100, probs[pi][preds[pi]].item()*100)) succ = True assert uid.update_sym(candidate_old_uids[pi], candidate_uids[pi]), \ "\n"+str(uid.sym2pos.keys())+"\n"+candidate_old_uids[pi]+"\n"+candidate_uids[pi] assert _is_uid(candidate_uids[pi]) break if succ: break next_i = probs[:, y].argmin().item() # CBA - Test if the ground truth probability decreases # SACBA - Accept / reject to jump to the candidate accept = (probs[next_i, y] < old_prob) if _attack == 'SACBA' and (not accept): acc_prob = torch.exp(-(probs[next_i, y] - old_prob) / temperature(it+1)) accept = (numpy.random.uniform(0,1) < acc_prob) if accept: print (" %s => %s, %d (%.5f%%) => %d (%.5f%%)" % \ (candidate_old_uids[next_i], candidate_uids[next_i], y, old_prob*100, y, probs[next_i][y].item()*100)) old_prob = probs[next_i][y].item() assert uid.update_sym(candidate_old_uids[next_i], candidate_uids[next_i]), \ "\n"+str(uid.sym2pos.keys())+"\n"+candidate_old_uids[next_i]+"\n"+candidate_uids[next_i] assert _is_uid(candidate_uids[next_i]), candidate_uids[next_i] else: break # CodeBert-Attack (white space insertion) elif _attack in ['CBA-WS', 'SACBA-WS']: for it in range(n_perturb): # Find vulnerable identifiers vulnerables = atk.find_vulnerable_tokens(cls=cls_model, ws=ws, ground_truth_label=y, n_vul=n_vulnerable, n_mask=n_mask_ws) if len(vulnerables) <= 0: break candidate_wss, candidate_seqs = atk.generate_candidates(ws=ws, mlm=mlm_model, vulnerable=vulnerables, n_candidate=n_candidate, batch_size=n_candidate*3, criterion=ce, len_threshold=510) if len(candidate_seqs) <= 0: break # Probe the target model probs = softmax(cls_model.run(candidate_seqs)) preds = probs.argmax(dim=-1) for pi in range(len(preds)): # Find an adversarial example if preds[pi].item() != y: print (" idx %d <%s>, %d (%.5f%%) => %d %d (%.5f%% %.5f%%)" % \ (candidate_wss[pi][0], {" ":"SPC", "\t":"TAB", "\n":"NL"}[candidate_wss[pi][1]], y, old_prob*100, y, preds[pi], probs[pi][y].item()*100, probs[pi][preds[pi]].item()*100)) succ = True assert ws.update_ws(candidate_wss[pi][0], candidate_wss[pi][1]) break if succ: break next_i = probs[:, y].argmin().item() # CBA-WS - Test if the ground truth probability decreases # SACBA-WS - Accept / reject to jump to the candidate accept = (probs[next_i, y] < old_prob) if _attack == 'SACBA' and (not accept): acc_prob = torch.exp(-(probs[next_i, y] - old_prob) / temperature(it+1)) accept = (numpy.random.uniform(0,1) < acc_prob) if accept: print (" idx %d <%s>, %d (%.5f%%) => %d (%.5f%%)" % \ (candidate_wss[next_i][0], {" ":"SPC", "\t":"TAB", "\n":"NL"}[candidate_wss[next_i][1]], y, old_prob*100, y, probs[next_i][y].item()*100)) old_prob = probs[next_i][y].item() assert ws.update_ws(candidate_wss[next_i][0], candidate_wss[next_i][1]) else: break # CodeBert-Attack (CMP Swapping) elif _attack in ['CBA-CMP', 'SACBA-CMP']: for it in range(n_perturb): # Find vulnerable identifiers vulnerables = atk.find_vulnerable_cmps(cls=cls_model, cmp=cmp, ground_truth_label=y, n_vul=n_vulnerable) if len(vulnerables) <= 0: break candidate_cmps, candidate_new_cmps, candidate_seqs = atk.generate_candidates(cmp=cmp, mlm=mlm_model, vulnerable=vulnerables, batch_size=n_candidate, n_candidate=n_candidate) candidate_old_cmps = [" ".join(cmp.code[cmp.cmptab[i][0]: cmp.cmptab[i][2]]) \ for i in candidate_cmps] if len(candidate_seqs) <= 0: break # Probe the target model probs = softmax(cls_model.run(candidate_seqs)) preds = probs.argmax(dim=-1) for pi in range(len(preds)): # Find an adversarial example if preds[pi].item() != y: print (" %s => %s, %d (%.5f%%) => %d %d (%.5f%% %.5f%%)" % \ (candidate_old_cmps[pi], candidate_new_cmps[pi], y, old_prob*100, y, preds[pi], probs[pi][y].item()*100, probs[pi][preds[pi]].item()*100)) succ = True assert cmp.update_cmp(candidate_cmps[pi]) break if succ: break next_i = probs[:, y].argmin().item() # CBA-CMP - Test if the ground truth probability decreases # SACBA-CMP - Accept / reject to jump to the candidate accept = (probs[next_i, y] < old_prob) if _attack == 'SACBA-CMP' and (not accept): acc_prob = torch.exp(-(probs[next_i, y] - old_prob) / temperature(it+1)) accept = (numpy.random.uniform(0,1) < acc_prob) if accept: print (" %s => %s, %d (%.5f%%) => %d (%.5f%%)" % \ (candidate_old_cmps[next_i], candidate_new_cmps[next_i], y, old_prob*100, y, probs[next_i][y].item()*100)) old_prob = probs[next_i][y].item() assert cmp.update_cmp(candidate_cmps[next_i]) else: break if succ: n_succ += 1 n_total += 1 time_total += time.time() - time_st print (" SUCC!") else: n_total += 1 print (" FAIL!") if n_total > 0: succ_rate = n_succ/n_total else: succ_rate = 0 acc_rate = (n_total-n_succ)/n_total_including_originally_wrong if n_succ > 0: avg_time = time_total/n_succ else: avg_time = float("NaN") print (" Succ %% = %.5f%%, Acc %% = %.5f%%, Avg time = %.5f sec" % \ (succ_rate*100, acc_rate*100, avg_time))