# -*- coding: utf-8 -*- """ Created on Tue Nov 17 14:34:30 2020 @author: DrLC """ from cparser import CCode from utils import normalize import copy class CMPStruct(object): def __init__(self, code, norm=True, mask="<unk>"): self.history = [] self.norm = norm self.code = [t.strip() for t in code] self._build_cc() if self.norm: self.code = normalize(self.code) self.mask = mask self._build_cmptab() def _build_cc(self): self.ccode = CCode(" ".join(self.code)) def _build_cmptab(self): self.cmptab = [] self._idx = 0 self._sym_stack = [] self._traverse(self.ccode.getParsingTree()) del self._idx del self._sym_stack def _traverse(self, node, cnt = None): if cnt is None: cnt = [] if node.getType() == '*': if len(node.getChildren()) == 3 and node.getValue() in ['equality_expression', 'relational_expression']: cnt.append(0) for c in node.getChildren(): cnt = self._traverse(c, cnt) if len(node.getChildren()) == 3 and node.getValue() in ['equality_expression', 'relational_expression']: self.cmptab.append([self._idx - cnt[-1], self._sym_stack[-1], self._idx]) cnt = cnt[:-1] self._sym_stack = self._sym_stack[:-1] else: if node.getValue() in ['==', '!=', '<', '<=', '>', '>=']: self._sym_stack.append(self._idx) cnt = [i + 1 for i in cnt] self._idx += 1 return cnt def gen_mask(self): mask = [] for s, _, t in self.cmptab: mask.append(copy.deepcopy(self.code)) for p in range(s, t): mask[-1][p] = self.mask return mask def update_cmp(self, index): if type(index) is not int: return False if index < 0 or index >= len(self.cmptab): return False self.history.append(index) if self.code[self.cmptab[index][1]] in ['==', '!=']: self.code = self.code[:self.cmptab[index][0]] \ + self.code[self.cmptab[index][1]+1: self.cmptab[index][2]] \ + [self.code[self.cmptab[index][1]]] \ + self.code[self.cmptab[index][0]: self.cmptab[index][1]] \ + self.code[self.cmptab[index][2]:] elif self.code[self.cmptab[index][1]] == '<': self.code = self.code[:self.cmptab[index][0]] \ + self.code[self.cmptab[index][1]+1: self.cmptab[index][2]] \ + ['>'] \ + self.code[self.cmptab[index][0]: self.cmptab[index][1]] \ + self.code[self.cmptab[index][2]:] elif self.code[self.cmptab[index][1]] == '<=': self.code = self.code[:self.cmptab[index][0]] \ + self.code[self.cmptab[index][1]+1: self.cmptab[index][2]] \ + ['>='] \ + self.code[self.cmptab[index][0]: self.cmptab[index][1]] \ + self.code[self.cmptab[index][2]:] elif self.code[self.cmptab[index][1]] == '>': self.code = self.code[:self.cmptab[index][0]] \ + self.code[self.cmptab[index][1]+1: self.cmptab[index][2]] \ + ['<'] \ + self.code[self.cmptab[index][0]: self.cmptab[index][1]] \ + self.code[self.cmptab[index][2]:] elif self.code[self.cmptab[index][1]] == '>=': self.code = self.code[:self.cmptab[index][0]] \ + self.code[self.cmptab[index][1]+1: self.cmptab[index][2]] \ + ['<='] \ + self.code[self.cmptab[index][0]: self.cmptab[index][1]] \ + self.code[self.cmptab[index][2]:] else: return False l_add = self.cmptab[index][2] - self.cmptab[index][1] r_sub = self.cmptab[index][1] + 1 - self.cmptab[index][0] for i in range(len(self.cmptab)): if i == index: continue if self.cmptab[i][0] >= self.cmptab[index][1] and self.cmptab[i][2] <= self.cmptab[index][2]: self.cmptab[i][0] = self.cmptab[i][0] - r_sub self.cmptab[i][1] = self.cmptab[i][1] - r_sub self.cmptab[i][2] = self.cmptab[i][2] - r_sub elif self.cmptab[i][0] >= self.cmptab[index][0] and self.cmptab[i][2] <= self.cmptab[index][1]: self.cmptab[i][0] = self.cmptab[i][0] + l_add self.cmptab[i][1] = self.cmptab[i][1] + l_add self.cmptab[i][2] = self.cmptab[i][2] + l_add self.cmptab[index][1] = self.cmptab[index][2] + self.cmptab[index][0] - self.cmptab[index][1] - 1 return True if __name__ == "__main__": import pickle import tqdm from utils import denormalize len_threshold = -1 with open("../data/norm.pkl", "rb") as f: d = pickle.load(f) for i in tqdm.tqdm(range(len(d['src']))): if len_threshold > 0 and len(d['src'][i]) >= len_threshold: continue try: s = [t.strip() for t in d['src'][i]] s_norm = normalize(s) cmp = CMPStruct(s) assert cmp.ccode.getTokenSeq() == s, \ "\n"+str(cmp.ccode.getTokenSeq())+"\n"+str(s) assert cmp.code == s_norm, \ "\n"+str(cmp.code)+"\n"+str(d['norm'][i]) for _s, _t, _e in cmp.cmptab: assert cmp.code[_t] in ['==', '!=', '<', '<=', '>', '>='], \ "\n"+str(cmp.code)+"\n"+str(cmp.code[_s:_t])+"\n"+str(cmp.code[_t:_e]) for i in range(len(cmp.cmptab)): assert cmp.update_cmp(i), \ "\n"+str(cmp.code)+"\n"+str(cmp.code[cmp.cmptab[i][0]:cmp.cmptab[i][1]]) \ +"\n"+str(cmp.code[cmp.cmptab[i][1]:cmp.cmptab[i][2]]) for _s, _t, _e in cmp.cmptab: assert cmp.code[_t] in ['==', '!=', '<', '<=', '>', '>='], \ "\n"+str(cmp.code)+"\n"+str(cmp.code[_s:_t])+"\n"+str(cmp.code[_t:_e]) for i in range(len(cmp.cmptab)): assert cmp.update_cmp(i), \ "\n"+str(cmp.code)+"\n"+str(cmp.code[cmp.cmptab[i][0]:cmp.cmptab[i][1]]) \ +"\n"+str(cmp.code[cmp.cmptab[i][1]:cmp.cmptab[i][2]]) for _s, _t, _e in cmp.cmptab: assert cmp.code[_t] in ['==', '!=', '<', '<=', '>', '>='], \ "\n"+str(cmp.code)+"\n"+str(cmp.code[_s:_t])+"\n"+str(cmp.code[_t:_e]) for t1, t2 in zip(cmp.code, s_norm): assert t1 == t2, \ "\n"+str(cmp.code)+"\n"+str(s_norm) denorm = denormalize(cmp.code) try: assert CCode(" ".join(denorm)) except: assert False, \ "\n"+str(cmp.code)+"\n"+str(denorm) except RecursionError as e: continue