CodeBERT-Attack / oj-attack / assert_test.py
assert_test.py
Raw
# -*- coding: utf-8 -*-
"""
Created on Mon Dec  7 15:43:26 2020

@author: DrLC
"""

import pickle
import tqdm
import random
import json
import torch
import os
import time

from tree_sitter import Language, Parser

from utils import denormalize, normalize
from cparser import CCode
from token_level import UIDStruct, WSStruct, UIDStruct_Java
from codebert_attack import CodeBERT_Attack_UID, CodeBERT_Attack_WS
    
def test_uid_c(len_threshold=-1, data_path="../data/norm.pkl"):
    
    with open(data_path, "rb")  as f:
        d = pickle.load(f)
        
    for i in tqdm.tqdm(range(len(d['src']))):
        if len_threshold > 0 and len(d['src'][i])  >= len_threshold:
            continue

        s = [t.strip() for t in d['src'][i]]

        uid = UIDStruct(s)

        assert uid.ccode.getTokenSeq() == s, \
            "\n"+str(uid.ccode.getTokenSeq())+"\n"+str(s)
        assert uid.code == normalize(s), \
            "\n"+str(uid.code)+"\n"+str(d['norm'][i])
        
        for s in uid.sym2pos.keys():
            for p in uid.sym2pos[s]:
                assert uid.code[p] == s, \
                    "\n"+str(uid.code)+"\n"+str(uid.sym2pos)+"\n"+str(p)+" ["+uid.code[p]+"] ["+s+"]"

        keys = list(uid.sym2pos.keys())
        
        assert len(keys) < 2 or not uid.update_sym("AnImpossiblyLongVariableWhichCanNeverShowUpInCode", "Prefix_"+keys[1])
        assert len(keys) < 2 or not uid.update_sym(keys[0], keys[1])
        assert len(keys) < 2 or not uid.update_sym(keys[0], "123")
        assert len(keys) < 2 or not uid.update_sym(keys[0], "1_wrong_id")
        assert len(keys) < 2 or not uid.update_sym(keys[0], "+_()_$#*UIQ")
        assert len(keys) < 2 or not uid.update_sym(keys[0], "for")
        assert len(keys) < 2 or not uid.update_sym(keys[0], "main")
        assert len(keys) < 2 or not uid.update_sym(keys[0], "NULL")

        update = []
        for m in keys:
            new_m = "Prefix_"+m
            uid.update_sym(m, new_m)
            update.append((m, new_m))
        keys = list(uid.sym2pos.keys())
        for m in keys:
            new_m = "Prefix_"+m
            uid.update_sym(m, new_m)
            update.append((m, new_m))
        assert update == uid.history
        denorm = denormalize(uid.code)
        try:
            CCode(" ".join(denormalize(uid.code)))
        except:
            assert False, \
                "\n"+str(uid.code)+"\n"+str(denorm)
                
def test_ws_c(len_threshold=-1, data_path="../data/norm.pkl"):
    
    with open(data_path, "rb")  as f:
        d = pickle.load(f)
        
    for i in tqdm.tqdm(range(len(d['src']))):
        if len_threshold > 0 and len(d['src'][i])  >= len_threshold:
            continue

        s = [t.strip() for t in d['src'][i]]
        
        ws = WSStruct(s, mask="<mask>", max_len=50)
        for j in range(16):
            masked = ws.gen_mask(1)
            idx = list(masked.keys())[0]
            assert "<mask>" in masked[idx]
            assert idx < 50
            assert not ws.update_ws(idx, "123")
            assert not ws.update_ws(idx, "for")
            assert not ws.update_ws(idx, "main")
            assert not ws.update_ws(idx, "NULL")
            insert = random.sample([' ', '\t', '\n'], 1)[0]
            assert ws.update_ws(idx, insert)
            assert ws.code[idx] == insert
        denorm = denormalize(ws.code)
        try:
            CCode(" ".join(denorm))
        except:
            assert False, \
                "\n"+str(ws.code)+"\n"+str(denorm)
                
def test_uid_java(len_threshold=-1, data_path="../data/bcb_norm.jsonl",
                  so_path='../data/java-language.so'):
    
    with open(data_path, "rb")  as f:
        d = []
        for l in f.readlines():
            d.append([t.strip() for t in json.loads(l)['func'].split()])
        
    JAVA_LANGUAGE = Language(so_path, 'java')
    parser = Parser()
    parser.set_language(JAVA_LANGUAGE)
        
    for i in tqdm.tqdm(range(len(d))):
        if len_threshold > 0 and len(d[i])  >= len_threshold:
            continue

        _s = [t.strip() for t in d[i]]

        uid = UIDStruct_Java(_s, so_path=so_path)
        
        for s in uid.sym2pos.keys():
            for p in uid.sym2pos[s]:
                assert uid.code[p] == s, \
                    "\n"+str(uid.code)+"\n"+str(uid.sym2pos)+"\n"+str(p)+" ["+uid.code[p]+"] ["+s+"]"

        keys = list(uid.sym2pos.keys())
        
        assert len(keys) < 2 or not uid.update_sym("AnImpossiblyLongVariableWhichCanNeverShowUpInCode", "Prefix_"+keys[1])
        assert len(keys) < 2 or not uid.update_sym(keys[0], keys[1])
        assert len(keys) < 2 or not uid.update_sym(keys[0], "123")
        assert len(keys) < 2 or not uid.update_sym(keys[0], "1_wrong_id")
        assert len(keys) < 2 or not uid.update_sym(keys[0], "+_()_$#*UIQ")
        assert len(keys) < 2 or not uid.update_sym(keys[0], "for")
        assert len(keys) < 2 or not uid.update_sym(keys[0], "main")
        assert len(keys) < 2 or not uid.update_sym(keys[0], "null")

        for _uid, m in uid.gen_mask().items():
            assert len(m) == len(uid.code)
            for i in range(len(m)):
                if i in uid.sym2pos[_uid]:
                    assert m[i] == '<unk>'
                else:
                    assert m[i] == uid.code[i]

        update = []
        for m in keys:
            new_m = "Prefix_"+m
            uid.update_sym(m, new_m)
            update.append((m, new_m))
        keys = list(uid.sym2pos.keys())
        for m in keys:
            new_m = "Prefix_"+m
            uid.update_sym(m, new_m)
            update.append((m, new_m))
        assert update == uid.history
        
        assert len(_s) == len(uid.code)
        pos = []
        for p in uid.sym2pos.values():
            pos += p
        for _t, t, i in zip(normalize(_s), uid.code, range(len(_s))):
            if i in pos:
                assert "Prefix_Prefix_" + _t == t, \
                    "\n["+_t+"] ["+t+"]\n"+str(normalize(_s))+"\n"+str(uid.code)
            else:
                assert _t == t, \
                    "\n["+_t+"] ["+t+"]\n"+str(normalize(_s))+"\n"+str(uid.code)
        
        fake_class = ["public", "class", "ThisIsAFakeClass", "{"] + uid.code + ["}"]
        denorm = denormalize(fake_class)
        try:
            parser.parse(bytes(" ".join(denorm), encoding="utf-8"))
        except:
            assert False, \
                "\n"+str(uid.code)+"\n"+str(denorm)
                
def test_ws_java(len_threshold=-1, data_path="../data/bcb_norm.jsonl",
                  so_path='../data/java-language.so'):
    
    with open(data_path, "rb")  as f:
        d = []
        for l in f.readlines():
            d.append([t.strip() for t in json.loads(l)['func'].split()])
        
    JAVA_LANGUAGE = Language(so_path, 'java')
    parser = Parser()
    parser.set_language(JAVA_LANGUAGE)
        
    for i in tqdm.tqdm(range(len(d))):
        if len_threshold > 0 and len(d[i])  >= len_threshold:
            continue

        _s = [t.strip() for t in d[i]]
        
        ws = WSStruct(_s, mask="<mask>", max_len=50)
        for j in range(16):
            masked = ws.gen_mask(1)
            idx = list(masked.keys())[0]
            assert "<mask>" in masked[idx]
            assert idx < 50
            assert not ws.update_ws(idx, "123")
            assert not ws.update_ws(idx, "for")
            assert not ws.update_ws(idx, "main")
            assert not ws.update_ws(idx, "NULL")
            insert = random.sample([' ', '\t', '\n'], 1)[0]
            assert ws.update_ws(idx, insert)
            assert ws.code[idx] == insert
        denorm = denormalize(ws.code)
        try:
            parser.parse(bytes(" ".join(denorm), encoding="utf-8"))
        except:
            assert False, \
                "\n"+str(ws.code)+"\n"+str(denorm)
                
def test_cba_uid_c(mlm_path="/var/data/lushuai/bertvsbert/save/only_mlm/poj/checkpoint-9000-1.0555",
                   cls_path="/var/data/lushuai/bertvsbert/save/only_mlm/poj-classifier/checkpoint-51000-0.986",
                   testset="../data/test.pkl", gpu="0"):

    if int(gpu) < 0:
        device = torch.device("cpu")
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu
        device = torch.device("cuda")
    
    from codebert import codebert_mlm, codebert_cls
    
    mlm_model = codebert_mlm(mlm_path, device)
    vocab_path = mlm_model.tokenizer.vocab_files_names["vocab_file"]
    with open(os.path.join(mlm_path, vocab_path), "r") as f:
        txt2idx = json.load(f)
        tmp = sorted(txt2idx.items(), key=lambda it: it[1])
        idx2txt = [it[0] for it in tmp]
        assert txt2idx[idx2txt[-1]] == len(idx2txt) - 1, \
            "\n"+idx2txt[-1]+"\n"+str(txt2idx[idx2txt[-1]])+"\n"+str(len(idx2txt)-1)
    cls_model = codebert_cls(cls_path, device)
    atk = CodeBERT_Attack_UID()
    with open(testset, "rb") as f:
        d = pickle.load(f)
    
    len_threshold = -1
    times_per_example = 5
    time_st = time.time()
    
    for i in range(len(d['src'])):
        if len_threshold > 0 and len(d['src'][i])  >= len_threshold:  
            continue
        s = [t.strip() for t in d['src'][i]]
        s_norm = normalize(s)
        y = d['label'][i] - 1
        logits = cls_model.run([" ".join(s_norm)])
        if logits[0].argmax().item() != y:
            continue
        uid = UIDStruct(s, mask=mlm_model.tokenizer.unk_token)
        for cnt_per_example in range(times_per_example):
            vulnerables = atk.find_vulnerable_uids(cls_model, uid, y)
            if len(vulnerables) <= 0:
                continue
            candidates, _ = atk.generate_candidates(uid, mlm_model, list(vulnerables.keys())[0],
                                                    idx2txt, n_candidate=10, smoothing=0.1)
            if candidates is None:
                continue
            uid.update_sym(list(vulnerables.keys())[0], candidates[0])
        assert CCode(" ".join(denormalize(uid.code))), \
            "\n"+" ".join(s_norm)+"\n"+" ".join(uid.code)
        print ("\r%.1f min | %6s / %6s |" % ((time.time()-time_st)/60., str(i), str(len(d['src']))), end="\r")
        
    print ("\n%.1f min" % ((time.time()-time_st)/60.))
    
def test_cba_uid_java(mlm_path="/var/data/zhanghz/codebert-base-mlm",
                      cls_path="/var/data/lushuai/bertvsbert/save/bcb-mlm/checkpoint-best-f1",
                      data_path="../data/bcb_norm.jsonl",
                      testset="../data/bcb_test_downsample.txt",
                      so_path='../data/java-language.so', gpu="3"):
    
    if int(gpu) < 0:
        device = torch.device("cpu")
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu
        device = torch.device("cuda")
        
    from codebert import codebert_mlm, codebert_clone
        
    mlm_model = codebert_mlm(mlm_path, device)
    vocab_path = mlm_model.tokenizer.vocab_files_names["vocab_file"]
    with open(os.path.join(mlm_path, vocab_path), "r") as f:
        txt2idx = json.load(f)
        tmp = sorted(txt2idx.items(), key=lambda it: it[1])
        idx2txt = [it[0] for it in tmp]
        assert txt2idx[idx2txt[-1]] == len(idx2txt) - 1, \
            "\n"+idx2txt[-1]+"\n"+str(txt2idx[idx2txt[-1]])+"\n"+str(len(idx2txt)-1)
    cls_model = codebert_clone(cls_path, device)
    atk = CodeBERT_Attack_UID(lang="java")
    
    with open(data_path, "r")  as f:
        src = {}
        for l in f.readlines():
            r = json.loads(l)
            src[r['idx']] = r['func'].strip().split()
    with open(testset, "r") as f:
        d = f.readlines()
        d = [i.strip().split() for i in d]
    
    len_threshold = -1
    times_per_example = 3
    time_st = time.time()
    
    for i in range(len(d)):
        if len_threshold > 0 and \
            (len(src[d[i][0]])  >= len_threshold or len(src[d[i][1]])  >= len_threshold):
            continue
        s = [t.strip() for t in src[d[i][0]]]
        s2 = [t.strip() for t in src[d[i][1]]]
        y = int(d[i][2])
        logits = cls_model.run([" ".join(s), " ".join(s2)])
        if logits[0].argmax().item() != y:
            continue
        uid = UIDStruct_Java(s, mask=mlm_model.tokenizer.unk_token, so_path=so_path, fake_class=True)
        uid2 = UIDStruct_Java(s2, mask=mlm_model.tokenizer.unk_token, so_path=so_path, fake_class=True)
        for cnt_per_example in range(times_per_example):
            vulnerables = atk.find_vulnerable_uids(cls_model, uid, y, uid2)
            if len(vulnerables) <= 0:
                break
            print (list(vulnerables.keys())[0], end="=>")
            print (",".join(mlm_model.tokenize(list(vulnerables.keys())[0])[0][1:-1]), end=" ")
            print (uid.sym2pos[list(vulnerables.keys())[0]], end=" ", flush=True)
            candidates, seq = atk.generate_candidates(uid, mlm_model, list(vulnerables.keys())[0],
                                                      idx2txt, n_candidate=10, smoothing=0.1,
                                                      len_threshold=400, max_computational=1e6,
                                                      other_uid=uid2)
            assert len(seq) == len(candidates) * 2, \
                str(len(seq)) + " " + str(len(candidates))
            if candidates is None:
                break
            uid.update_sym(list(vulnerables.keys())[0], candidates[0])
            print (len(candidates), end=" ")
            print (" ".join(candidates), flush=True)
        try:
            UIDStruct_Java(denormalize(uid.code), so_path=so_path, fake_class=True)
        except Exception as e:
            print (e)
            assert False, \
                "\n"+" ".join(s)+"\n"+" ".join(uid.code)+"\n"
        print ("%.1f min | %6s / %6s |" % ((time.time()-time_st)/60., str(i), str(len(d))))
        
    print ("\n%.1f min" % ((time.time()-time_st)/60.))
        
def test_cba_ws_java(mlm_path="/var/data/zhanghz/codebert-base-mlm",
                      cls_path="/var/data/lushuai/bertvsbert/save/bcb-mlm/checkpoint-best-f1",
                      data_path="../data/bcb_norm.jsonl",
                      testset="../data/bcb_test_downsample.txt",
                      so_path='../data/java-language.so', gpu="1"):
    
    if int(gpu) < 0:
        device = torch.device("cpu")
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu
        device = torch.device("cuda")
        
    JAVA_LANGUAGE = Language(so_path, 'java')
    parser = Parser()
    parser.set_language(JAVA_LANGUAGE)
        
    from codebert import codebert_mlm, codebert_clone
        
    mlm_model = codebert_mlm(mlm_path, device)
    vocab_path = mlm_model.tokenizer.vocab_files_names["vocab_file"]
    with open(os.path.join(mlm_path, vocab_path), "r") as f:
        txt2idx = json.load(f)
        tmp = sorted(txt2idx.items(), key=lambda it: it[1])
        idx2txt = [it[0] for it in tmp]
        assert txt2idx[idx2txt[-1]] == len(idx2txt) - 1, \
            "\n"+idx2txt[-1]+"\n"+str(txt2idx[idx2txt[-1]])+"\n"+str(len(idx2txt)-1)
    cls_model = codebert_clone(cls_path, device)
    atk = CodeBERT_Attack_WS()
    
    with open(data_path, "r")  as f:
        src = {}
        for l in f.readlines():
            r = json.loads(l)
            src[r['idx']] = r['func'].strip().split()
    with open(testset, "r") as f:
        d = f.readlines()
        d = [i.strip().split() for i in d]
    
    len_threshold = -1
    times_per_example = 3
    time_st = time.time()
    
    for i in range(len(d)):
        if len_threshold > 0 and \
            (len(src[d[i][0]])  >= len_threshold or len(src[d[i][1]])  >= len_threshold):
            continue
        s = [t.strip() for t in src[d[i][0]]]
        s2 = [t.strip() for t in src[d[i][1]]]
        y = int(d[i][2])
        logits = cls_model.run([" ".join(s), " ".join(s2)])
        if logits[0].argmax().item() != y:
            continue
        ws = WSStruct(s, mask=mlm_model.tokenizer.unk_token)
        ws2 = WSStruct(s2, mask=mlm_model.tokenizer.unk_token)
        for cnt_per_example in range(times_per_example):
            vulnerables = atk.find_vulnerable_tokens(cls_model, ws, y, ws2)
            if len(vulnerables) <= 0:
                break
            candidates, seq = atk.generate_candidates(ws, mlm_model, vulnerables,
                                                      n_candidate=10, len_threshold=400, 
                                                      other_ws=ws2)
            assert len(seq) == len(candidates) * 2, \
                str(len(seq)) + " " + str(len(candidates))
            if candidates is None:
                break
            assert ws.update_ws(candidates[0][0], candidates[0][1])
            assert ws.code[candidates[0][0]] == candidates[0][1]
        try:
            parser.parse(bytes(" ".join(denormalize(ws.code)), encoding="utf-8"))
        except Exception as e:
            print (e)
            assert False, \
                "\n"+" ".join(s)+"\n"+" ".join(ws.code)+"\n"
        print ("%.1f min | %6s / %6s |" % ((time.time()-time_st)/60., str(i), str(len(d))))
        
    print ("\n%.1f min" % ((time.time()-time_st)/60.))
    
                
if __name__ == "__main__":
    
    test_cba_ws_java()