CodeBERT-Attack / preprocess / oj104.py
oj104.py
Raw
# -*- coding: utf-8 -*-
"""
Created on Wed Oct 14 16:22:24 2020

@author: DrLC
"""

from cparser import CCode

import tarfile
import os
import tqdm
import random
import pickle
import argparse
import shutil

def unzip(file="../data/oj.tar.gz"):
    
    if not os.path.isdir("./"):
        os.mkdir("./")
    try:
        with tarfile.open(file) as t:
            t.extractall("./")
        return True
    except Exception as e:
        print (e)
        return False
  
def tokenize(token_file, _dir="ProgramData"):
    
    seq, label, idxs = [], [], []
    for d in tqdm.tqdm(os.listdir(_dir)):
        curr_dir = os.path.join(_dir, d)
        for idx in os.listdir(curr_dir):
            with open(os.path.join(curr_dir, idx), "r", encoding='latin1') as f:
                _input = f.read()
                try:
                    cc = CCode(_input)
                except:
                    continue
                if len(cc.getTokenSeq()) <= 10:
                    continue
                seq.append(cc.getTokenSeq())
                label.append(int(d))
                idxs.append(d+"_"+idx[:-4])
                
    with open(token_file, "wb") as f:
        pickle.dump({"src": seq, "label": label, "id": idxs}, f)
    
def is_fp(num):
    
    if num.lower() in ["inf", "nan"]:
        return False
    try:
        float(num)
        return True
    except:
        return False

def normalize(token_file, tgt_file):
    
    with open(token_file, "rb") as f:
        d = pickle.load(f)
        seqs = d["src"]
        labels = d["label"]
        ids = d["id"]
    assert len(seqs) == len(labels) and len(labels) == len(ids)
    norm = []
    for s in seqs:
        norm.append([])
        for t in s:
            if "'" in t:
                norm[-1].append("<char>")
            elif '"' in t:
                norm[-1].append("<str>")
            elif t.isdigit() or t[:2] == "0x":
                norm[-1].append("<int>")
            elif is_fp(t):
                norm[-1].append("<fp>")
            else:
                norm[-1].append(t)
    with open(tgt_file, "wb") as f:
        pickle.dump({"src": seqs, "norm": norm, "label": labels, "id": ids}, f)
                    
def split(norm_file, train_file, test_file, test_ratio=0.2):
    
    with open(norm_file, "rb") as f:
        d = pickle.load(f)
        seqs = d["norm"]
        oris = d["src"]
        labels = d["label"]
        idxs = d["id"]
    assert len(oris) == len(seqs) and len(seqs) == len(labels) and len(labels) == len(idxs)
    
    ids = random.sample(range(len(seqs)), len(seqs))
    n_test = int(test_ratio * len(seqs))
    _seqs, _oris, _labels, _idxs = [], [], [], []
    for i in ids[:n_test]:
        _seqs.append(seqs[i])
        _oris.append(oris[i])
        _labels.append(labels[i])
        _idxs.append(idxs[i])
    with open(test_file, "wb") as f:
        pickle.dump({"src": _oris, "norm": _seqs, "label": _labels, "id": _idxs}, f)
    _seqs, _oris, _labels, _idxs = [], [], [], []
    for i in ids[n_test:]:
        _seqs.append(seqs[i])
        _oris.append(oris[i])
        _labels.append(labels[i])
        _idxs.append(idxs[i])
    with open(train_file, "wb") as f:
        pickle.dump({"src": _oris, "norm": _seqs, "label": _labels, "id": _idxs}, f)

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument('-m', '--mode', type=str, default='pipeline',
                        help="[pipeline] for preprocessing, [clean] for cleaning, [remove] for removing all, [verify] for verifying, [move] for moving to data directory")
    
    opt = parser.parse_args()
    assert opt.mode in ["pipeline", "clean", "remove", "verify", "move"], \
        "Invalid mode "+opt.mode+". Use \"pipeline\", \"clean\", \"remove\", \"verify\" or \"move\" instead."
    
    unzip_dir = "ProgramData"
    unzip_done = "unzip.done"
    token_file = "token.pkl"
    token_done = "token.done"
    norm_file = "norm.pkl"
    norm_done = "norm.done"
    train_file = "train.pkl"
    test_file = "test.pkl"
    split_done = "split.done"
    
    if opt.mode == "pipeline":
        
        if not os.path.isfile(unzip_done):
            unzip()
            with open(unzip_done, "w") as f:
                pass
            
        if not os.path.isfile(token_done):
            tokenize(token_file, unzip_dir)
            with open(token_done, "w") as f:
                pass
            
        if not os.path.isfile(norm_done):
            normalize(token_file, norm_file)
            with open(norm_done, "w") as f:
                pass
    
        if not os.path.isfile(split_done):
            split(norm_file, train_file, test_file)
            with open(split_done, "w") as f:
                pass
          
    elif opt.mode == "verify":

        with open(train_file, "rb") as f:
            d = pickle.load(f)
            seqs = d["norm"]
            oris = d["src"]
            labels = d["label"]
            idxs = d["id"]
            assert len(oris) == len(seqs) and len(seqs) == len(labels) and len(labels) == len(idxs)
            for s in tqdm.tqdm(oris):
                CCode(" ".join(s))
        with open(test_file, "rb") as f:
            d = pickle.load(f)
            seqs = d["norm"]
            oris = d["src"]
            labels = d["label"]
            idxs = d["id"]
            assert len(oris) == len(seqs) and len(seqs) == len(labels) and len(labels) == len(idxs)
            for s in tqdm.tqdm(oris):
                CCode(" ".join(s))

    elif opt.mode == "move":
        
        shutil.copyfile(norm_file, "../data/"+norm_file)
        shutil.copyfile(train_file, "../data/"+train_file)
        shutil.copyfile(test_file, "../data/"+test_file)

    elif opt.mode == "clean":
        
        pass
    
    elif opt.mode == "remove":
        
        pass
    
    else:
        
        assert False