In [16]:
import numpy as np
import math
import json
import glob
from functools import reduce
import altair as alt

In [17]:
def get_mae(dataset, model, metric, n_sample, seed):
 result_dir = f'../results/{dataset}/{model}_{seed}/{metric}_{n_sample}'
 
 result_fs = glob.glob(f'{result_dir}/*.json')
 results = []
 for file in result_fs:
 with open(file, 'r') as f:
 data = json.load(f)
 results.extend(data)
 
 maes = [abs(result['metric'] - result['error']) for result in results]
 
 return np.array(maes).mean()

In [33]:
def get_sens_results(metric, dataset, arch):
 n_samples = [10000, 5000, 2000, 1000, 500, 200, 100]
 seed = '1'
 sens = []
 for n in n_samples: 
 mae = get_mae(dataset, arch, metric, n, seed)
 sens.append({
 'metric': metric,
 'dataset': dataset,
 'n_sample': n,
 'arch': arch,
 'mae': mae
 })
 return sens

In [34]:
metrics = ['EMD', 'ATC', 'REMD_1.25']
datasets = ['CIFAR-10', 'CIFAR-100', 'Tiny-ImageNet']
sens = []
for metric in metrics:
 for dataset in datasets:
 sens += get_sens_results(metric, dataset, 'resnet50')

In [35]:
alt.Chart(alt.Data(values=sens)).mark_line(
 point=alt.OverlayMarkDef(size=50, filled=False, fill='white')
).encode(
 x=alt.X('n_sample:Q', title='Number of Samples', scale=alt.Scale(type="log", reverse=False), axis=alt.Axis(values=[100, 1000, 10000])), 
 y=alt.Y('mae:Q', scale=alt.Scale(domain=[0, 0.22]), title='Mean Absolute Error'),
 color=alt.Color('metric:N')
).properties(
 width=150,
 height=150
).facet(
 column=alt.Column('dataset:N', header=alt.Header(labelFontSize=18))
).configure_axis(
 labelFontSize=16,
 titleFontSize=18,
).configure_legend(
 titleFontSize=20,
 labelFontSize=18
) 