CodeBERT-Attack / oj-attack / exp_level.py
exp_level.py
Raw
# -*- coding: utf-8 -*-
"""
Created on Tue Nov 17 14:34:30 2020

@author: DrLC
"""

from cparser import CCode
from utils import normalize

import copy

class CMPStruct(object):
    
    def __init__(self, code, norm=True, mask="<unk>"):
        
        self.history = []
        self.norm = norm
        self.code = [t.strip() for t in code]
        self._build_cc()
        if self.norm:
            self.code = normalize(self.code)
        self.mask = mask
        self._build_cmptab()
    
    def _build_cc(self):
        
        self.ccode = CCode(" ".join(self.code))
        
    def _build_cmptab(self):
        
        self.cmptab = []
        self._idx = 0
        self._sym_stack = []
        self._traverse(self.ccode.getParsingTree())
        del self._idx
        del self._sym_stack
    
    def _traverse(self, node, cnt = None):
        
        if cnt is None:
            cnt = []
        if node.getType() == '*':
            if len(node.getChildren()) == 3 and node.getValue() in ['equality_expression', 'relational_expression']:
                cnt.append(0)
            for c in node.getChildren():
                cnt = self._traverse(c, cnt)
            if len(node.getChildren()) == 3 and node.getValue() in ['equality_expression', 'relational_expression']:
                self.cmptab.append([self._idx - cnt[-1], self._sym_stack[-1], self._idx])
                cnt = cnt[:-1]
                self._sym_stack = self._sym_stack[:-1]
        else:
            if node.getValue() in ['==', '!=', '<', '<=', '>', '>=']:
                self._sym_stack.append(self._idx)
            cnt = [i + 1 for i in cnt]
            self._idx += 1
        return cnt
            
    def gen_mask(self):
        
        mask = []
        for s, _, t in self.cmptab:
            mask.append(copy.deepcopy(self.code))
            for p in range(s, t):
                mask[-1][p] = self.mask
        return mask
        
    def update_cmp(self, index):
        
        if type(index) is not int:
            return False
        if index < 0 or index >= len(self.cmptab):
            return False
        self.history.append(index)
        if self.code[self.cmptab[index][1]] in ['==', '!=']:
            self.code = self.code[:self.cmptab[index][0]] \
                + self.code[self.cmptab[index][1]+1: self.cmptab[index][2]] \
                + [self.code[self.cmptab[index][1]]] \
                + self.code[self.cmptab[index][0]: self.cmptab[index][1]] \
                + self.code[self.cmptab[index][2]:]
        elif self.code[self.cmptab[index][1]] == '<':
            self.code = self.code[:self.cmptab[index][0]] \
                + self.code[self.cmptab[index][1]+1: self.cmptab[index][2]] \
                + ['>'] \
                + self.code[self.cmptab[index][0]: self.cmptab[index][1]] \
                + self.code[self.cmptab[index][2]:]
        elif self.code[self.cmptab[index][1]] == '<=':
            self.code = self.code[:self.cmptab[index][0]] \
                + self.code[self.cmptab[index][1]+1: self.cmptab[index][2]] \
                + ['>='] \
                + self.code[self.cmptab[index][0]: self.cmptab[index][1]] \
                + self.code[self.cmptab[index][2]:]
        elif self.code[self.cmptab[index][1]] == '>':
            self.code = self.code[:self.cmptab[index][0]] \
                + self.code[self.cmptab[index][1]+1: self.cmptab[index][2]] \
                + ['<'] \
                + self.code[self.cmptab[index][0]: self.cmptab[index][1]] \
                + self.code[self.cmptab[index][2]:]
        elif self.code[self.cmptab[index][1]] == '>=':
            self.code = self.code[:self.cmptab[index][0]] \
                + self.code[self.cmptab[index][1]+1: self.cmptab[index][2]] \
                + ['<='] \
                + self.code[self.cmptab[index][0]: self.cmptab[index][1]] \
                + self.code[self.cmptab[index][2]:]
        else:
            return False
        l_add = self.cmptab[index][2] - self.cmptab[index][1]
        r_sub = self.cmptab[index][1] + 1 - self.cmptab[index][0]
        for i in range(len(self.cmptab)):
            if i == index:
                continue
            if self.cmptab[i][0] >= self.cmptab[index][1] and self.cmptab[i][2] <= self.cmptab[index][2]:
                self.cmptab[i][0] = self.cmptab[i][0] - r_sub
                self.cmptab[i][1] = self.cmptab[i][1] - r_sub
                self.cmptab[i][2] = self.cmptab[i][2] - r_sub
            elif self.cmptab[i][0] >= self.cmptab[index][0] and self.cmptab[i][2] <= self.cmptab[index][1]:
                self.cmptab[i][0] = self.cmptab[i][0] + l_add
                self.cmptab[i][1] = self.cmptab[i][1] + l_add
                self.cmptab[i][2] = self.cmptab[i][2] + l_add
        self.cmptab[index][1] = self.cmptab[index][2] + self.cmptab[index][0] - self.cmptab[index][1] - 1
        return True
    
        
if __name__ == "__main__":
    
    import pickle
    import tqdm
    from utils import denormalize
    
    len_threshold = -1
    
    with open("../data/norm.pkl", "rb")  as f:
        d = pickle.load(f)
        
    for i in tqdm.tqdm(range(len(d['src']))):
        if len_threshold > 0 and len(d['src'][i])  >= len_threshold:
            continue
        
        try:
            s = [t.strip() for t in d['src'][i]]
            s_norm = normalize(s)
            cmp = CMPStruct(s)
            
            assert cmp.ccode.getTokenSeq() == s, \
                "\n"+str(cmp.ccode.getTokenSeq())+"\n"+str(s)
            assert cmp.code == s_norm, \
                "\n"+str(cmp.code)+"\n"+str(d['norm'][i])
            for _s, _t, _e in cmp.cmptab:
                assert cmp.code[_t] in ['==', '!=', '<', '<=', '>', '>='], \
                    "\n"+str(cmp.code)+"\n"+str(cmp.code[_s:_t])+"\n"+str(cmp.code[_t:_e])
    
            for i in range(len(cmp.cmptab)):
                assert cmp.update_cmp(i), \
                    "\n"+str(cmp.code)+"\n"+str(cmp.code[cmp.cmptab[i][0]:cmp.cmptab[i][1]]) \
                    +"\n"+str(cmp.code[cmp.cmptab[i][1]:cmp.cmptab[i][2]])
                for _s, _t, _e in cmp.cmptab:
                    assert cmp.code[_t] in ['==', '!=', '<', '<=', '>', '>='], \
                        "\n"+str(cmp.code)+"\n"+str(cmp.code[_s:_t])+"\n"+str(cmp.code[_t:_e])
            for i in range(len(cmp.cmptab)):
                assert cmp.update_cmp(i), \
                    "\n"+str(cmp.code)+"\n"+str(cmp.code[cmp.cmptab[i][0]:cmp.cmptab[i][1]]) \
                    +"\n"+str(cmp.code[cmp.cmptab[i][1]:cmp.cmptab[i][2]])
                for _s, _t, _e in cmp.cmptab:
                    assert cmp.code[_t] in ['==', '!=', '<', '<=', '>', '>='], \
                        "\n"+str(cmp.code)+"\n"+str(cmp.code[_s:_t])+"\n"+str(cmp.code[_t:_e])
                
            for t1, t2 in zip(cmp.code, s_norm):
                assert t1 == t2, \
                    "\n"+str(cmp.code)+"\n"+str(s_norm)
            
            denorm = denormalize(cmp.code)
            try:
                assert CCode(" ".join(denorm))
            except:
                assert False, \
                    "\n"+str(cmp.code)+"\n"+str(denorm)
        except RecursionError as e:
            continue