CodeBERT-Attack / oj-attack / utils.py
utils.py
Raw
# -*- coding: utf-8 -*-
"""
Created on Sat Oct 17 21:11:59 2020

@author: DrLC
"""

import re
from tree_sitter import Language, Parser

def is_fp(num):
    
    if num.lower() in ["inf", "nan"]:
        return False
    try:
        float(num)
        return True
    except:
        return False
    
def is_uid(uid):
    
    m = re.match(r"[a-zA-Z_][a-zA-Z0-9_]*", uid)
    if (m is not None) and m.group() == uid and (not is_keyword(uid)):
        return True
    return False

def is_keyword(uid):
    
    if uid in ["auto", "break", "case", "char", "const", "continue",
               "default", "do", "double", "else", "enum", "extern",
               "float", "for", "goto", "if", "inline", "int", "long",
               "register", "restrict", "return", "short", "signed",
               "sizeof", "static", "struct", "switch", "typedef",
               "union", "unsigned", "void", "volatile", "while",
               "_Alignas", "_Alignof", "_Atomic", "_Bool", "_Complex",
               "_Generic", "_Imaginary", "_Noreturn", "_Static_assert",
               "_Thread_local", "__func__"]:
        return True
    return False

def is_special_id(_id):
    
    if _id in ["NULL", "_IOFBF", "_IOLBF", "BUFSIZ", "EOF", "FOPEN_MAX", "TMP_MAX",  # <stdio.h> macro
               "FILENAME_MAX", "L_tmpnam", "SEEK_CUR", "SEEK_END", "SEEK_SET",
               "NULL", "EXIT_FAILURE", "EXIT_SUCCESS", "RAND_MAX", "MB_CUR_MAX",     # <stdlib.h> macro
               "main",                                      # main function
               "stdio", "cstdio", "stdio.h",                                # <stdio.h> & <cstdio>
               "size_t", "FILE", "fpos_t", "stdin", "stdout", "stderr",     # <stdio.h> types & streams
               "remove", "rename", "tmpfile", "tmpnam", "fclose", "fflush", # <stdio.h> functions
               "fopen", "freopen", "setbuf", "setvbuf", "fprintf", "fscanf",
               "printf", "scanf", "snprintf", "sprintf", "sscanf", "vprintf",
               "vscanf", "vsnprintf", "vsprintf", "vsscanf", "fgetc", "fgets",
               "fputc", "getc", "getchar", "putc", "putchar", "puts", "ungetc",
               "fread", "fwrite", "fgetpos", "fseek", "fsetpos", "ftell",
               "rewind", "clearerr", "feof", "ferror", "perror", "getline"
               "stdlib", "cstdlib", "stdlib.h",                             # <stdlib.h> & <cstdlib>
               "size_t", "div_t", "ldiv_t", "lldiv_t",                      # <stdlib.h> types
               "atof", "atoi", "atol", "atoll", "strtod", "strtof", "strtold",  # <stdlib.h> functions
               "strtol", "strtoll", "strtoul", "strtoull", "rand", "srand",
               "aligned_alloc", "calloc", "malloc", "realloc", "free", "abort",
               "atexit", "exit", "at_quick_exit", "_Exit", "getenv",
               "quick_exit", "system", "bsearch", "qsort", "abs", "labs",
               "llabs", "div", "ldiv", "lldiv", "mblen", "mbtowc", "wctomb",
               "mbstowcs", "wcstombs",
               "string", "cstring", "string.h",                                 # <string.h> & <cstring>
               "memcpy", "memmove", "memchr", "memcmp", "memset", "strcat",     # <string.h> functions
               "strncat", "strchr", "strrchr", "strcmp", "strncmp", "strcoll",
               "strcpy", "strncpy", "strerror", "strlen", "strspn", "strcspn",
               "strpbrk" ,"strstr", "strtok", "strxfrm",
               "memccpy", "mempcpy", "strcat_s", "strcpy_s", "strdup",      # <string.h> extension functions
               "strerror_r", "strlcat", "strlcpy", "strsignal", "strtok_r",
               "iostream", "istream", "ostream", "fstream", "sstream",      # <iostream> family
               "iomanip", "iosfwd",
               "ios", "wios", "streamoff", "streampos", "wstreampos",       # <iostream> types
               "streamsize", "cout", "cerr", "clog", "cin",
               "boolalpha", "noboolalpha", "skipws", "noskipws", "showbase",    # <iostream> manipulators
               "noshowbase", "showpoint", "noshowpoint", "showpos",
               "noshowpos", "unitbuf", "nounitbuf", "uppercase", "nouppercase",
               "left", "right", "internal", "dec", "oct", "hex", "fixed",
               "scientific", "hexfloat", "defaultfloat", "width", "fill",
               "precision", "endl", "ends", "flush", "ws", "showpoint",
               "sin", "cos", "tan", "asin", "acos", "atan", "atan2", "sinh",    # <math.h> functions
               "cosh", "tanh", "exp", "sqrt", "log", "log10", "pow", "powf",
               "ceil", "floor", "abs", "fabs", "cabs", "frexp", "ldexp",
               "modf", "fmod", "hypot", "ldexp", "poly", "matherr"]:
        return True
    return False

def is_java_uid(uid):
    
    m = re.match(r"[a-zA-Z_][a-zA-Z0-9_]*", uid)
    if (m is not None) and m.group() == uid and (not is_java_keyword(uid)):
        return True
    return False

def is_java_keyword(uid):
    
    # https://en.wikipedia.org/wiki/List_of_Java_keywords
    if uid in ["abstract", "assert", "boolean", "break", "byte", "case", "catch",
               "char", "class", "const", "continue", "default", "do", "double",
               "else", "enum", "extends", "final", "finally", "float", "for",
               "goto", "if", "implements", "import", "instanceof", "int", 
               "interface", "long", "native", "new", "package", "private",
               "protected", "public", "return", "strictfp", "short", "static",
               "super", "switch", "synchronized", "this", "throw", "throws",
               "transient", "try", "void", "volatile", "while",
               "permits", "record", "sealed", "var", "yield",
               "true", 'false', 'null',
               "const", "goto"]:
        return True
    return False

def is_java_special_id(uid):
    
    java_special_ids = getattr(is_java_special_id, 'special_ids', None)
    if java_special_ids is None:
        with open("../data/java_special_ids.txt", "r") as f:
            is_java_special_id.special_ids = f.read().split()
            is_java_special_id.special_ids.append("main")
            java_special_ids = getattr(is_java_special_id, 'special_ids', None)
    if uid in java_special_ids:
        return True
    return False

def tokenize_java(src, so_path='../data/java-language.so'):
    
    parser = getattr(tokenize_java, 'parser', None)
    if parser is None:
        JAVA_LANGUAGE = Language(so_path, 'java')
        tokenize_java.parser = Parser()
        tokenize_java.parser.set_language(JAVA_LANGUAGE)
        parser = getattr(tokenize_java, 'parser', None)
    byte_seq = bytes(src, encoding='utf-8')
    tree = parser.parse(byte_seq)
    tokens = []
    def get_tokenseq(node):
        if node.type != 'comment' and len(node.children) <= 0:
            tokens.append(byte_seq[node.start_byte: node.end_byte])
        for c in node.children:
            get_tokenseq(c)
    get_tokenseq(tree.root_node)
    tokens = [str(t, 'utf-8') for t in tokens]
    return tokens
        

def normalize(seq):
    
    norm = []
    for t in seq:
        if "'" in t:
            norm.append("<char>")
        elif '"' in t:
            norm.append("<str>")
        elif t.isdigit() or t[:2] == "0x":
            norm.append("<int>")
        elif is_fp(t):
            norm.append("<fp>")
        else:
            norm.append(t)
    return norm

def denormalize(seq):
    
    denorm = []
    for t in seq:
        if t == "<char>":
            denorm.append("'0'")
        elif t == "<str>":
            denorm.append("\"0\"")
        elif t == "<int>":
            denorm.append("0")
        elif t == "<fp>":
            denorm.append("0.0")
        else:
            denorm.append(t)
    return denorm

def align_subtokens(tokens, subtokens, bpe_indicator='Ġ', specials=[]):
    
    align = [-1]
    idx = 0
    for i in range(1, len(subtokens)-1):
        if subtokens[i][0] == bpe_indicator or subtokens[i] in specials:
            idx += 1
        align.append(idx)
    assert idx == len(tokens) - 1, \
        "\n"+str([(subtokens[i], tokens[align[i]]) for i in range(1, len(subtokens)-1)])
    align.append(-1)
    assert len(align) == len(subtokens)
    align_rev = []
    for i in range(1, len(subtokens)-1):
        if align[i] == len(align_rev):
            align_rev.append([])
        align_rev[align[i]].append(i)
    assert len(align_rev) == len(tokens)
    return align, align_rev