CodeBERT-Attack / oj-attack / codebert_attack.py
codebert_attack.py
Raw
# -*- coding: utf-8 -*-
"""
Created on Mon Oct 19 11:27:11 2020

@author: DrLC
"""

from utils import align_subtokens, is_uid, is_special_id, is_java_uid, is_java_special_id

import torch
import copy
from torch.nn import CrossEntropyLoss


class CodeBERT_Attack_UID(object):
    
    def __init__(self, lang="c"):
        
        lang = lang.upper()
        if lang == "C":
            self.is_uid = is_uid
            self.is_special_id = is_special_id
        elif lang == 'JAVA':
            self.is_uid = is_java_uid
            self.is_special_id = is_java_special_id
        else:
            assert False
        
    def find_vulnerable_uids(self, cls, uid, ground_truth_label, other_uid=None, n_vul=5):
        
        masked_code = uid.gen_mask()
        masked_uids = list(masked_code.keys())
        if other_uid is None:
            batch_ = [masked_code[i] for i in masked_uids] + [uid.code] # The last one is the original
        else:
            batch_ = []
            for i in masked_uids:
                batch_.append(masked_code[i])
                batch_.append(other_uid.code)
            batch_.append(uid.code)         # The last pair is the original
            batch_.append(other_uid.code)
        batch = [" ".join(s) for s in batch_]
        logits = cls.run(batch)[:, ground_truth_label]
        score = logits[-1] - logits[:-1]   # Importance score of each UID
        n_selected = int(min(len(score), n_vul))
        _, idxs = score.topk(n_selected)
        ret = {}
        for i in idxs:
            if score[i] > 0:
                if other_uid is not None:
                    ret[masked_uids[i]] = batch_[2 * i]
                else:
                    ret[masked_uids[i]] = batch_[i]
        return ret
    
    def generate_candidates(self, uid, mlm, vulnerable, idx2txt,
                            bpe_indicator='Ġ', n_candidate=10, smoothing=0.1,
                            batch_size=32, criterion=None, len_threshold=512,
                            max_computational=1e8, other_uid=None):
        
        assert smoothing < 1
        specials = mlm.tokenizer.all_special_tokens
        if criterion is None:
            criterion = CrossEntropyLoss(reduction="none")
        device = mlm.device
        
        # 1. Generate all candidate subtokens
        subtokens = mlm.tokenize(" ".join(uid.code), False, False)[0]
        mlm_logits = mlm.run(" ".join(uid.code))[0]
        align_sub2token, align_token2sub = align_subtokens(uid.code, subtokens, bpe_indicator, specials)
        all_candidate_subtoken_logits, all_candidate_subtoken_ids = mlm_logits.topk(n_candidate, dim=-1)
        # 2. Merge the subtokens to create the candidate tokens at each position
        candidates = []
        computational = 1
        # If the subtoken seq is too long, use the simplified version -- considering only the very first positions
        for i in range(len(uid.code)):
            if vulnerable == uid.code[i].strip():
                new_candidates = []
                if n_candidate ** len(align_token2sub[i]) >= max_computational:
                    break
                for j in align_token2sub[i]:
                    if j >= len_threshold:
                        break
                    if len(new_candidates) <= 0:
                        for idx, logit in zip(all_candidate_subtoken_ids[j], all_candidate_subtoken_logits[j]):
                            if idx2txt[idx][0] == bpe_indicator:
                                new_candidates.append((idx2txt[idx][1:], logit.item()))
                    else:
                        new_candidates_tmp = new_candidates
                        new_candidates = []
                        for idx, logit in zip(all_candidate_subtoken_ids[j], all_candidate_subtoken_logits[j]):
                            if idx2txt[idx][0] != bpe_indicator and idx2txt[idx] not in specials:
                                for _c, _l in new_candidates_tmp:
                                    new_candidates.append((_c+idx2txt[idx], logit.item()+_l))
                if len(new_candidates) <= 0:
                    continue
                computational *= len(new_candidates)
                new_candidates = {it[0]: it[1] for it in new_candidates}
                new_candidates[None] = min(new_candidates.values()) * smoothing
                candidates.append(new_candidates)
                if computational * len(new_candidates) >= max_computational:
                    break
        # 3. Merge all tokens to create the candidate tokens upon the whole snippet
        merged_candidates = {}
        if len(candidates) <= 0:
            return None, None
        for c in candidates[0].keys():
            if c is None:
                merged_candidates[None] = candidates[0][None]
            else:
                merged_candidates[c] = candidates[0][c]
        for _candidates in candidates[1:]:
            for c in _candidates.keys():
                None_logit = merged_candidates[None]
                if c is None:
                    merged_candidates[None] = _candidates[None]
                elif c in merged_candidates.keys():
                    merged_candidates[c] += _candidates[c]
                else:
                    merged_candidates[c] = None_logit + _candidates[c]
        merged_candidates.pop(None)
        merged_candidates = sorted(merged_candidates.items(), key=lambda it: it[1], reverse=True)
        # 4. Filter out the non-uid tokens
        candidates = []
        for _c, _ in merged_candidates[:batch_size]:
            c = _c.strip()
            if self.is_uid(c) and (not self.is_special_id(c)) and (c not in uid.sym2pos.keys()):
                candidates.append(c)
        if len(candidates) <= 0:
            return None, None
        # 5. Select topk accorading to PPL by MLM
        targets = mlm.tokenize(candidates, cut_and_pad=True, ret_id=True)
        logits = mlm.run(candidates, batch_size)
        targets = torch.tensor(targets).to(device)[:, :logits.shape[1]]
        log_ppl = criterion(logits.permute([0,2,1]), targets).mean(dim=-1)
        _, selectd_idx = log_ppl.topk(min(n_candidate, len(log_ppl)))
        ret = [candidates[i] for i in selectd_idx]
        seq = []
        for i in ret:
            seq.append(copy.deepcopy(uid.code))
            for j in uid.sym2pos[vulnerable]:
                seq[-1][j] = i
            seq[-1] = " ".join(seq[-1])
            if other_uid is not None:
                seq.append(" ".join(other_uid.code))
        return ret, seq
        
class CodeBERT_Attack_CMP(object):
    
    def __init__(self):
        
        pass
        
    def find_vulnerable_cmps(self, cls, cmp, ground_truth_label,
                             n_vul=10):
        
        masked_code = cmp.gen_mask()
        batch_ = masked_code + [cmp.code] # The last one is the original
        batch = [" ".join(s) for s in batch_]
        logits = cls.run(batch)[:, ground_truth_label]
        score = logits[-1] - logits[:-1]   # Importance score of each UID
        n_selected = int(min(len(score), n_vul))
        _, idxs = score.topk(n_selected)
        ret = []
        for i in idxs:
            if score[i] > 0:
                ret.append(int(i))
        return ret
    
    def generate_candidates(self, cmp, mlm, vulnerable,
                            batch_size=32, n_candidate=10, criterion=None):
        
        if criterion is None:
            criterion = CrossEntropyLoss(reduction="none")
        device = mlm.device
        
        # 1. Generate all candidate swapped sequences
        candidates = []
        seqs = []
        for idx in vulnerable:
            cmp.update_cmp(idx)
            candidates.append(" ".join(cmp.code))
            seqs.append(" ".join(cmp.code[cmp.cmptab[idx][0]: cmp.cmptab[idx][2]]))
            cmp.update_cmp(idx)
        # 2. Filter according to PPL by MLM
        targets = mlm.tokenize(candidates, cut_and_pad=True, ret_id=True)
        logits = mlm.run(candidates, batch_size)
        targets = torch.tensor(targets).to(device)[:, :logits.shape[1]]
        log_ppl = criterion(logits.permute([0,2,1]), targets).mean(dim=-1)
        _, selectd_idx = log_ppl.topk(min(n_candidate, len(log_ppl)))
        ret = [vulnerable[int(i)] for i in selectd_idx]
        seq = [seqs[int(i)] for i in selectd_idx]
        cand = [candidates[int(i)] for i in selectd_idx]

        return ret, seq, cand
                
class CodeBERT_Attack_WS(object):
    
    def __init__(self, ws=[" ", "\n", "\t"]):
        
        self.ws = copy.deepcopy(ws)
        
    def find_vulnerable_tokens(self, cls, ws, ground_truth_label,
                               other_ws=None, n_vul=5, n_mask=15):
        
        masked_code = ws.gen_mask(n_mask)
        masked_idxs = list(masked_code.keys())
        if other_ws is None:
            batch_ = [masked_code[i] for i in masked_idxs] + [ws.code] # The last one is the original
        else:
            batch_ = []
            for i in masked_idxs:
                batch_.append(masked_code[i])
                batch_.append(other_ws.code)
            batch_.append(ws.code)         # The last pair is the original
            batch_.append(other_ws.code)
        batch = [" ".join(s) for s in batch_]
        logits = cls.run(batch)[:, ground_truth_label]
        score = logits[-1] - logits[:-1]   # Importance score of each UID
        n_selected = int(min(len(score), n_vul))
        _, idxs = score.topk(n_selected)
        ret = {}
        for i in idxs:
            if score[i] > 0:
                if other_ws is not None:
                    ret[masked_idxs[i]] = batch_[2 * i]
                else:
                    ret[masked_idxs[i]] = batch_[i]
        return ret
    
    def generate_candidates(self, ws, mlm, vulnerable, 
                            n_candidate=10, batch_size=32,
                            criterion=None, len_threshold=512, other_ws=None):
        
        if criterion is None:
            criterion = CrossEntropyLoss(reduction="none")
        device = mlm.device
        
        # 1. Generate all candidates (white space inserted)
        candidates, idxs, wss = [], [], []
        for v in vulnerable.keys():
            for i in self.ws:
                candidates.append(copy.deepcopy(ws.code))
                candidates[-1][v] = i + candidates[-1][v]
                candidates[-1] = " ".join(candidates[-1])
                idxs.append(v)
                wss.append(i)
        # 2. Select topk according to PPL by MLM
        targets = mlm.tokenize(candidates, cut_and_pad=True, ret_id=True)
        logits = mlm.run(candidates, batch_size)
        targets = torch.tensor(targets).to(device)[:, :logits.shape[1]]
        log_ppl = criterion(logits.permute([0,2,1]), targets).mean(dim=-1)
        _, selectd_idx = log_ppl.topk(min(n_candidate, len(log_ppl)))
        
        # 5. Select topk accorading to PPL by MLM
        targets = mlm.tokenize(candidates, cut_and_pad=True, ret_id=True)
        logits = mlm.run(candidates, batch_size)
        targets = torch.tensor(targets).to(device)[:, :logits.shape[1]]
        log_ppl = criterion(logits.permute([0,2,1]), targets).mean(dim=-1)
        _, selectd_idx = log_ppl.topk(min(n_candidate, len(log_ppl)))
        ret = [(idxs[i], wss[i]) for i in selectd_idx]
        seq = []
        for i in selectd_idx:
            seq.append(candidates[i])
            if other_ws is not None:
                seq.append(" ".join(other_ws.code))
        return ret, seq    
    
    
    
if __name__ == "__main__":
    
    from uid import UIDStruct
    from cmp import CMPStruct
    from codebert import codebert_mlm, codebert_cls
    from utils import normalize, denormalize
    from cparser import CCode
    
    import argparse
    import os
    import pickle
    import json
    import time
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', type=str, default='-1')
    parser.add_argument('--mlm_path', type=str,
                        default="/var/data/lushuai/bertvsbert/save/poj/checkpoint-9000-1.0555",
                        help="Path to the masked language model")
    parser.add_argument('--cls_path', type=str,
                        default="/var/data/lushuai/bertvsbert/save/poj-classifier/checkpoint-51000-0.986",
                        help="Path to the OJ classifier")
    parser.add_argument('--testset', type=str,
                        default="../data/test.pkl")
    
    opt = parser.parse_args()

    if int(opt.gpu) < 0:
        device = torch.device("cpu")
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu
        device = torch.device("cuda")
    mlm_model = codebert_mlm(opt.mlm_path, device)
    vocab_path = mlm_model.tokenizer.vocab_files_names["vocab_file"]
    with open(os.path.join(opt.mlm_path, vocab_path), "r") as f:
        txt2idx = json.load(f)
        tmp = sorted(txt2idx.items(), key=lambda it: it[1])
        idx2txt = [it[0] for it in tmp]
        assert txt2idx[idx2txt[-1]] == len(idx2txt) - 1, \
            "\n"+idx2txt[-1]+"\n"+str(txt2idx[idx2txt[-1]])+"\n"+str(len(idx2txt)-1)
    cls_model = codebert_cls(opt.cls_path, device)
    atk = CodeBERT_Attack_CMP()
    #atk = CodeBERT_Attack_UID()
    with open(opt.testset, "rb") as f:
        d = pickle.load(f)
    
    len_threshold = -1
    times_per_example = 5
    time_st = time.time()
    
    for i in range(len(d['src'])):
        if len_threshold > 0 and len(d['src'][i])  >= len_threshold:  
            continue
        s = [t.strip() for t in d['src'][i]]
        s_norm = normalize(s)
        y = d['label'][i] - 1
        logits = cls_model.run([" ".join(s_norm)])
        if logits[0].argmax().item() != y:
            continue
        cmp = CMPStruct(s, mask=mlm_model.tokenizer.unk_token)
        #uid = UIDStruct(s, mask=mlm_model.tokenizer.unk_token)
        for cnt_per_example in range(times_per_example):
            vulnerables = atk.find_vulnerable_cmps(cls_model, cmp, y)
            #vulnerables = atk.find_vulnerable_uids(cls_model, uid, y)
            if len(vulnerables) <= 0:
                continue
            candidates, _, _ = atk.generate_candidates(cmp, mlm_model, vulnerables, n_candidate=10)
            #candidates, _ = atk.generate_candidates(uid, mlm_model, list(vulnerables.keys())[0],
            #                                        idx2txt, n_candidate=10, smoothing=0.1)
            if candidates is None:
                continue
            cmp.update_cmp(candidates[0])
            #uid.update_sym(list(vulnerables.keys())[0], candidates[0])
        assert CCode(" ".join(denormalize(cmp.code))), \
            "\n"+" ".join(s_norm)+"\n"+" ".join(cmp.code)
        print ("\r%.1f min | %6s / %6s |" % ((time.time()-time_st)/60., str(i), str(len(d['src']))), end="\r")
        
    print ("\n%.1f min" % ((time.time()-time_st)/60.))