CodeBERT-Attack / oj-attack / log / process_log.py
process_log.py
Raw
# -*- coding: utf-8 -*-
"""
Created on Mon Dec 28 13:53:05 2020

@author: DrLC
"""

import pickle
import copy
import tqdm
import argparse
import torch
import os
import numpy
import matplotlib.pyplot as plt

def read_log(lines):
    
    while "Attack" not in lines[0]:
        lines = lines[1:]
    
    blocks = []
    for line in lines:
        if "Attack" == line.split()[0]:
            blocks.append([])
        blocks[-1].append(line.strip())
    
    return blocks
    
def parse_rep_log_block(lines):
    
    rep = []
    for line in lines:
        words = line.split()
        words = [w.strip(',') for w in words]
        if words == ["FAIL!"] or ("WRONG." in words and "SKIP!" in words):
            return None
        if "=>" not in words:
            continue
        idx = words.index("=>")
        rep.append((words[idx - 1], words[idx + 1]))
    return rep

def normalize(seq):
    
    norm = []
    for t in seq:
        if "'" in t:
            norm.append("<char>")
        elif '"' in t:
            norm.append("<str>")
        elif t.isdigit() or t[:2] == "0x":
            norm.append("<int>")
        elif is_fp(t):
            norm.append("<fp>")
        else:
            norm.append(t)
    return norm

def denormalize(seq):
    
    denorm = []
    for t in seq:
        if t == "<char>":
            denorm.append("'0'")
        elif t == "<str>":
            denorm.append("\"0\"")
        elif t == "<int>":
            denorm.append("0")
        elif t == "<fp>":
            denorm.append("0.0")
        else:
            denorm.append(t)
    return denorm

if __name__ == "__main__":
    
    '''
    logs = ["mhm_attack_lstm.log",
            "mhm_attack_lstm_large.log",
            "cba_attack_lstm.log",
            "sacba_attack_lstm.log"]
    labels = ["MHM", "MHM(50)", "CBA", "CBA-SA"]
    line_types = ['r:', 'y-.', 'g-*', 'b-o']
    fig_path = "E:\\CodeBERT Attack Paper\\lstm_succ_rate.png"
    yticks_max = 50
    yticks_each = 10
    '''
    logs = ["mhm_attack_cbmlm.log",
            "cba_attack_cbmlm.log",
            "sacba_attack_cbmlm.log"]
    labels = ["MHM", "CBA", "CBA-SA"]
    line_types = ['r:', 'g-*', 'b-o']
    fig_path = "E:\\CodeBERT Attack Paper\\cbmlm_succ_rate.png"
    yticks_max = 15
    yticks_each = 5
    
    
    max_iter = 20
    label_size = 24
    tick_size = 17
    legend_size = 17
    
    for log, label, line_type in zip(logs, labels, line_types):
        with open(log, "r") as f:
            lines = f.readlines()
            
        d = read_log(lines)
        for b in d:
            l = b[0].split()
            l_ = b[-1].split()
            assert l[0] == "Attack"
            assert l[2] == '/'
            assert l[6] == 'Class'
            assert float(l[3]) == len(d) or float(l[3]) * 2 == len(d)
            if b[1] == 'WRONG. SKIP!':
                continue
            assert b[-2] in ["FAIL!", "SUCC!"]
            assert l_[0] == 'Succ'
            assert l_[1] == '%'
    
        reps = []
        for b in d:
            reps.append(parse_rep_log_block(b))
            
        succ = [0 for _ in range(max_iter + 1)]
        for rep in reps:
            if rep is None:
                continue
            for j in range(len(rep), max_iter + 1):
                succ[j] += 1
        succ = [i / len(reps) * 100 for i in succ]
        
        plt.plot(range(max_iter + 1), succ, line_type, label=label)
        
    plt.xticks([i * 5 for i in range(max_iter // 5 + 1)], [i * 5 for i in range(max_iter // 5 + 1)], fontsize=tick_size)
    plt.yticks([i * yticks_each for i in range(yticks_max // yticks_each + 1)], [i * yticks_each for i in range(yticks_max // yticks_each + 1)], fontsize=tick_size)
    plt.xlabel("Iteration", fontsize=label_size)
    plt.ylabel("Succ (%)", fontsize=label_size)
    plt.legend(loc="lower right", fontsize=legend_size)
    plt.savefig(fig_path, bbox_inches = 'tight')
    plt.show()