# -*- coding: utf-8 -*- """ Created on Wed Oct 14 21:45:02 2020 @author: DrLC """ from tree_sitter import Language, Parser from cparser import CCode from utils import normalize, is_uid, is_special_id, is_java_uid, is_java_special_id import copy import random import string class UIDStruct(object): def __init__(self, code, norm=True, mask="<unk>"): self.history = [] self.norm = norm self.code = [t.strip() for t in code] self._build_cc() self._build_symtab() if self.norm: self.code = normalize(self.code) self.mask = mask def _build_cc(self): self.ccode = CCode(" ".join(self.code)) def _build_symtab(self): self.sym2pos = {} for sym in self.ccode.getSymbolTable().getSymbols(): if is_special_id(sym): continue self.sym2pos[sym] = [] for i in range(len(self.code)): if self.code[i] == sym: self.sym2pos[sym].append(i) def gen_mask(self): mask = {} for s in self.sym2pos.keys(): mask[s] = copy.deepcopy(self.code) for p in self.sym2pos[s]: mask[s][p] = self.mask return mask def update_sym(self, old_sym, new_sym): if old_sym not in self.sym2pos.keys() or new_sym in self.sym2pos.keys(): return False if (not is_uid(new_sym)) or is_special_id(new_sym): return False self.history.append((old_sym, new_sym)) self.sym2pos[new_sym] = self.sym2pos.pop(old_sym) for p in self.sym2pos[new_sym]: self.code[p] = new_sym return True class UIDStruct_Java(object): def __init__(self, code, norm=True, mask="<unk>", so_path='../data/java-language.so', fake_class=True): self.history = [] self.norm = norm self.code = [t.strip() for t in code] if fake_class: fake_code = ["public", "class", "ThisIsAFakeClass_"+''.join(random.sample(string.ascii_letters + string.digits, 10)), "{"] \ + self.code + ["}"] else: fake_code = self.code self.code_bytes = bytes(" ".join(fake_code), "utf-8") JAVA_LANGUAGE = Language(so_path, 'java') self.parser = Parser() self.parser.set_language(JAVA_LANGUAGE) self.sym2pos = {} self._build_jc() self._build_symtab(self.jcode.root_node) if self.norm: self.code = normalize(self.code) self.mask = mask def _build_jc(self): self.jcode = self.parser.parse(self.code_bytes) def _build_symtab(self, node): if "declaration" in node.type or "declarator" in node.type: _uid = self._find_identifier(node) if _uid is not None and "ThisIsAFakeClass_" not in _uid and (not is_java_special_id(_uid)): if _uid not in self.sym2pos.keys(): self.sym2pos[_uid] = [] for i in range(len(self.code)): if self.code[i] == _uid: self.sym2pos[_uid].append(i) for c in node.children: self._build_symtab(c) def _find_identifier(self, node): if node.type == "identifier": return str(self.code_bytes[node.start_byte: node.end_byte], 'utf-8').strip() for c in node.children: ret = self._find_identifier(c) if ret is not None: return ret return None def gen_mask(self): mask = {} for s in self.sym2pos.keys(): mask[s] = copy.deepcopy(self.code) for p in self.sym2pos[s]: mask[s][p] = self.mask return mask def update_sym(self, old_sym, new_sym): if old_sym not in self.sym2pos.keys() or new_sym in self.sym2pos.keys(): return False if (not is_java_uid(new_sym)) or is_java_special_id(new_sym): return False self.history.append((old_sym, new_sym)) self.sym2pos[new_sym] = self.sym2pos.pop(old_sym) for p in self.sym2pos[new_sym]: self.code[p] = new_sym return True class WSStruct(object): def __init__(self, code, norm=True, mask="<unk>", max_len=None): self.history = [] self.norm = norm self.code = [t.strip() for t in code] if self.norm: self.code = normalize(self.code) self.mask = mask self.len_threshold = max_len if max_len is None: self.maxlen = len(self.code) else: self.maxlen = min(max_len, len(self.code)) def gen_mask(self, n_mask=None): mask = {} if n_mask is None: idxs = list(range(self.maxlen)) else: n_mask = min(self.maxlen, n_mask) idxs = random.sample(list(range(self.maxlen)), n_mask) for i in idxs: mask[i] = copy.deepcopy(self.code) mask[i][i] = self.mask return mask def update_ws(self, idx, insert): if insert not in [' ', '\t', '\n']: return False if idx < 0 or idx >= self.maxlen: return False self.history.append((idx, insert)) self.code = self.code[:idx] + [insert] + self.code[idx:] if self.len_threshold is None: self.maxlen += 1 else: self.maxlen = min(self.len_threshold, self.maxlen + 1) return True if __name__ == "__main__": import pickle import tqdm from utils import denormalize len_threshold = -1 with open("../data/norm.pkl", "rb") as f: d = pickle.load(f) for i in tqdm.tqdm(range(len(d['src']))): if len_threshold > 0 and len(d['src'][i]) >= len_threshold: continue s = [t.strip() for t in d['src'][i]] ''' uid = UIDStruct(s) assert uid.ccode.getTokenSeq() == s, \ "\n"+str(uid.ccode.getTokenSeq())+"\n"+str(s) assert uid.code == normalize(s), \ "\n"+str(uid.code)+"\n"+str(d['norm'][i]) for s in uid.sym2pos.keys(): for p in uid.sym2pos[s]: assert uid.code[p] == s, \ "\n"+str(uid.code)+"\n"+str(uid.sym2pos)+"\n"+str(p)+" ["+uid.code[p]+"] ["+s+"]" keys = list(uid.sym2pos.keys()) assert len(keys) < 2 or not uid.update_sym("AnImpossiblyLongVariableWhichCanNeverShowUpInCode", "Prefix_"+keys[1]) assert len(keys) < 2 or not uid.update_sym(keys[0], keys[1]) assert len(keys) < 2 or not uid.update_sym(keys[0], "123") assert len(keys) < 2 or not uid.update_sym(keys[0], "1_wrong_id") assert len(keys) < 2 or not uid.update_sym(keys[0], "+_()_$#*UIQ") assert len(keys) < 2 or not uid.update_sym(keys[0], "for") assert len(keys) < 2 or not uid.update_sym(keys[0], "main") assert len(keys) < 2 or not uid.update_sym(keys[0], "NULL") update = [] for m in keys: new_m = "Prefix_"+m uid.update_sym(m, new_m) update.append((m, new_m)) keys = list(uid.sym2pos.keys()) for m in keys: new_m = "Prefix_"+m uid.update_sym(m, new_m) update.append((m, new_m)) assert update == uid.history denorm = denormalize(uid.code) try: CCode(" ".join(denormalize(uid.code))) except: assert False, \ "\n"+str(uid.code)+"\n"+str(denorm) ''' ws = WSStruct(s, mask="<mask>", max_len=50) for j in range(16): masked = ws.gen_mask(1) idx = list(masked.keys())[0] assert "<mask>" in masked[idx] assert idx < 50 assert not ws.update_ws(idx, "123") assert not ws.update_ws(idx, "for") assert not ws.update_ws(idx, "main") assert not ws.update_ws(idx, "NULL") insert = random.sample([' ', '\t', '\n'], 1)[0] assert ws.update_ws(idx, insert) assert ws.code[idx] == insert denorm = denormalize(ws.code) try: CCode(" ".join(denorm)) except: assert False, \ "\n"+str(ws.code)+"\n"+str(denorm)