FOT-OOD / notebooks / visualize_sensitivity_results.ipynb
visualize_sensitivity_results.ipynb
Raw
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import math\n",
    "import json\n",
    "import glob\n",
    "from functools import reduce\n",
    "import altair as alt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mae(dataset, model, metric, n_sample, seed):\n",
    "    result_dir = f'../results/{dataset}/{model}_{seed}/{metric}_{n_sample}'\n",
    "      \n",
    "    result_fs = glob.glob(f'{result_dir}/*.json')\n",
    "    results = []\n",
    "    for file in result_fs:\n",
    "        with open(file, 'r') as f:\n",
    "            data = json.load(f)\n",
    "        results.extend(data)\n",
    "    \n",
    "    maes = [abs(result['metric'] - result['error']) for result in results]\n",
    "        \n",
    "    return np.array(maes).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sens_results(metric, dataset, arch):\n",
    "    n_samples = [10000, 5000, 2000, 1000, 500, 200, 100]\n",
    "    seed = '1'\n",
    "    sens = []\n",
    "    for n in n_samples: \n",
    "        mae = get_mae(dataset, arch, metric, n, seed)\n",
    "        sens.append({\n",
    "            'metric': metric,\n",
    "            'dataset': dataset,\n",
    "            'n_sample': n,\n",
    "            'arch': arch,\n",
    "            'mae': mae\n",
    "        })\n",
    "    return sens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics = ['EMD', 'ATC', 'REMD_1.25']\n",
    "datasets = ['CIFAR-10', 'CIFAR-100', 'Tiny-ImageNet']\n",
    "sens = []\n",
    "for metric in metrics:\n",
    "    for dataset in datasets:\n",
    "        sens += get_sens_results(metric, dataset, 'resnet50')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "<div id=\"altair-viz-4b59fbffcc37474daa32cb84404743a9\"></div>\n",
       "<script type=\"text/javascript\">\n",
       "  var VEGA_DEBUG = (typeof VEGA_DEBUG == \"undefined\") ? {} : VEGA_DEBUG;\n",
       "  (function(spec, embedOpt){\n",
       "    let outputDiv = document.currentScript.previousElementSibling;\n",
       "    if (outputDiv.id !== \"altair-viz-4b59fbffcc37474daa32cb84404743a9\") {\n",
       "      outputDiv = document.getElementById(\"altair-viz-4b59fbffcc37474daa32cb84404743a9\");\n",
       "    }\n",
       "    const paths = {\n",
       "      \"vega\": \"https://cdn.jsdelivr.net/npm//vega@5?noext\",\n",
       "      \"vega-lib\": \"https://cdn.jsdelivr.net/npm//vega-lib?noext\",\n",
       "      \"vega-lite\": \"https://cdn.jsdelivr.net/npm//vega-lite@4.17.0?noext\",\n",
       "      \"vega-embed\": \"https://cdn.jsdelivr.net/npm//vega-embed@6?noext\",\n",
       "    };\n",
       "\n",
       "    function maybeLoadScript(lib, version) {\n",
       "      var key = `${lib.replace(\"-\", \"\")}_version`;\n",
       "      return (VEGA_DEBUG[key] == version) ?\n",
       "        Promise.resolve(paths[lib]) :\n",
       "        new Promise(function(resolve, reject) {\n",
       "          var s = document.createElement('script');\n",
       "          document.getElementsByTagName(\"head\")[0].appendChild(s);\n",
       "          s.async = true;\n",
       "          s.onload = () => {\n",
       "            VEGA_DEBUG[key] = version;\n",
       "            return resolve(paths[lib]);\n",
       "          };\n",
       "          s.onerror = () => reject(`Error loading script: ${paths[lib]}`);\n",
       "          s.src = paths[lib];\n",
       "        });\n",
       "    }\n",
       "\n",
       "    function showError(err) {\n",
       "      outputDiv.innerHTML = `<div class=\"error\" style=\"color:red;\">${err}</div>`;\n",
       "      throw err;\n",
       "    }\n",
       "\n",
       "    function displayChart(vegaEmbed) {\n",
       "      vegaEmbed(outputDiv, spec, embedOpt)\n",
       "        .catch(err => showError(`Javascript Error: ${err.message}<br>This usually means there's a typo in your chart specification. See the javascript console for the full traceback.`));\n",
       "    }\n",
       "\n",
       "    if(typeof define === \"function\" && define.amd) {\n",
       "      requirejs.config({paths});\n",
       "      require([\"vega-embed\"], displayChart, err => showError(`Error loading script: ${err.message}`));\n",
       "    } else {\n",
       "      maybeLoadScript(\"vega\", \"5\")\n",
       "        .then(() => maybeLoadScript(\"vega-lite\", \"4.17.0\"))\n",
       "        .then(() => maybeLoadScript(\"vega-embed\", \"6\"))\n",
       "        .catch(showError)\n",
       "        .then(() => displayChart(vegaEmbed));\n",
       "    }\n",
       "  })({\"config\": {\"view\": {\"continuousWidth\": 400, \"continuousHeight\": 300}, \"axis\": {\"labelFontSize\": 16, \"titleFontSize\": 18}, \"legend\": {\"labelFontSize\": 18, \"titleFontSize\": 20}}, \"data\": {\"values\": [{\"metric\": \"EMD\", \"dataset\": \"CIFAR-10\", \"n_sample\": 10000, \"arch\": \"resnet50\", \"mae\": 0.014469379930368452}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-10\", \"n_sample\": 5000, \"arch\": \"resnet50\", \"mae\": 0.015004043102478862}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-10\", \"n_sample\": 2000, \"arch\": \"resnet50\", \"mae\": 0.016281259801561865}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-10\", \"n_sample\": 1000, \"arch\": \"resnet50\", \"mae\": 0.020989433447233103}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-10\", \"n_sample\": 500, \"arch\": \"resnet50\", \"mae\": 0.04018579669484187}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-10\", \"n_sample\": 200, \"arch\": \"resnet50\", \"mae\": 0.04615530701310061}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-10\", \"n_sample\": 100, \"arch\": \"resnet50\", \"mae\": 0.045474082633063746}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-100\", \"n_sample\": 10000, \"arch\": \"resnet50\", \"mae\": 0.01992340799863755}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-100\", \"n_sample\": 5000, \"arch\": \"resnet50\", \"mae\": 0.020585536769618707}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-100\", \"n_sample\": 2000, \"arch\": \"resnet50\", \"mae\": 0.024342450141671763}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-100\", \"n_sample\": 1000, \"arch\": \"resnet50\", \"mae\": 0.03338057958699007}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-100\", \"n_sample\": 500, \"arch\": \"resnet50\", \"mae\": 0.0534508796307282}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-100\", \"n_sample\": 200, \"arch\": \"resnet50\", \"mae\": 0.10063707337972706}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-100\", \"n_sample\": 100, \"arch\": \"resnet50\", \"mae\": 0.19019687173354596}, {\"metric\": \"EMD\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 10000, \"arch\": \"resnet50\", \"mae\": 0.032434747111411456}, {\"metric\": \"EMD\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 5000, \"arch\": \"resnet50\", \"mae\": 0.029363516945500913}, {\"metric\": \"EMD\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 2000, \"arch\": \"resnet50\", \"mae\": 0.02454799185399754}, {\"metric\": \"EMD\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 1000, \"arch\": \"resnet50\", \"mae\": 0.03127610869061543}, {\"metric\": \"EMD\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 500, \"arch\": \"resnet50\", \"mae\": 0.04364479950047747}, {\"metric\": \"EMD\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 200, \"arch\": \"resnet50\", \"mae\": 0.09441336930069869}, {\"metric\": \"EMD\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 100, \"arch\": \"resnet50\", \"mae\": 0.20636375601731388}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-10\", \"n_sample\": 10000, \"arch\": \"resnet50\", \"mae\": 0.0316802075425163}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-10\", \"n_sample\": 5000, \"arch\": \"resnet50\", \"mae\": 0.03201457563489675}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-10\", \"n_sample\": 2000, \"arch\": \"resnet50\", \"mae\": 0.031520850390195844}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-10\", \"n_sample\": 1000, \"arch\": \"resnet50\", \"mae\": 0.028937513527770835}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-10\", \"n_sample\": 500, \"arch\": \"resnet50\", \"mae\": 0.03106251671910286}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-10\", \"n_sample\": 200, \"arch\": \"resnet50\", \"mae\": 0.03927082135031621}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-10\", \"n_sample\": 100, \"arch\": \"resnet50\", \"mae\": 0.045416658793886504}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-100\", \"n_sample\": 10000, \"arch\": \"resnet50\", \"mae\": 0.03524166695301732}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-100\", \"n_sample\": 5000, \"arch\": \"resnet50\", \"mae\": 0.03512082161357005}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-100\", \"n_sample\": 2000, \"arch\": \"resnet50\", \"mae\": 0.034406267752870916}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-100\", \"n_sample\": 1000, \"arch\": \"resnet50\", \"mae\": 0.03490626962669194}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-100\", \"n_sample\": 500, \"arch\": \"resnet50\", \"mae\": 0.03997918866078059}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-100\", \"n_sample\": 200, \"arch\": \"resnet50\", \"mae\": 0.043802075817560154}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-100\", \"n_sample\": 100, \"arch\": \"resnet50\", \"mae\": 0.047083325767889615}, {\"metric\": \"ATC\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 10000, \"arch\": \"resnet50\", \"mae\": 0.03485526348164207}, {\"metric\": \"ATC\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 5000, \"arch\": \"resnet50\", \"mae\": 0.034565782804316596}, {\"metric\": \"ATC\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 2000, \"arch\": \"resnet50\", \"mae\": 0.03525658435531352}, {\"metric\": \"ATC\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 1000, \"arch\": \"resnet50\", \"mae\": 0.03613158635089272}, {\"metric\": \"ATC\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 500, \"arch\": \"resnet50\", \"mae\": 0.042947374909331926}, {\"metric\": \"ATC\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 200, \"arch\": \"resnet50\", \"mae\": 0.04743420748922386}, {\"metric\": \"ATC\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 100, \"arch\": \"resnet50\", \"mae\": 0.05434210061634842}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-10\", \"n_sample\": 10000, \"arch\": \"resnet50\", \"mae\": 0.04440492593811391}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-10\", \"n_sample\": 5000, \"arch\": \"resnet50\", \"mae\": 0.04269366873211624}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-10\", \"n_sample\": 2000, \"arch\": \"resnet50\", \"mae\": 0.042810486328272036}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-10\", \"n_sample\": 1000, \"arch\": \"resnet50\", \"mae\": 0.04571517167147061}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-10\", \"n_sample\": 500, \"arch\": \"resnet50\", \"mae\": 0.051640504498346416}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-10\", \"n_sample\": 200, \"arch\": \"resnet50\", \"mae\": 0.05429068888781783}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-10\", \"n_sample\": 100, \"arch\": \"resnet50\", \"mae\": 0.05624488137029399}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-100\", \"n_sample\": 10000, \"arch\": \"resnet50\", \"mae\": 0.0934960299973343}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-100\", \"n_sample\": 5000, \"arch\": \"resnet50\", \"mae\": 0.09297857507000007}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-100\", \"n_sample\": 2000, \"arch\": \"resnet50\", \"mae\": 0.08642741723291288}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-100\", \"n_sample\": 1000, \"arch\": \"resnet50\", \"mae\": 0.0810600697000193}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-100\", \"n_sample\": 500, \"arch\": \"resnet50\", \"mae\": 0.07525555161545947}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-100\", \"n_sample\": 200, \"arch\": \"resnet50\", \"mae\": 0.08454419112086059}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-100\", \"n_sample\": 100, \"arch\": \"resnet50\", \"mae\": 0.12893080543225685}, {\"metric\": \"REMD_1.25\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 10000, \"arch\": \"resnet50\", \"mae\": 0.14702108353799165}, {\"metric\": \"REMD_1.25\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 5000, \"arch\": \"resnet50\", \"mae\": 0.14431254023133747}, {\"metric\": \"REMD_1.25\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 2000, \"arch\": \"resnet50\", \"mae\": 0.1329430498002896}, {\"metric\": \"REMD_1.25\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 1000, \"arch\": \"resnet50\", \"mae\": 0.12127065566562693}, {\"metric\": \"REMD_1.25\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 500, \"arch\": \"resnet50\", \"mae\": 0.106019251174766}, {\"metric\": \"REMD_1.25\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 200, \"arch\": \"resnet50\", \"mae\": 0.08902080514337413}, {\"metric\": \"REMD_1.25\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 100, \"arch\": \"resnet50\", \"mae\": 0.12326658876760059}]}, \"facet\": {\"column\": {\"field\": \"dataset\", \"header\": {\"labelFontSize\": 18}, \"type\": \"nominal\"}}, \"spec\": {\"mark\": {\"type\": \"line\", \"point\": {\"fill\": \"white\", \"filled\": false, \"size\": 50}}, \"encoding\": {\"color\": {\"field\": \"metric\", \"type\": \"nominal\"}, \"x\": {\"axis\": {\"values\": [100, 1000, 10000]}, \"field\": \"n_sample\", \"scale\": {\"reverse\": false, \"type\": \"log\"}, \"title\": \"Number of Samples\", \"type\": \"quantitative\"}, \"y\": {\"field\": \"mae\", \"scale\": {\"domain\": [0, 0.22]}, \"title\": \"Mean Absolute Error\", \"type\": \"quantitative\"}}, \"height\": 150, \"width\": 150}, \"$schema\": \"https://vega.github.io/schema/vega-lite/v4.17.0.json\"}, {\"mode\": \"vega-lite\"});\n",
       "</script>"
      ],
      "text/plain": [
       "alt.FacetChart(...)"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "alt.Chart(alt.Data(values=sens)).mark_line(\n",
    "    point=alt.OverlayMarkDef(size=50, filled=False, fill='white')\n",
    ").encode(\n",
    "    x=alt.X('n_sample:Q', title='Number of Samples', scale=alt.Scale(type=\"log\", reverse=False), axis=alt.Axis(values=[100, 1000, 10000])), \n",
    "    y=alt.Y('mae:Q', scale=alt.Scale(domain=[0, 0.22]), title='Mean Absolute Error'),\n",
    "    color=alt.Color('metric:N')\n",
    ").properties(\n",
    "    width=150,\n",
    "    height=150\n",
    ").facet(\n",
    "    column=alt.Column('dataset:N', header=alt.Header(labelFontSize=18))\n",
    ").configure_axis(\n",
    "    labelFontSize=16,\n",
    "    titleFontSize=18,\n",
    ").configure_legend(\n",
    "    titleFontSize=20,\n",
    "    labelFontSize=18\n",
    ") "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "3f1da77dfecb8ea08abe23a06d7ff88ed6156499ac883f4692cab9e51a9d5766"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}