{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import random\n",
"import numpy as np\n",
"import pandas as pd\n",
"import altair as alt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cache_dir = '../cache/ImageNet/resnet50_0-10'\n",
"corruptions = ['brightness', 'defocus_blur', 'elastic_transform', 'fog', 'frost', 'gaussian_blur', \n",
" 'gaussian_noise', 'glass_blur', 'impulse_noise', 'jpeg_compression', 'motion_blur', 'pixelate',\n",
" 'saturate', 'shot_noise', 'snow', 'spatter', 'speckle_noise', 'zoom_blur', 'contrast', 'collection']\n",
"metrics = ['ATC-MC', 'COTT-MC']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def gather_results(cache_dir, corruptions, metrics):\n",
" data = []\n",
" for corr in corruptions:\n",
" for metric in metrics:\n",
" file_dir = f'{cache_dir}/{metric}_costs/pretrained_{corr}.json'\n",
" with open(file_dir, 'r') as f:\n",
" records = json.load(f)\n",
" \n",
" for sev, record in enumerate(records):\n",
" if 'ATC' in metric:\n",
" costs = np.array(record['costs'])\n",
" elif 'COTT' in metric:\n",
" costs = np.array(record['costs']) / 2 + 1\n",
" \n",
" mae = abs( (np.array(record['costs']) < record['t']).sum() / len(costs) - record['ood error'])\n",
" \n",
" data.extend(\n",
" [{'corr': corr, 'sev': sev, 'cost': cost, 'mae': mae, 'metric': metric} for cost in random.choices(costs, k=1000)]\n",
" )\n",
" return pd.DataFrame(data)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"source = gather_results(cache_dir, corruptions, metrics)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"alt.data_transformers.enable('default', max_rows=None)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"alt.Chart(data=source).transform_density(\n",
" 'cost',\n",
" groupby=['sev', 'corr', 'metric'],\n",
" as_=['cost', 'density'],\n",
").mark_area(opacity=0.8).encode(\n",
" x=\"cost:Q\",\n",
" y='density:Q',\n",
" color='metric:N'\n",
").properties(\n",
" width=100,\n",
" height=100\n",
").facet(\n",
" column='sev:N',\n",
" row='corr:N'\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"alt.Chart(data=source).mark_bar(opacity=0.8).encode(\n",
" x=\"metric:N\",\n",
" y='mae:Q',\n",
" color='metric:N'\n",
").properties(\n",
" width=20,\n",
" height=100\n",
").facet(\n",
" column='sev:N',\n",
" row='corr:N'\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"charts = alt.vconcat()\n",
"for corr in corruptions:\n",
" hcharts = alt.hconcat(title=corr)\n",
" for i in range(5):\n",
" sub_source = source[(source['sev'] == i) & (source['corr'] == corr)]\n",
" if len(sub_source) == 0:\n",
" break\n",
" \n",
" density = alt.Chart(data=sub_source).transform_density(\n",
" 'cost',\n",
" groupby=['sev', 'corr', 'metric'],\n",
" as_=['cost', 'density'],\n",
" ).mark_area(opacity=0.8).encode(\n",
" x=alt.X(\"cost:Q\"),\n",
" y=alt.Y('density:Q'),\n",
" color='metric:N'\n",
" ).properties(\n",
" width=80,\n",
" height=80\n",
" )\n",
" \n",
" y_max = source[(source['corr'] == corr)]['mae'].max()\n",
" mae = alt.Chart(data=sub_source).mark_bar(opacity=0.8).encode(\n",
" x=alt.X(\"metric:N\", axis=alt.Axis(labels=False)),\n",
" y=alt.Y('mae:Q', scale=alt.Scale(domain=[0, y_max]), axis=alt.Axis(format='.3f')),\n",
" color='metric:N'\n",
" ).properties(\n",
" width=20,\n",
" height=80\n",
" )\n",
" \n",
" hcharts |= density\n",
" hcharts |= mae\n",
"\n",
" charts &= hcharts\n",
"\n",
"charts"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "ood",
"language": "python",
"name": "ood"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}