FOT-OOD / notebooks / visualize_score_dist.ipynb
visualize_score_dist.ipynb
Raw
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import random\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import altair as alt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cache_dir = '../cache/ImageNet/resnet50_0-10'\n",
    "corruptions = ['brightness',  'defocus_blur', 'elastic_transform', 'fog', 'frost',  'gaussian_blur', \n",
    "               'gaussian_noise', 'glass_blur',  'impulse_noise',  'jpeg_compression',  'motion_blur', 'pixelate',\n",
    "               'saturate',  'shot_noise', 'snow', 'spatter',  'speckle_noise',  'zoom_blur', 'contrast', 'collection']\n",
    "metrics = ['ATC-MC', 'COTT-MC']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gather_results(cache_dir, corruptions, metrics):\n",
    "    data = []\n",
    "    for corr in corruptions:\n",
    "        for metric in metrics:\n",
    "            file_dir = f'{cache_dir}/{metric}_costs/pretrained_{corr}.json'\n",
    "            with open(file_dir, 'r') as f:\n",
    "                records = json.load(f)\n",
    "            \n",
    "            for sev, record in enumerate(records):\n",
    "                if 'ATC' in metric:\n",
    "                    costs = np.array(record['costs'])\n",
    "                elif 'COTT' in metric:\n",
    "                    costs = np.array(record['costs']) / 2 + 1\n",
    "                    \n",
    "                mae = abs( (np.array(record['costs']) < record['t']).sum() / len(costs) - record['ood error'])\n",
    "                    \n",
    "                data.extend(\n",
    "                    [{'corr': corr, 'sev': sev, 'cost': cost, 'mae': mae, 'metric': metric} for cost in random.choices(costs, k=1000)]\n",
    "                )\n",
    "    return pd.DataFrame(data)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "source = gather_results(cache_dir, corruptions, metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "alt.data_transformers.enable('default', max_rows=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "alt.Chart(data=source).transform_density(\n",
    "    'cost',\n",
    "    groupby=['sev', 'corr', 'metric'],\n",
    "    as_=['cost', 'density'],\n",
    ").mark_area(opacity=0.8).encode(\n",
    "    x=\"cost:Q\",\n",
    "    y='density:Q',\n",
    "    color='metric:N'\n",
    ").properties(\n",
    "    width=100,\n",
    "    height=100\n",
    ").facet(\n",
    "    column='sev:N',\n",
    "    row='corr:N'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "alt.Chart(data=source).mark_bar(opacity=0.8).encode(\n",
    "    x=\"metric:N\",\n",
    "    y='mae:Q',\n",
    "    color='metric:N'\n",
    ").properties(\n",
    "    width=20,\n",
    "    height=100\n",
    ").facet(\n",
    "    column='sev:N',\n",
    "    row='corr:N'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "charts = alt.vconcat()\n",
    "for corr in corruptions:\n",
    "    hcharts = alt.hconcat(title=corr)\n",
    "    for i in range(5):\n",
    "        sub_source = source[(source['sev'] == i) & (source['corr'] == corr)]\n",
    "        if len(sub_source) == 0:\n",
    "            break\n",
    "        \n",
    "        density = alt.Chart(data=sub_source).transform_density(\n",
    "            'cost',\n",
    "            groupby=['sev', 'corr', 'metric'],\n",
    "            as_=['cost', 'density'],\n",
    "        ).mark_area(opacity=0.8).encode(\n",
    "            x=alt.X(\"cost:Q\"),\n",
    "            y=alt.Y('density:Q'),\n",
    "            color='metric:N'\n",
    "        ).properties(\n",
    "            width=80,\n",
    "            height=80\n",
    "        )\n",
    "        \n",
    "        y_max = source[(source['corr'] == corr)]['mae'].max()\n",
    "        mae = alt.Chart(data=sub_source).mark_bar(opacity=0.8).encode(\n",
    "            x=alt.X(\"metric:N\", axis=alt.Axis(labels=False)),\n",
    "            y=alt.Y('mae:Q', scale=alt.Scale(domain=[0, y_max]), axis=alt.Axis(format='.3f')),\n",
    "            color='metric:N'\n",
    "        ).properties(\n",
    "            width=20,\n",
    "            height=80\n",
    "        )\n",
    "    \n",
    "        hcharts |= density\n",
    "        hcharts |= mae\n",
    "\n",
    "    charts &= hcharts\n",
    "\n",
    "charts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ood",
   "language": "python",
   "name": "ood"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}