CodeBERT-Attack / oj-attack / mhm.py
mhm.py
Raw
# -*- coding: utf-8 -*-
"""
Created on Mon Nov  2 16:23:52 2020

@author: DrLC
"""

from utils import is_uid, is_special_id, is_java_uid, is_java_special_id
from torch.nn import Softmax

import random
import torch
import copy

class MHM_Baseline(object):
    
    def __init__(self, uids, lang="c", prob_threshold=None):
        
        self.all_uids = uids
        self.prob_threshold = prob_threshold
        self.softmax = Softmax(dim=-1)
        lang = lang.upper()
        if lang == "C":
            self.is_uid = is_uid
            self.is_special_id = is_special_id
        elif lang == 'JAVA':
            self.is_uid = is_java_uid
            self.is_special_id = is_java_special_id
        else:
            assert False
        
    def mcmc(self, uid, label, classifier, n_candi=30,
             max_iter=100, prob_threshold=0.999, other_uid=None):
        
        if len(uid.sym2pos) <= 0:
            return {'succ': False, 'tokens': None, 'raw_tokens': None}
        for iteration in range(1, 1+max_iter):
            res = self.__replaceUID(uid=uid,
                                    label=label,
                                    tgt_cls=classifier,
                                    n_candi=n_candi,
                                    prob_threshold=prob_threshold,
                                    other_uid=other_uid)
            self.__printRes(_iter=iteration, _res=res, _prefix="  >> ")
            if res['status'].lower() in ['s', 'a']: ## Accept (including success)
                try:
                    assert uid.update_sym(res['old_uid'], res['new_uid'])
                    if res['status'].lower() == 's':
                        return {'succ': True, 'seqs': res['seqs'], 'iter': iteration}
                except:
                    return {'succ': False, 'seqs': None, 'iter': None}
        return {'succ': False, 'seqs': None, 'iter': None}
        
    def __replaceUID(self, uid, label, tgt_cls,
                     n_candi=30, prob_threshold=0.95, other_uid=None):

        selected_uid = random.sample(list(uid.sym2pos.keys()), 1)[0]
        token_seq = uid.code

        # First, generate the candidate set
        # The transition probabilities of all candidate are the same (uniform)
        candi_token = [selected_uid]
        candi_seq = [copy.deepcopy(" ".join(token_seq))]
        if other_uid is not None:
            candi_seq.append(" ".join(other_uid.code))
        for c in random.sample(self.all_uids, n_candi):
            if self.is_uid(c) and (not self.is_special_id(c)) and (c not in uid.sym2pos.keys()):
                candi_token.append(c)
                tmp_seq = copy.deepcopy(token_seq)
                for i in uid.sym2pos[selected_uid]:
                    tmp_seq[i] = c
                candi_seq.append(" ".join(tmp_seq))
                if other_uid is not None:
                    candi_seq.append(" ".join(other_uid.code))
        # Then, feed all candidates to probe to target classifier
        prob = self.softmax(tgt_cls.run(candi_seq))
        if self.prob_threshold is None:
            pred = torch.argmax(prob, dim=1)
        else:
            pred = prob[:, 1] > self.prob_threshold
            pred = pred.int()
        for _i in range(len(candi_token)):
            i = _i if other_uid is None else _i * 2
            if pred[_i] != label:   # Find a valid example
                return {"status": "s", "alpha": 1, "seqs": candi_seq[i],
                        "old_uid": selected_uid, "new_uid": candi_token[_i],
                        "old_prob": prob[0], "new_prob": prob[_i],
                        "old_pred": pred[0], "new_pred": pred[_i]}
        # If not succeed, choose the strongest candidate
        candi_idx = torch.argmin(prob[1:, label]) + 1
        candi_idx = int(candi_idx.item())
        # At last, compute acceptance rate.
        prob = prob.cpu().detach().numpy()
        pred = pred.cpu().detach().numpy()
        alpha = (1-prob[candi_idx][label]+1e-10) / (1-prob[0][label]+1e-10)
        if random.uniform(0, 1) > alpha or alpha < prob_threshold:
            return {"status": "r", "alpha": alpha, "seqs": candi_seq[i],
                    "old_uid": selected_uid, "new_uid": candi_token[_i],
                    "old_prob": prob[0], "new_prob": prob[_i],
                    "old_pred": pred[0], "new_pred": pred[_i]}
        else:
            return {"status": "a", "alpha": alpha, "seqs": candi_seq[i],
                    "old_uid": selected_uid, "new_uid": candi_token[_i],
                    "old_prob": prob[0], "new_prob": prob[_i],
                    "old_pred": pred[0], "new_pred": pred[_i]}


    def __printRes(self, _iter=None, _res=None, _prefix="  => "):
        
        if _res['status'].lower() == 's':   # Accepted & successful
            print("%s iter %d, SUCC! %s => %s (%d => %d, %.5f => %.5f) a=%.3f" % \
                  (_prefix, _iter, _res['old_uid'], _res['new_uid'],
                   _res['old_pred'], _res['new_pred'],
                   _res['old_prob'][_res['old_pred']],
                   _res['new_prob'][_res['old_pred']], _res['alpha']), flush=True)
        elif _res['status'].lower() == 'r': # Rejected
            print("%s iter %d, REJ. %s => %s (%d => %d, %.5f => %.5f) a=%.3f" % \
                  (_prefix, _iter, _res['old_uid'], _res['new_uid'],
                   _res['old_pred'], _res['new_pred'],
                   _res['old_prob'][_res['old_pred']],
                   _res['new_prob'][_res['old_pred']], _res['alpha']), flush=True)
        elif _res['status'].lower() == 'a': # Accepted
            print("%s iter %d, ACC! %s => %s (%d => %d, %.5f => %.5f) a=%.3f" % \
                  (_prefix, _iter, _res['old_uid'], _res['new_uid'],
                   _res['old_pred'], _res['new_pred'],
                   _res['old_prob'][_res['old_pred']],
                   _res['new_prob'][_res['old_pred']], _res['alpha']), flush=True)
            
if __name__ == "__main__":
    
    pass