In [None]:
import json
import random
import numpy as np
import pandas as pd
import altair as alt

In [None]:
cache_dir = '../cache/ImageNet/resnet50_0-10'
corruptions = ['brightness', 'defocus_blur', 'elastic_transform', 'fog', 'frost', 'gaussian_blur', 
 'gaussian_noise', 'glass_blur', 'impulse_noise', 'jpeg_compression', 'motion_blur', 'pixelate',
 'saturate', 'shot_noise', 'snow', 'spatter', 'speckle_noise', 'zoom_blur', 'contrast', 'collection']
metrics = ['ATC-MC', 'COTT-MC']

In [None]:
def gather_results(cache_dir, corruptions, metrics):
 data = []
 for corr in corruptions:
 for metric in metrics:
 file_dir = f'{cache_dir}/{metric}_costs/pretrained_{corr}.json'
 with open(file_dir, 'r') as f:
 records = json.load(f)
 
 for sev, record in enumerate(records):
 if 'ATC' in metric:
 costs = np.array(record['costs'])
 elif 'COTT' in metric:
 costs = np.array(record['costs']) / 2 + 1
 
 mae = abs( (np.array(record['costs']) < record['t']).sum() / len(costs) - record['ood error'])
 
 data.extend(
 [{'corr': corr, 'sev': sev, 'cost': cost, 'mae': mae, 'metric': metric} for cost in random.choices(costs, k=1000)]
 )
 return pd.DataFrame(data)


In [None]:
source = gather_results(cache_dir, corruptions, metrics)

In [None]:
alt.data_transformers.enable('default', max_rows=None)

In [None]:
alt.Chart(data=source).transform_density(
 'cost',
 groupby=['sev', 'corr', 'metric'],
 as_=['cost', 'density'],
).mark_area(opacity=0.8).encode(
 x="cost:Q",
 y='density:Q',
 color='metric:N'
).properties(
 width=100,
 height=100
).facet(
 column='sev:N',
 row='corr:N'
)

In [None]:
alt.Chart(data=source).mark_bar(opacity=0.8).encode(
 x="metric:N",
 y='mae:Q',
 color='metric:N'
).properties(
 width=20,
 height=100
).facet(
 column='sev:N',
 row='corr:N'
)

In [None]:
charts = alt.vconcat()
for corr in corruptions:
 hcharts = alt.hconcat(title=corr)
 for i in range(5):
 sub_source = source[(source['sev'] == i) & (source['corr'] == corr)]
 if len(sub_source) == 0:
 break
 
 density = alt.Chart(data=sub_source).transform_density(
 'cost',
 groupby=['sev', 'corr', 'metric'],
 as_=['cost', 'density'],
 ).mark_area(opacity=0.8).encode(
 x=alt.X("cost:Q"),
 y=alt.Y('density:Q'),
 color='metric:N'
 ).properties(
 width=80,
 height=80
 )
 
 y_max = source[(source['corr'] == corr)]['mae'].max()
 mae = alt.Chart(data=sub_source).mark_bar(opacity=0.8).encode(
 x=alt.X("metric:N", axis=alt.Axis(labels=False)),
 y=alt.Y('mae:Q', scale=alt.Scale(domain=[0, y_max]), axis=alt.Axis(format='.3f')),
 color='metric:N'
 ).properties(
 width=20,
 height=80
 )
 
 hcharts |= density
 hcharts |= mae

 charts &= hcharts

charts