{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "import random\n", "import numpy as np\n", "import pandas as pd\n", "import altair as alt" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cache_dir = '../cache/ImageNet/resnet50_0-10'\n", "corruptions = ['brightness', 'defocus_blur', 'elastic_transform', 'fog', 'frost', 'gaussian_blur', \n", " 'gaussian_noise', 'glass_blur', 'impulse_noise', 'jpeg_compression', 'motion_blur', 'pixelate',\n", " 'saturate', 'shot_noise', 'snow', 'spatter', 'speckle_noise', 'zoom_blur', 'contrast', 'collection']\n", "metrics = ['ATC-MC', 'COTT-MC']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def gather_results(cache_dir, corruptions, metrics):\n", " data = []\n", " for corr in corruptions:\n", " for metric in metrics:\n", " file_dir = f'{cache_dir}/{metric}_costs/pretrained_{corr}.json'\n", " with open(file_dir, 'r') as f:\n", " records = json.load(f)\n", " \n", " for sev, record in enumerate(records):\n", " if 'ATC' in metric:\n", " costs = np.array(record['costs'])\n", " elif 'COTT' in metric:\n", " costs = np.array(record['costs']) / 2 + 1\n", " \n", " mae = abs( (np.array(record['costs']) < record['t']).sum() / len(costs) - record['ood error'])\n", " \n", " data.extend(\n", " [{'corr': corr, 'sev': sev, 'cost': cost, 'mae': mae, 'metric': metric} for cost in random.choices(costs, k=1000)]\n", " )\n", " return pd.DataFrame(data)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "source = gather_results(cache_dir, corruptions, metrics)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "alt.data_transformers.enable('default', max_rows=None)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "alt.Chart(data=source).transform_density(\n", " 'cost',\n", " groupby=['sev', 'corr', 'metric'],\n", " as_=['cost', 'density'],\n", ").mark_area(opacity=0.8).encode(\n", " x=\"cost:Q\",\n", " y='density:Q',\n", " color='metric:N'\n", ").properties(\n", " width=100,\n", " height=100\n", ").facet(\n", " column='sev:N',\n", " row='corr:N'\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "alt.Chart(data=source).mark_bar(opacity=0.8).encode(\n", " x=\"metric:N\",\n", " y='mae:Q',\n", " color='metric:N'\n", ").properties(\n", " width=20,\n", " height=100\n", ").facet(\n", " column='sev:N',\n", " row='corr:N'\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "charts = alt.vconcat()\n", "for corr in corruptions:\n", " hcharts = alt.hconcat(title=corr)\n", " for i in range(5):\n", " sub_source = source[(source['sev'] == i) & (source['corr'] == corr)]\n", " if len(sub_source) == 0:\n", " break\n", " \n", " density = alt.Chart(data=sub_source).transform_density(\n", " 'cost',\n", " groupby=['sev', 'corr', 'metric'],\n", " as_=['cost', 'density'],\n", " ).mark_area(opacity=0.8).encode(\n", " x=alt.X(\"cost:Q\"),\n", " y=alt.Y('density:Q'),\n", " color='metric:N'\n", " ).properties(\n", " width=80,\n", " height=80\n", " )\n", " \n", " y_max = source[(source['corr'] == corr)]['mae'].max()\n", " mae = alt.Chart(data=sub_source).mark_bar(opacity=0.8).encode(\n", " x=alt.X(\"metric:N\", axis=alt.Axis(labels=False)),\n", " y=alt.Y('mae:Q', scale=alt.Scale(domain=[0, y_max]), axis=alt.Axis(format='.3f')),\n", " color='metric:N'\n", " ).properties(\n", " width=20,\n", " height=80\n", " )\n", " \n", " hcharts |= density\n", " hcharts |= mae\n", "\n", " charts &= hcharts\n", "\n", "charts" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "ood", "language": "python", "name": "ood" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }