# -*- coding: utf-8 -*- """ Created on Mon Dec 28 13:53:05 2020 @author: DrLC """ import pickle import copy import tqdm import argparse import torch import os import numpy import matplotlib.pyplot as plt def read_log(lines): while "Attack" not in lines[0]: lines = lines[1:] blocks = [] for line in lines: if "Attack" == line.split()[0]: blocks.append([]) blocks[-1].append(line.strip()) return blocks def parse_rep_log_block(lines): rep = [] for line in lines: words = line.split() words = [w.strip(',') for w in words] if words == ["FAIL!"] or ("WRONG." in words and "SKIP!" in words): return None if "=>" not in words: continue idx = words.index("=>") rep.append((words[idx - 1], words[idx + 1])) return rep def normalize(seq): norm = [] for t in seq: if "'" in t: norm.append("<char>") elif '"' in t: norm.append("<str>") elif t.isdigit() or t[:2] == "0x": norm.append("<int>") elif is_fp(t): norm.append("<fp>") else: norm.append(t) return norm def denormalize(seq): denorm = [] for t in seq: if t == "<char>": denorm.append("'0'") elif t == "<str>": denorm.append("\"0\"") elif t == "<int>": denorm.append("0") elif t == "<fp>": denorm.append("0.0") else: denorm.append(t) return denorm if __name__ == "__main__": ''' logs = ["mhm_attack_lstm.log", "mhm_attack_lstm_large.log", "cba_attack_lstm.log", "sacba_attack_lstm.log"] labels = ["MHM", "MHM(50)", "CBA", "CBA-SA"] line_types = ['r:', 'y-.', 'g-*', 'b-o'] fig_path = "E:\\CodeBERT Attack Paper\\lstm_succ_rate.png" yticks_max = 50 yticks_each = 10 ''' logs = ["mhm_attack_cbmlm.log", "cba_attack_cbmlm.log", "sacba_attack_cbmlm.log"] labels = ["MHM", "CBA", "CBA-SA"] line_types = ['r:', 'g-*', 'b-o'] fig_path = "E:\\CodeBERT Attack Paper\\cbmlm_succ_rate.png" yticks_max = 15 yticks_each = 5 max_iter = 20 label_size = 24 tick_size = 17 legend_size = 17 for log, label, line_type in zip(logs, labels, line_types): with open(log, "r") as f: lines = f.readlines() d = read_log(lines) for b in d: l = b[0].split() l_ = b[-1].split() assert l[0] == "Attack" assert l[2] == '/' assert l[6] == 'Class' assert float(l[3]) == len(d) or float(l[3]) * 2 == len(d) if b[1] == 'WRONG. SKIP!': continue assert b[-2] in ["FAIL!", "SUCC!"] assert l_[0] == 'Succ' assert l_[1] == '%' reps = [] for b in d: reps.append(parse_rep_log_block(b)) succ = [0 for _ in range(max_iter + 1)] for rep in reps: if rep is None: continue for j in range(len(rep), max_iter + 1): succ[j] += 1 succ = [i / len(reps) * 100 for i in succ] plt.plot(range(max_iter + 1), succ, line_type, label=label) plt.xticks([i * 5 for i in range(max_iter // 5 + 1)], [i * 5 for i in range(max_iter // 5 + 1)], fontsize=tick_size) plt.yticks([i * yticks_each for i in range(yticks_max // yticks_each + 1)], [i * yticks_each for i in range(yticks_max // yticks_each + 1)], fontsize=tick_size) plt.xlabel("Iteration", fontsize=label_size) plt.ylabel("Succ (%)", fontsize=label_size) plt.legend(loc="lower right", fontsize=legend_size) plt.savefig(fig_path, bbox_inches = 'tight') plt.show()