# -*- coding: utf-8 -*- """ Created on Mon Nov 2 16:23:52 2020 @author: DrLC """ from utils import is_uid, is_special_id, is_java_uid, is_java_special_id from torch.nn import Softmax import random import torch import copy class MHM_Baseline(object): def __init__(self, uids, lang="c", prob_threshold=None): self.all_uids = uids self.prob_threshold = prob_threshold self.softmax = Softmax(dim=-1) lang = lang.upper() if lang == "C": self.is_uid = is_uid self.is_special_id = is_special_id elif lang == 'JAVA': self.is_uid = is_java_uid self.is_special_id = is_java_special_id else: assert False def mcmc(self, uid, label, classifier, n_candi=30, max_iter=100, prob_threshold=0.999, other_uid=None): if len(uid.sym2pos) <= 0: return {'succ': False, 'tokens': None, 'raw_tokens': None} for iteration in range(1, 1+max_iter): res = self.__replaceUID(uid=uid, label=label, tgt_cls=classifier, n_candi=n_candi, prob_threshold=prob_threshold, other_uid=other_uid) self.__printRes(_iter=iteration, _res=res, _prefix=" >> ") if res['status'].lower() in ['s', 'a']: ## Accept (including success) try: assert uid.update_sym(res['old_uid'], res['new_uid']) if res['status'].lower() == 's': return {'succ': True, 'seqs': res['seqs'], 'iter': iteration} except: return {'succ': False, 'seqs': None, 'iter': None} return {'succ': False, 'seqs': None, 'iter': None} def __replaceUID(self, uid, label, tgt_cls, n_candi=30, prob_threshold=0.95, other_uid=None): selected_uid = random.sample(list(uid.sym2pos.keys()), 1)[0] token_seq = uid.code # First, generate the candidate set # The transition probabilities of all candidate are the same (uniform) candi_token = [selected_uid] candi_seq = [copy.deepcopy(" ".join(token_seq))] if other_uid is not None: candi_seq.append(" ".join(other_uid.code)) for c in random.sample(self.all_uids, n_candi): if self.is_uid(c) and (not self.is_special_id(c)) and (c not in uid.sym2pos.keys()): candi_token.append(c) tmp_seq = copy.deepcopy(token_seq) for i in uid.sym2pos[selected_uid]: tmp_seq[i] = c candi_seq.append(" ".join(tmp_seq)) if other_uid is not None: candi_seq.append(" ".join(other_uid.code)) # Then, feed all candidates to probe to target classifier prob = self.softmax(tgt_cls.run(candi_seq)) if self.prob_threshold is None: pred = torch.argmax(prob, dim=1) else: pred = prob[:, 1] > self.prob_threshold pred = pred.int() for _i in range(len(candi_token)): i = _i if other_uid is None else _i * 2 if pred[_i] != label: # Find a valid example return {"status": "s", "alpha": 1, "seqs": candi_seq[i], "old_uid": selected_uid, "new_uid": candi_token[_i], "old_prob": prob[0], "new_prob": prob[_i], "old_pred": pred[0], "new_pred": pred[_i]} # If not succeed, choose the strongest candidate candi_idx = torch.argmin(prob[1:, label]) + 1 candi_idx = int(candi_idx.item()) # At last, compute acceptance rate. prob = prob.cpu().detach().numpy() pred = pred.cpu().detach().numpy() alpha = (1-prob[candi_idx][label]+1e-10) / (1-prob[0][label]+1e-10) if random.uniform(0, 1) > alpha or alpha < prob_threshold: return {"status": "r", "alpha": alpha, "seqs": candi_seq[i], "old_uid": selected_uid, "new_uid": candi_token[_i], "old_prob": prob[0], "new_prob": prob[_i], "old_pred": pred[0], "new_pred": pred[_i]} else: return {"status": "a", "alpha": alpha, "seqs": candi_seq[i], "old_uid": selected_uid, "new_uid": candi_token[_i], "old_prob": prob[0], "new_prob": prob[_i], "old_pred": pred[0], "new_pred": pred[_i]} def __printRes(self, _iter=None, _res=None, _prefix=" => "): if _res['status'].lower() == 's': # Accepted & successful print("%s iter %d, SUCC! %s => %s (%d => %d, %.5f => %.5f) a=%.3f" % \ (_prefix, _iter, _res['old_uid'], _res['new_uid'], _res['old_pred'], _res['new_pred'], _res['old_prob'][_res['old_pred']], _res['new_prob'][_res['old_pred']], _res['alpha']), flush=True) elif _res['status'].lower() == 'r': # Rejected print("%s iter %d, REJ. %s => %s (%d => %d, %.5f => %.5f) a=%.3f" % \ (_prefix, _iter, _res['old_uid'], _res['new_uid'], _res['old_pred'], _res['new_pred'], _res['old_prob'][_res['old_pred']], _res['new_prob'][_res['old_pred']], _res['alpha']), flush=True) elif _res['status'].lower() == 'a': # Accepted print("%s iter %d, ACC! %s => %s (%d => %d, %.5f => %.5f) a=%.3f" % \ (_prefix, _iter, _res['old_uid'], _res['new_uid'], _res['old_pred'], _res['new_pred'], _res['old_prob'][_res['old_pred']], _res['new_prob'][_res['old_pred']], _res['alpha']), flush=True) if __name__ == "__main__": pass