# -*- coding: utf-8 -*- """ Created on Mon Dec 7 15:43:26 2020 @author: DrLC """ import pickle import tqdm import random import json import torch import os import time from tree_sitter import Language, Parser from utils import denormalize, normalize from cparser import CCode from token_level import UIDStruct, WSStruct, UIDStruct_Java from codebert_attack import CodeBERT_Attack_UID, CodeBERT_Attack_WS def test_uid_c(len_threshold=-1, data_path="../data/norm.pkl"): with open(data_path, "rb") as f: d = pickle.load(f) for i in tqdm.tqdm(range(len(d['src']))): if len_threshold > 0 and len(d['src'][i]) >= len_threshold: continue s = [t.strip() for t in d['src'][i]] uid = UIDStruct(s) assert uid.ccode.getTokenSeq() == s, \ "\n"+str(uid.ccode.getTokenSeq())+"\n"+str(s) assert uid.code == normalize(s), \ "\n"+str(uid.code)+"\n"+str(d['norm'][i]) for s in uid.sym2pos.keys(): for p in uid.sym2pos[s]: assert uid.code[p] == s, \ "\n"+str(uid.code)+"\n"+str(uid.sym2pos)+"\n"+str(p)+" ["+uid.code[p]+"] ["+s+"]" keys = list(uid.sym2pos.keys()) assert len(keys) < 2 or not uid.update_sym("AnImpossiblyLongVariableWhichCanNeverShowUpInCode", "Prefix_"+keys[1]) assert len(keys) < 2 or not uid.update_sym(keys[0], keys[1]) assert len(keys) < 2 or not uid.update_sym(keys[0], "123") assert len(keys) < 2 or not uid.update_sym(keys[0], "1_wrong_id") assert len(keys) < 2 or not uid.update_sym(keys[0], "+_()_$#*UIQ") assert len(keys) < 2 or not uid.update_sym(keys[0], "for") assert len(keys) < 2 or not uid.update_sym(keys[0], "main") assert len(keys) < 2 or not uid.update_sym(keys[0], "NULL") update = [] for m in keys: new_m = "Prefix_"+m uid.update_sym(m, new_m) update.append((m, new_m)) keys = list(uid.sym2pos.keys()) for m in keys: new_m = "Prefix_"+m uid.update_sym(m, new_m) update.append((m, new_m)) assert update == uid.history denorm = denormalize(uid.code) try: CCode(" ".join(denormalize(uid.code))) except: assert False, \ "\n"+str(uid.code)+"\n"+str(denorm) def test_ws_c(len_threshold=-1, data_path="../data/norm.pkl"): with open(data_path, "rb") as f: d = pickle.load(f) for i in tqdm.tqdm(range(len(d['src']))): if len_threshold > 0 and len(d['src'][i]) >= len_threshold: continue s = [t.strip() for t in d['src'][i]] ws = WSStruct(s, mask="<mask>", max_len=50) for j in range(16): masked = ws.gen_mask(1) idx = list(masked.keys())[0] assert "<mask>" in masked[idx] assert idx < 50 assert not ws.update_ws(idx, "123") assert not ws.update_ws(idx, "for") assert not ws.update_ws(idx, "main") assert not ws.update_ws(idx, "NULL") insert = random.sample([' ', '\t', '\n'], 1)[0] assert ws.update_ws(idx, insert) assert ws.code[idx] == insert denorm = denormalize(ws.code) try: CCode(" ".join(denorm)) except: assert False, \ "\n"+str(ws.code)+"\n"+str(denorm) def test_uid_java(len_threshold=-1, data_path="../data/bcb_norm.jsonl", so_path='../data/java-language.so'): with open(data_path, "rb") as f: d = [] for l in f.readlines(): d.append([t.strip() for t in json.loads(l)['func'].split()]) JAVA_LANGUAGE = Language(so_path, 'java') parser = Parser() parser.set_language(JAVA_LANGUAGE) for i in tqdm.tqdm(range(len(d))): if len_threshold > 0 and len(d[i]) >= len_threshold: continue _s = [t.strip() for t in d[i]] uid = UIDStruct_Java(_s, so_path=so_path) for s in uid.sym2pos.keys(): for p in uid.sym2pos[s]: assert uid.code[p] == s, \ "\n"+str(uid.code)+"\n"+str(uid.sym2pos)+"\n"+str(p)+" ["+uid.code[p]+"] ["+s+"]" keys = list(uid.sym2pos.keys()) assert len(keys) < 2 or not uid.update_sym("AnImpossiblyLongVariableWhichCanNeverShowUpInCode", "Prefix_"+keys[1]) assert len(keys) < 2 or not uid.update_sym(keys[0], keys[1]) assert len(keys) < 2 or not uid.update_sym(keys[0], "123") assert len(keys) < 2 or not uid.update_sym(keys[0], "1_wrong_id") assert len(keys) < 2 or not uid.update_sym(keys[0], "+_()_$#*UIQ") assert len(keys) < 2 or not uid.update_sym(keys[0], "for") assert len(keys) < 2 or not uid.update_sym(keys[0], "main") assert len(keys) < 2 or not uid.update_sym(keys[0], "null") for _uid, m in uid.gen_mask().items(): assert len(m) == len(uid.code) for i in range(len(m)): if i in uid.sym2pos[_uid]: assert m[i] == '<unk>' else: assert m[i] == uid.code[i] update = [] for m in keys: new_m = "Prefix_"+m uid.update_sym(m, new_m) update.append((m, new_m)) keys = list(uid.sym2pos.keys()) for m in keys: new_m = "Prefix_"+m uid.update_sym(m, new_m) update.append((m, new_m)) assert update == uid.history assert len(_s) == len(uid.code) pos = [] for p in uid.sym2pos.values(): pos += p for _t, t, i in zip(normalize(_s), uid.code, range(len(_s))): if i in pos: assert "Prefix_Prefix_" + _t == t, \ "\n["+_t+"] ["+t+"]\n"+str(normalize(_s))+"\n"+str(uid.code) else: assert _t == t, \ "\n["+_t+"] ["+t+"]\n"+str(normalize(_s))+"\n"+str(uid.code) fake_class = ["public", "class", "ThisIsAFakeClass", "{"] + uid.code + ["}"] denorm = denormalize(fake_class) try: parser.parse(bytes(" ".join(denorm), encoding="utf-8")) except: assert False, \ "\n"+str(uid.code)+"\n"+str(denorm) def test_ws_java(len_threshold=-1, data_path="../data/bcb_norm.jsonl", so_path='../data/java-language.so'): with open(data_path, "rb") as f: d = [] for l in f.readlines(): d.append([t.strip() for t in json.loads(l)['func'].split()]) JAVA_LANGUAGE = Language(so_path, 'java') parser = Parser() parser.set_language(JAVA_LANGUAGE) for i in tqdm.tqdm(range(len(d))): if len_threshold > 0 and len(d[i]) >= len_threshold: continue _s = [t.strip() for t in d[i]] ws = WSStruct(_s, mask="<mask>", max_len=50) for j in range(16): masked = ws.gen_mask(1) idx = list(masked.keys())[0] assert "<mask>" in masked[idx] assert idx < 50 assert not ws.update_ws(idx, "123") assert not ws.update_ws(idx, "for") assert not ws.update_ws(idx, "main") assert not ws.update_ws(idx, "NULL") insert = random.sample([' ', '\t', '\n'], 1)[0] assert ws.update_ws(idx, insert) assert ws.code[idx] == insert denorm = denormalize(ws.code) try: parser.parse(bytes(" ".join(denorm), encoding="utf-8")) except: assert False, \ "\n"+str(ws.code)+"\n"+str(denorm) def test_cba_uid_c(mlm_path="/var/data/lushuai/bertvsbert/save/only_mlm/poj/checkpoint-9000-1.0555", cls_path="/var/data/lushuai/bertvsbert/save/only_mlm/poj-classifier/checkpoint-51000-0.986", testset="../data/test.pkl", gpu="0"): if int(gpu) < 0: device = torch.device("cpu") else: os.environ["CUDA_VISIBLE_DEVICES"] = gpu device = torch.device("cuda") from codebert import codebert_mlm, codebert_cls mlm_model = codebert_mlm(mlm_path, device) vocab_path = mlm_model.tokenizer.vocab_files_names["vocab_file"] with open(os.path.join(mlm_path, vocab_path), "r") as f: txt2idx = json.load(f) tmp = sorted(txt2idx.items(), key=lambda it: it[1]) idx2txt = [it[0] for it in tmp] assert txt2idx[idx2txt[-1]] == len(idx2txt) - 1, \ "\n"+idx2txt[-1]+"\n"+str(txt2idx[idx2txt[-1]])+"\n"+str(len(idx2txt)-1) cls_model = codebert_cls(cls_path, device) atk = CodeBERT_Attack_UID() with open(testset, "rb") as f: d = pickle.load(f) len_threshold = -1 times_per_example = 5 time_st = time.time() for i in range(len(d['src'])): if len_threshold > 0 and len(d['src'][i]) >= len_threshold: continue s = [t.strip() for t in d['src'][i]] s_norm = normalize(s) y = d['label'][i] - 1 logits = cls_model.run([" ".join(s_norm)]) if logits[0].argmax().item() != y: continue uid = UIDStruct(s, mask=mlm_model.tokenizer.unk_token) for cnt_per_example in range(times_per_example): vulnerables = atk.find_vulnerable_uids(cls_model, uid, y) if len(vulnerables) <= 0: continue candidates, _ = atk.generate_candidates(uid, mlm_model, list(vulnerables.keys())[0], idx2txt, n_candidate=10, smoothing=0.1) if candidates is None: continue uid.update_sym(list(vulnerables.keys())[0], candidates[0]) assert CCode(" ".join(denormalize(uid.code))), \ "\n"+" ".join(s_norm)+"\n"+" ".join(uid.code) print ("\r%.1f min | %6s / %6s |" % ((time.time()-time_st)/60., str(i), str(len(d['src']))), end="\r") print ("\n%.1f min" % ((time.time()-time_st)/60.)) def test_cba_uid_java(mlm_path="/var/data/zhanghz/codebert-base-mlm", cls_path="/var/data/lushuai/bertvsbert/save/bcb-mlm/checkpoint-best-f1", data_path="../data/bcb_norm.jsonl", testset="../data/bcb_test_downsample.txt", so_path='../data/java-language.so', gpu="3"): if int(gpu) < 0: device = torch.device("cpu") else: os.environ["CUDA_VISIBLE_DEVICES"] = gpu device = torch.device("cuda") from codebert import codebert_mlm, codebert_clone mlm_model = codebert_mlm(mlm_path, device) vocab_path = mlm_model.tokenizer.vocab_files_names["vocab_file"] with open(os.path.join(mlm_path, vocab_path), "r") as f: txt2idx = json.load(f) tmp = sorted(txt2idx.items(), key=lambda it: it[1]) idx2txt = [it[0] for it in tmp] assert txt2idx[idx2txt[-1]] == len(idx2txt) - 1, \ "\n"+idx2txt[-1]+"\n"+str(txt2idx[idx2txt[-1]])+"\n"+str(len(idx2txt)-1) cls_model = codebert_clone(cls_path, device) atk = CodeBERT_Attack_UID(lang="java") with open(data_path, "r") as f: src = {} for l in f.readlines(): r = json.loads(l) src[r['idx']] = r['func'].strip().split() with open(testset, "r") as f: d = f.readlines() d = [i.strip().split() for i in d] len_threshold = -1 times_per_example = 3 time_st = time.time() for i in range(len(d)): if len_threshold > 0 and \ (len(src[d[i][0]]) >= len_threshold or len(src[d[i][1]]) >= len_threshold): continue s = [t.strip() for t in src[d[i][0]]] s2 = [t.strip() for t in src[d[i][1]]] y = int(d[i][2]) logits = cls_model.run([" ".join(s), " ".join(s2)]) if logits[0].argmax().item() != y: continue uid = UIDStruct_Java(s, mask=mlm_model.tokenizer.unk_token, so_path=so_path, fake_class=True) uid2 = UIDStruct_Java(s2, mask=mlm_model.tokenizer.unk_token, so_path=so_path, fake_class=True) for cnt_per_example in range(times_per_example): vulnerables = atk.find_vulnerable_uids(cls_model, uid, y, uid2) if len(vulnerables) <= 0: break print (list(vulnerables.keys())[0], end="=>") print (",".join(mlm_model.tokenize(list(vulnerables.keys())[0])[0][1:-1]), end=" ") print (uid.sym2pos[list(vulnerables.keys())[0]], end=" ", flush=True) candidates, seq = atk.generate_candidates(uid, mlm_model, list(vulnerables.keys())[0], idx2txt, n_candidate=10, smoothing=0.1, len_threshold=400, max_computational=1e6, other_uid=uid2) assert len(seq) == len(candidates) * 2, \ str(len(seq)) + " " + str(len(candidates)) if candidates is None: break uid.update_sym(list(vulnerables.keys())[0], candidates[0]) print (len(candidates), end=" ") print (" ".join(candidates), flush=True) try: UIDStruct_Java(denormalize(uid.code), so_path=so_path, fake_class=True) except Exception as e: print (e) assert False, \ "\n"+" ".join(s)+"\n"+" ".join(uid.code)+"\n" print ("%.1f min | %6s / %6s |" % ((time.time()-time_st)/60., str(i), str(len(d)))) print ("\n%.1f min" % ((time.time()-time_st)/60.)) def test_cba_ws_java(mlm_path="/var/data/zhanghz/codebert-base-mlm", cls_path="/var/data/lushuai/bertvsbert/save/bcb-mlm/checkpoint-best-f1", data_path="../data/bcb_norm.jsonl", testset="../data/bcb_test_downsample.txt", so_path='../data/java-language.so', gpu="1"): if int(gpu) < 0: device = torch.device("cpu") else: os.environ["CUDA_VISIBLE_DEVICES"] = gpu device = torch.device("cuda") JAVA_LANGUAGE = Language(so_path, 'java') parser = Parser() parser.set_language(JAVA_LANGUAGE) from codebert import codebert_mlm, codebert_clone mlm_model = codebert_mlm(mlm_path, device) vocab_path = mlm_model.tokenizer.vocab_files_names["vocab_file"] with open(os.path.join(mlm_path, vocab_path), "r") as f: txt2idx = json.load(f) tmp = sorted(txt2idx.items(), key=lambda it: it[1]) idx2txt = [it[0] for it in tmp] assert txt2idx[idx2txt[-1]] == len(idx2txt) - 1, \ "\n"+idx2txt[-1]+"\n"+str(txt2idx[idx2txt[-1]])+"\n"+str(len(idx2txt)-1) cls_model = codebert_clone(cls_path, device) atk = CodeBERT_Attack_WS() with open(data_path, "r") as f: src = {} for l in f.readlines(): r = json.loads(l) src[r['idx']] = r['func'].strip().split() with open(testset, "r") as f: d = f.readlines() d = [i.strip().split() for i in d] len_threshold = -1 times_per_example = 3 time_st = time.time() for i in range(len(d)): if len_threshold > 0 and \ (len(src[d[i][0]]) >= len_threshold or len(src[d[i][1]]) >= len_threshold): continue s = [t.strip() for t in src[d[i][0]]] s2 = [t.strip() for t in src[d[i][1]]] y = int(d[i][2]) logits = cls_model.run([" ".join(s), " ".join(s2)]) if logits[0].argmax().item() != y: continue ws = WSStruct(s, mask=mlm_model.tokenizer.unk_token) ws2 = WSStruct(s2, mask=mlm_model.tokenizer.unk_token) for cnt_per_example in range(times_per_example): vulnerables = atk.find_vulnerable_tokens(cls_model, ws, y, ws2) if len(vulnerables) <= 0: break candidates, seq = atk.generate_candidates(ws, mlm_model, vulnerables, n_candidate=10, len_threshold=400, other_ws=ws2) assert len(seq) == len(candidates) * 2, \ str(len(seq)) + " " + str(len(candidates)) if candidates is None: break assert ws.update_ws(candidates[0][0], candidates[0][1]) assert ws.code[candidates[0][0]] == candidates[0][1] try: parser.parse(bytes(" ".join(denormalize(ws.code)), encoding="utf-8")) except Exception as e: print (e) assert False, \ "\n"+" ".join(s)+"\n"+" ".join(ws.code)+"\n" print ("%.1f min | %6s / %6s |" % ((time.time()-time_st)/60., str(i), str(len(d)))) print ("\n%.1f min" % ((time.time()-time_st)/60.)) if __name__ == "__main__": test_cba_ws_java()