CodeBERT-Attack / oj-attack / compute_ppl.py
compute_ppl.py
Raw
# -*- coding: utf-8 -*-
"""
Created on Sun Jan  3 13:50:46 2021

@author: DrLC
"""

import pickle
import tqdm
import argparse
import torch
import os
import random

from token_level import UIDStruct_Java

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', type=str, default='-1')
    parser.add_argument('--gpt_path', type=str,
                        default="/var/data/zhanghz/codegpt",
                        help="Path to the codegpt")
    parser.add_argument('--train_path', type=str,
                        default="/var/data/lizhuo/bigJava/datasets/train.pkl",
                        help="Path to the testset")
    parser.add_argument('--tgt_path', type=str,
                        default="log/jdp_ppl.pkl",
                        help="Target ppl file path")
    parser.add_argument('--all_uid_path', type=str,
                        default="../data/jdp_all_uids.pkl",
                        help="Path to the UID pickle file")
    parser.add_argument('--mode', type=str, default="NONE",
                        help="Mode selection")
    
    opt = parser.parse_args()
    
    _mode = opt.mode.upper()
    assert _mode in ["NONE", "UID"]
    if _mode == "UID":
        max_n_rep = 100
        with open(opt.all_uid_path, "rb") as f:
            all_uids = pickle.load(f)

    if int(opt.gpu) < 0:
        device = torch.device("cpu")
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu
        device = torch.device("cuda")

    from codegpt import codegpt
    codegpt_model = codegpt(opt.gpt_path, device, block_size=512)
            
    with open(opt.train_path, "rb") as f:
        data = pickle.load(f)
    
    ppl = []
    for r in tqdm.tqdm(data['raw']):
        if _mode == "NONE":
            _r = r
        elif _mode == "UID":
            _r = r
            uid = UIDStruct_Java(r)
            n_rep = len(uid.sym2pos)
            rep_from = random.sample(list(uid.sym2pos.keys()), n_rep)
            for _from in rep_from:
                _to = random.sample(all_uids, 1)[0]
                while not uid.update_sym(_from, _to):
                    _to = random.sample(all_uids, 1)[0]
                while _from in _r:
                    _r[_r.index(_from)] = _to
        else:
            assert False
        ppl.append(codegpt_model.run(" ".join(_r)).item())
        
    with open(opt.tgt_path, "wb") as f:
        pickle.dump(ppl, f)