CodeBERT-Attack / oj-attack / token_level.py
token_level.py
Raw
# -*- coding: utf-8 -*-
"""
Created on Wed Oct 14 21:45:02 2020

@author: DrLC
"""

from tree_sitter import Language, Parser

from cparser import CCode
from utils import normalize, is_uid, is_special_id, is_java_uid, is_java_special_id

import copy
import random
import string

class UIDStruct(object):
    
    def __init__(self, code, norm=True, mask="<unk>"):
        
        self.history = []
        self.norm = norm
        self.code = [t.strip() for t in code]
        self._build_cc()
        self._build_symtab()
        if self.norm:
            self.code = normalize(self.code)
        self.mask = mask
    
    def _build_cc(self):
        
        self.ccode = CCode(" ".join(self.code))
    
    def _build_symtab(self):
        
        self.sym2pos = {}
        for sym in self.ccode.getSymbolTable().getSymbols():
            if is_special_id(sym):
                continue
            self.sym2pos[sym] = []
            for i in range(len(self.code)):
                if self.code[i] == sym:
                    self.sym2pos[sym].append(i)
                    
    def gen_mask(self):
        
        mask = {}
        for s in self.sym2pos.keys():
            mask[s] = copy.deepcopy(self.code)
            for p in self.sym2pos[s]:
                mask[s][p] = self.mask
        return mask
    
    def update_sym(self, old_sym, new_sym):
        
        if old_sym not in self.sym2pos.keys() or new_sym in self.sym2pos.keys():
            return False
        if (not is_uid(new_sym)) or is_special_id(new_sym):
            return False
        self.history.append((old_sym, new_sym))
        self.sym2pos[new_sym] = self.sym2pos.pop(old_sym)
        for p in self.sym2pos[new_sym]:
            self.code[p] = new_sym
        return True
    
class UIDStruct_Java(object):
    
    def __init__(self, code, norm=True, mask="<unk>",
                 so_path='../data/java-language.so', fake_class=True):
        
        self.history = []
        self.norm = norm
        self.code = [t.strip() for t in code]
        if fake_class:
            fake_code = ["public", "class", "ThisIsAFakeClass_"+''.join(random.sample(string.ascii_letters + string.digits, 10)), "{"] \
                + self.code + ["}"]
        else:
            fake_code = self.code
        self.code_bytes = bytes(" ".join(fake_code), "utf-8")
        JAVA_LANGUAGE = Language(so_path, 'java')
        self.parser = Parser()
        self.parser.set_language(JAVA_LANGUAGE)
        self.sym2pos = {}
        self._build_jc()
        self._build_symtab(self.jcode.root_node)
        if self.norm:
            self.code = normalize(self.code)
        self.mask = mask
    
    def _build_jc(self):
        
        self.jcode = self.parser.parse(self.code_bytes)
        
    def _build_symtab(self, node):
    
        if "declaration" in node.type or "declarator" in node.type:
            _uid = self._find_identifier(node)
            if _uid is not None and "ThisIsAFakeClass_" not in _uid and (not is_java_special_id(_uid)):
                if _uid not in self.sym2pos.keys():
                    self.sym2pos[_uid] = []
                    for i in range(len(self.code)):
                        if self.code[i] == _uid:
                            self.sym2pos[_uid].append(i)
        for c in node.children:
            self._build_symtab(c)
                    
    def _find_identifier(self, node):
        
        if node.type == "identifier":
            return str(self.code_bytes[node.start_byte: node.end_byte], 'utf-8').strip()
        for c in node.children:
            ret = self._find_identifier(c)
            if ret is not None:
                return ret
        return None
            
    def gen_mask(self):
        
        mask = {}
        for s in self.sym2pos.keys():
            mask[s] = copy.deepcopy(self.code)
            for p in self.sym2pos[s]:
                mask[s][p] = self.mask
        return mask
    
    def update_sym(self, old_sym, new_sym):
        
        if old_sym not in self.sym2pos.keys() or new_sym in self.sym2pos.keys():
            return False
        if (not is_java_uid(new_sym)) or is_java_special_id(new_sym):
            return False
        self.history.append((old_sym, new_sym))
        self.sym2pos[new_sym] = self.sym2pos.pop(old_sym)
        for p in self.sym2pos[new_sym]:
            self.code[p] = new_sym
        return True
    
class WSStruct(object):
    
    def __init__(self, code, norm=True, mask="<unk>", max_len=None):
        
        self.history = []
        self.norm = norm
        self.code = [t.strip() for t in code]
        if self.norm:
            self.code = normalize(self.code)
        self.mask = mask
        self.len_threshold = max_len
        if max_len is None:
            self.maxlen = len(self.code)
        else:
            self.maxlen = min(max_len, len(self.code))
    
    def gen_mask(self, n_mask=None):
        
        mask = {}
        if n_mask is None:
            idxs = list(range(self.maxlen))
        else:
            n_mask = min(self.maxlen, n_mask)
            idxs = random.sample(list(range(self.maxlen)), n_mask)
        for i in idxs:
            mask[i] = copy.deepcopy(self.code)
            mask[i][i] = self.mask
        return mask
    
    def update_ws(self, idx, insert):
        
        if insert not in [' ', '\t', '\n']:
            return False
        if idx < 0 or idx >= self.maxlen:
            return False
        self.history.append((idx, insert))
        self.code = self.code[:idx] + [insert] + self.code[idx:]
        if self.len_threshold is None:
            self.maxlen += 1
        else:
            self.maxlen = min(self.len_threshold, self.maxlen + 1)
        return True
        
    
    
if __name__ == "__main__":
    
    import pickle
    import tqdm
    from utils import denormalize
    
    len_threshold = -1
    
    with open("../data/norm.pkl", "rb")  as f:
        d = pickle.load(f)
        
    for i in tqdm.tqdm(range(len(d['src']))):
        if len_threshold > 0 and len(d['src'][i])  >= len_threshold:
            continue

        s = [t.strip() for t in d['src'][i]]
        
        '''
        uid = UIDStruct(s)

        assert uid.ccode.getTokenSeq() == s, \
            "\n"+str(uid.ccode.getTokenSeq())+"\n"+str(s)
        assert uid.code == normalize(s), \
            "\n"+str(uid.code)+"\n"+str(d['norm'][i])
        
        for s in uid.sym2pos.keys():
            for p in uid.sym2pos[s]:
                assert uid.code[p] == s, \
                    "\n"+str(uid.code)+"\n"+str(uid.sym2pos)+"\n"+str(p)+" ["+uid.code[p]+"] ["+s+"]"

        keys = list(uid.sym2pos.keys())
        
        assert len(keys) < 2 or not uid.update_sym("AnImpossiblyLongVariableWhichCanNeverShowUpInCode", "Prefix_"+keys[1])
        assert len(keys) < 2 or not uid.update_sym(keys[0], keys[1])
        assert len(keys) < 2 or not uid.update_sym(keys[0], "123")
        assert len(keys) < 2 or not uid.update_sym(keys[0], "1_wrong_id")
        assert len(keys) < 2 or not uid.update_sym(keys[0], "+_()_$#*UIQ")
        assert len(keys) < 2 or not uid.update_sym(keys[0], "for")
        assert len(keys) < 2 or not uid.update_sym(keys[0], "main")
        assert len(keys) < 2 or not uid.update_sym(keys[0], "NULL")

        update = []
        for m in keys:
            new_m = "Prefix_"+m
            uid.update_sym(m, new_m)
            update.append((m, new_m))
        keys = list(uid.sym2pos.keys())
        for m in keys:
            new_m = "Prefix_"+m
            uid.update_sym(m, new_m)
            update.append((m, new_m))
        assert update == uid.history
        denorm = denormalize(uid.code)
        try:
            CCode(" ".join(denormalize(uid.code)))
        except:
            assert False, \
                "\n"+str(uid.code)+"\n"+str(denorm)
        '''
        
        ws = WSStruct(s, mask="<mask>", max_len=50)
        for j in range(16):
            masked = ws.gen_mask(1)
            idx = list(masked.keys())[0]
            assert "<mask>" in masked[idx]
            assert idx < 50
            assert not ws.update_ws(idx, "123")
            assert not ws.update_ws(idx, "for")
            assert not ws.update_ws(idx, "main")
            assert not ws.update_ws(idx, "NULL")
            insert = random.sample([' ', '\t', '\n'], 1)[0]
            assert ws.update_ws(idx, insert)
            assert ws.code[idx] == insert
        denorm = denormalize(ws.code)
        try:
            CCode(" ".join(denorm))
        except:
            assert False, \
                "\n"+str(ws.code)+"\n"+str(denorm)