CodeBERT-Attack / oj-attack / main_attack_cls.py
main_attack_cls.py
Raw
# -*- coding: utf-8 -*-
"""
Created on Tue Oct 27 15:58:42 2020

@author: DrLC
"""

from token_level import UIDStruct, WSStruct, UIDStruct_Java
from exp_level import CMPStruct
from utils import is_uid, is_special_id, normalize, denormalize, is_java_uid, is_java_special_id
from torch.nn import CrossEntropyLoss, Softmax
from codebert_attack import CodeBERT_Attack_UID, CodeBERT_Attack_CMP, CodeBERT_Attack_WS
from mhm import MHM_Baseline

# Import codebert & rnn later

import torch    
import argparse
import os
import pickle
import json
import time
import numpy
import random
import copy

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', type=str, default='-1', help="Gpu selection")
    parser.add_argument('--mlm_path', type=str,
                        default="/var/data/lushuai/bertvsbert/save/only_mlm/poj/checkpoint-9000-1.0555",
                        help="Path to the masked language model")
    parser.add_argument('--cls_path', type=str,
                        default="/var/data/lushuai/bertvsbert/save/only_mlm/poj-classifier/checkpoint-51000-0.986",
                        help="Path to the downstream CodeBERT classifier")
    parser.add_argument('--trainset', type=str, default="../data/train.pkl",
                        help="Path to the train set")
    parser.add_argument('--testset', type=str, default="../data/test.pkl",
                        help="Path to the test set")
    parser.add_argument('--uid_path', type=str, default="../data/all_uids.pkl",
                        help="Path to the uid file")
    parser.add_argument('--attack', type=str, default='cba',
                        help="Attack approach")
    parser.add_argument('--max_perturb_iter', type=int, default=20,
                        help="Maximal iteration of perturbtions")
    parser.add_argument('--max_vulnerable', type=int, default=5,
                        help="Maximal vulnerable number")
    parser.add_argument('--max_mask_ws', type=int, default=30,
                        help="Maximal mask number for white space attack")
    parser.add_argument('--max_candidate', type=int, default=10,
                        help="Maximal candidate number")
    parser.add_argument('--smooth_factor', type=float, default=0.1,
                        help="Smoothing factor during merging")
    parser.add_argument('--init_temperature', type=float, default=1,
                        help="Temperature initialization for SA")
    parser.add_argument('--cooling_factor', type=float, default=0.8,
                        help="Temperature decreasement for SA")
    parser.add_argument('--model', type=str, default="CB",
                        help="Target model / victim model")
    parser.add_argument('--so_path', type=str, default='../data/java-language.so',
                        help="Path to the java parser library")
    parser.add_argument('--word2vec', type=str, default="../data/w2v.model",
                        help="Path to the word2vec matrix")
    parser.add_argument('--rnn_path', type=str, default='../model/oj_lstm/model.pt',
                        help="Path to the downstream RNN classifier")
    
    parser.add_argument('--data_name', type=str, default="OJ",
                        help='Name of the dataset')
    parser.add_argument('--rand_idx', type=bool, default=False,
                        help="Random indexing without vulnerablility")
    parser.add_argument('--rand_cand', type=bool, default=True,
                        help="Random candidate without codebert")
    
    opt = parser.parse_args()
    
    print ("CONFIGS")
    args = vars(opt)
    for k in args.keys():
        print ("  " + k + " = " + str(args[k]))
    
    _attack = opt.attack.upper()
    # CodeBERT-Attack, MHM, Simulated Annealling CBA, Compare-swapping CBA, Compare-swapping SACBA, White-space-inserting CBA, White-space-inserting SACBA
    assert _attack in ['CBA', 'MHM', 'SACBA', 'CBA-CMP', 'SACBA-CMP', 'CBA-WS', 'SACBA-WS']
    _victim_model = opt.model.upper()
    assert _victim_model in ['CB', 'LSTM']   # CodeBERT, LSTM
    
    n_perturb = opt.max_perturb_iter
    n_vulnerable = opt.max_vulnerable
    n_mask_ws = opt.max_mask_ws
    n_candidate = opt.max_candidate
    smoothing = opt.smooth_factor

    # Load the models (MLM and downstream CLS)
    if int(opt.gpu) < 0:
        device = torch.device("cpu")
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu
        device = torch.device("cuda")
        
    # Cannot select GPU after importing transformers...
    from codebert import codebert_mlm, codebert_cls
    from rnn import RNNClassifier
        
    data_name = opt.data_name.upper()
    n_class = {"OJ": 104, "JDP": 2}[data_name]
    class_op = {"OJ": -1, "JDP": 0}[data_name]
    if data_name == "JDP":
        _is_uid = is_java_uid
        _is_special_id = is_java_special_id
    elif data_name == "OJ":
        _is_uid = is_uid
        _is_special_id = is_special_id
    _lang = {"OJ": "C", "JDP": "JAVA"}[data_name]
    
    mlm_model = codebert_mlm(opt.mlm_path, device)
    if _victim_model == "CB":
        cls_model = codebert_cls(opt.cls_path, device, n_class=n_class)
    else:
        hidden_size = 600
        n_layers = 2
        max_len = 512
        attn = True
        bidirection = True
        cls_model = RNNClassifier(num_class=n_class,
                                  hidden_dim=hidden_size,
                                  n_layers=n_layers,
                                  tokenizer_path=opt.cls_path,
                                  w2v_path=opt.word2vec,
                                  max_len=max_len,
                                  drop_prob=0.,
                                  model=_victim_model,
                                  brnn=bidirection,
                                  attn=attn,
                                  device=device).to(device)
        cls_model.load_state_dict(torch.load(opt.rnn_path))
        cls_model.eval()
    # Load the vocabulary of MLM
    vocab_path = mlm_model.tokenizer.vocab_files_names["vocab_file"]
    with open(os.path.join(opt.mlm_path, vocab_path), "r") as f:
        txt2idx = json.load(f)
        tmp = sorted(txt2idx.items(), key=lambda it: it[1])
        idx2txt = [it[0] for it in tmp]
        assert txt2idx[idx2txt[-1]] == len(idx2txt) - 1, \
            "\n"+idx2txt[-1]+"\n"+str(txt2idx[idx2txt[-1]])+"\n"+str(len(idx2txt)-1)
    # Load the test set
    with open(opt.testset, "rb") as f:
        d = pickle.load(f)
        
    if opt.rand_cand:
        with open(opt.uid_path, "rb") as f:
            all_uids = pickle.load(f)
        print (len(all_uids))
            
    # Renaming attack
    if _attack == 'CBA':
        atk = CodeBERT_Attack_UID(lang=_lang)
    elif _attack == 'MHM':
        # When using MHM, collect all uids from the training set
        with open(opt.uid_path, "rb") as f:
            all_uids = pickle.load(f)
        print (len(all_uids))
        atk = MHM_Baseline(all_uids, lang=_lang)
    elif _attack == 'SACBA':
        atk = CodeBERT_Attack_UID(lang=_lang)
        temperature = lambda n: opt.init_temperature * (opt.cooling_factor ** n)
    elif _attack == 'CBA-CMP':
        atk = CodeBERT_Attack_CMP()
    elif _attack == 'SACBA-CMP':
        atk = CodeBERT_Attack_CMP()
        temperature = lambda n: opt.init_temperature * (opt.cooling_factor ** n)
    elif _attack == 'CBA-WS':
        atk = CodeBERT_Attack_WS()
    elif _attack == 'SACBA-WS':
        atk = CodeBERT_Attack_WS()
        temperature = lambda n: opt.init_temperature * (opt.cooling_factor ** n)
    else:
        assert False
    
    n_total_including_originally_wrong = 0
    n_total = 0
    n_succ = 0
    time_total = 0
    
    ce = CrossEntropyLoss(reduction="none")
    softmax = Softmax(dim=-1)
    
    for i in range(len(d['norm'])):
        print ("Attack %d / %d. Class %d" % \
               (i+1, len(d['norm']), d['label'][i]))
        succ = False
        n_total_including_originally_wrong += 1
        time_st = time.time()
        s = [t.strip() for t in d['norm'][i]]
        s_norm = normalize(s)
        s = denormalize(s)
        y = d['label'][i] + class_op
        logits = cls_model.run([" ".join(s_norm)])[0]
        # Skip those original erroneously predicted examples
        if logits.argmax().item() != y:
            print ("  WRONG. SKIP!")
            continue
        # Start adversarial attack
        old_prob = softmax(logits)[y].item()
        if _attack in ['MHM', 'CBA', 'SACBA']:
            if _lang == 'C':
                uid = UIDStruct(s, mask=mlm_model.tokenizer.unk_token)
            elif _lang == 'JAVA':
                uid = UIDStruct_Java(s, mask=mlm_model.tokenizer.unk_token, so_path=opt.so_path)
            print ("  UIDs: ", end="")
            for i in uid.sym2pos.keys():
                print (i, end=" ")
            print ()
        elif _attack in ['CBA-CMP', 'SACBA-CMP']:
            cmp = CMPStruct(s, mask=mlm_model.tokenizer.unk_token)
            print ("  CMPs: ", end="")
            for s, _, t in cmp.cmptab:
                for w in cmp.code[s: t]:
                    print (w, end=" ")
                print (",", end=" ")
            print ()
        elif _attack in ['CBA-WS', 'SACBA-WS']:
            ws = WSStruct(s, mask=mlm_model.tokenizer.unk_token, max_len=512)
    
        # MHM
        if _attack == 'MHM':
            
            res = atk.mcmc(uid=uid, label=y,
                           classifier=cls_model,
                           n_candi=n_candidate,
                           max_iter=n_perturb)
            if res['succ']:
                succ = True
            
        # CodeBert-Attack (UID Renaming)
        elif _attack in ['CBA', 'SACBA']:
            
            for it in range(n_perturb):
                # Find vulnerable identifiers
                if opt.rand_idx:
                    vulnerables = random.sample(list(uid.sym2pos.keys()), 1)
                    vulnerables = {i: None for i in vulnerables}
                else:
                    vulnerables = atk.find_vulnerable_uids(cls=cls_model,
                                                           uid=uid,
                                                           ground_truth_label=y,
                                                           n_vul=n_vulnerable)
                candidate_uids, candidate_seqs, candidate_old_uids = [], [], []
                for v in vulnerables.keys():
                    # Generate possible candidates for each vulnerable uids
                    if opt.rand_cand:
                        c = []
                        s = []
                        for _c in random.sample(all_uids, n_candidate):
                            if _c not in uid.sym2pos.keys():
                                c.append(_c)
                                s.append(copy.deepcopy(uid.code))
                                for jjj in uid.sym2pos[v]:
                                    s[-1][jjj] = _c
                                s[-1] = " ".join(s[-1])
                    else:
                        c, s = atk.generate_candidates(uid=uid,
                                                       mlm=mlm_model,
                                                       vulnerable=v,
                                                       idx2txt=idx2txt,
                                                       bpe_indicator='Ġ',
                                                       n_candidate=n_candidate,
                                                       smoothing=smoothing,
                                                       batch_size=n_candidate*2,
                                                       criterion=ce,
                                                       max_computational=1e6,
                                                       len_threshold=510)
                    if c is None or s is None:
                        continue
                    candidate_old_uids += [v for _ in c]
                    candidate_uids += c
                    candidate_seqs += s
                if len(candidate_seqs) <= 0:
                    break
                # Probe the target model
                probs = softmax(cls_model.run(candidate_seqs))
                preds = probs.argmax(dim=-1)
                for pi in range(len(preds)):
                    # Find an adversarial example
                    if preds[pi].item() != y:
                        print ("  %s => %s, %d (%.5f%%) => %d %d (%.5f%% %.5f%%)" % \
                               (candidate_old_uids[pi], candidate_uids[pi], y, old_prob*100,
                                y, preds[pi], probs[pi][y].item()*100, probs[pi][preds[pi]].item()*100))
                        succ = True
                        assert uid.update_sym(candidate_old_uids[pi], candidate_uids[pi]), \
                            "\n"+str(uid.sym2pos.keys())+"\n"+candidate_old_uids[pi]+"\n"+candidate_uids[pi]
                        assert _is_uid(candidate_uids[pi])
                        break
                if succ:
                    break
                
                next_i = probs[:, y].argmin().item()
                # CBA - Test if the ground truth probability decreases
                # SACBA - Accept / reject to jump to the candidate
                accept = (probs[next_i, y] < old_prob)
                if _attack == 'SACBA' and (not accept):
                    acc_prob = torch.exp(-(probs[next_i, y] - old_prob) / temperature(it+1))
                    accept = (numpy.random.uniform(0,1) < acc_prob)
                if accept:
                    print ("  %s => %s, %d (%.5f%%) => %d (%.5f%%)" % \
                           (candidate_old_uids[next_i], candidate_uids[next_i], y, old_prob*100,
                            y, probs[next_i][y].item()*100))
                    old_prob = probs[next_i][y].item()
                    assert uid.update_sym(candidate_old_uids[next_i], candidate_uids[next_i]), \
                        "\n"+str(uid.sym2pos.keys())+"\n"+candidate_old_uids[next_i]+"\n"+candidate_uids[next_i]
                    assert _is_uid(candidate_uids[next_i]), candidate_uids[next_i]
                else:
                    break
                
        # CodeBert-Attack (white space insertion)
        elif _attack in ['CBA-WS', 'SACBA-WS']:
            
            for it in range(n_perturb):
                # Find vulnerable identifiers
                vulnerables = atk.find_vulnerable_tokens(cls=cls_model,
                                                         ws=ws,
                                                         ground_truth_label=y,
                                                         n_vul=n_vulnerable,
                                                         n_mask=n_mask_ws)
                if len(vulnerables) <= 0:
                    break
                candidate_wss, candidate_seqs = atk.generate_candidates(ws=ws,
                                                                        mlm=mlm_model,
                                                                        vulnerable=vulnerables, 
                                                                        n_candidate=n_candidate,
                                                                        batch_size=n_candidate*3,
                                                                        criterion=ce,
                                                                        len_threshold=510)
                if len(candidate_seqs) <= 0:
                    break
                # Probe the target model
                probs = softmax(cls_model.run(candidate_seqs))
                preds = probs.argmax(dim=-1)
                for pi in range(len(preds)):
                    # Find an adversarial example
                    if preds[pi].item() != y:
                        print ("  idx %d <%s>, %d (%.5f%%) => %d %d (%.5f%% %.5f%%)" % \
                               (candidate_wss[pi][0],
                                {" ":"SPC", "\t":"TAB", "\n":"NL"}[candidate_wss[pi][1]],
                                y, old_prob*100, y, preds[pi], probs[pi][y].item()*100,
                                probs[pi][preds[pi]].item()*100))
                        succ = True
                        assert ws.update_ws(candidate_wss[pi][0], candidate_wss[pi][1])
                        break
                if succ:
                    break
                
                next_i = probs[:, y].argmin().item()
                # CBA-WS - Test if the ground truth probability decreases
                # SACBA-WS - Accept / reject to jump to the candidate
                accept = (probs[next_i, y] < old_prob)
                if _attack == 'SACBA' and (not accept):
                    acc_prob = torch.exp(-(probs[next_i, y] - old_prob) / temperature(it+1))
                    accept = (numpy.random.uniform(0,1) < acc_prob)
                if accept:
                    print ("  idx %d <%s>, %d (%.5f%%) => %d (%.5f%%)" % \
                           (candidate_wss[next_i][0],
                            {" ":"SPC", "\t":"TAB", "\n":"NL"}[candidate_wss[next_i][1]],
                            y, old_prob*100, y, probs[next_i][y].item()*100))
                    old_prob = probs[next_i][y].item()
                    assert ws.update_ws(candidate_wss[next_i][0], candidate_wss[next_i][1])
                else:
                    break
                
        # CodeBert-Attack (CMP Swapping)
        elif _attack in ['CBA-CMP', 'SACBA-CMP']:
            
            for it in range(n_perturb):
                # Find vulnerable identifiers
                vulnerables = atk.find_vulnerable_cmps(cls=cls_model,
                                                       cmp=cmp,
                                                       ground_truth_label=y,
                                                       n_vul=n_vulnerable)
                if len(vulnerables) <= 0:
                    break
                candidate_cmps, candidate_new_cmps, candidate_seqs = atk.generate_candidates(cmp=cmp,
                                                                                             mlm=mlm_model,
                                                                                             vulnerable=vulnerables,
                                                                                             batch_size=n_candidate,
                                                                                             n_candidate=n_candidate)
                candidate_old_cmps = [" ".join(cmp.code[cmp.cmptab[i][0]: cmp.cmptab[i][2]]) \
                                      for i in candidate_cmps]
                if len(candidate_seqs) <= 0:
                    break
                # Probe the target model
                probs = softmax(cls_model.run(candidate_seqs))
                preds = probs.argmax(dim=-1)
                for pi in range(len(preds)):
                    # Find an adversarial example
                    if preds[pi].item() != y:
                        print ("  %s => %s, %d (%.5f%%) => %d %d (%.5f%% %.5f%%)" % \
                               (candidate_old_cmps[pi], candidate_new_cmps[pi], y, old_prob*100,
                                y, preds[pi], probs[pi][y].item()*100, probs[pi][preds[pi]].item()*100))
                        succ = True
                        assert cmp.update_cmp(candidate_cmps[pi])
                        break
                if succ:
                    break
                
                next_i = probs[:, y].argmin().item()
                # CBA-CMP - Test if the ground truth probability decreases
                # SACBA-CMP - Accept / reject to jump to the candidate
                accept = (probs[next_i, y] < old_prob)
                if _attack == 'SACBA-CMP' and (not accept):
                    acc_prob = torch.exp(-(probs[next_i, y] - old_prob) / temperature(it+1))
                    accept = (numpy.random.uniform(0,1) < acc_prob)
                if accept:
                    print ("  %s => %s, %d (%.5f%%) => %d (%.5f%%)" % \
                           (candidate_old_cmps[next_i], candidate_new_cmps[next_i], y, old_prob*100,
                            y, probs[next_i][y].item()*100))
                    old_prob = probs[next_i][y].item()
                    assert cmp.update_cmp(candidate_cmps[next_i])
                else:
                    break
            
        if succ:
            n_succ += 1
            n_total += 1
            time_total += time.time() - time_st
            print ("  SUCC!")
        else:
            n_total += 1
            print ("  FAIL!")
            
        if n_total > 0:
            succ_rate = n_succ/n_total
        else:
            succ_rate = 0
        acc_rate = (n_total-n_succ)/n_total_including_originally_wrong
        if n_succ > 0:
            avg_time = time_total/n_succ
        else:
            avg_time = float("NaN")
        
        print ("  Succ %% = %.5f%%, Acc %% = %.5f%%, Avg time = %.5f sec" % \
               (succ_rate*100, acc_rate*100, avg_time))