FOT-OOD / notebooks / visualize_confidence.ipynb
visualize_confidence.ipynb
Raw
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import torch\n",
    "from collections import Counter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_conf_scores(emb_dir):   \n",
    "    res = pickle.load( open(emb_dir, 'rb') )\n",
    "    act, pred, tar = res['act'], res['pred'], res['tar']\n",
    "    sm = torch.nn.functional.softmax(act, dim=1)\n",
    "    acc = ( pred == tar ).sum() / len(pred)\n",
    "\n",
    "    print('averaged conf:', sm.max(1)[0].mean().item())\n",
    "    print('acc:', acc.item())\n",
    "    pred_counts = Counter(pred.tolist())\n",
    "    print('label dist:', [pred_counts[i] / len(tar) for i in range(act.shape[1])] )\n",
    "    print()\n",
    "\n",
    "    return sm.max(1)[0].tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds='Nonliving-26'\n",
    "epoch=450\n",
    "n_test_sample=2600\n",
    "iid_emb = f'../cache/{ds}/resnet50_1-{epoch}/scratch/id_m1-{epoch}_d1.pkl'\n",
    "novel_emb = f'../cache/{ds}/resnet50_1-{epoch}/scratch/od_pnovel_m1-{epoch}_cclean-0_n{n_test_sample}.pkl'\n",
    "corr_emb = f'../cache/{ds}/resnet50_1-{epoch}/scratch/od_psame_m1-{epoch}_cfrost-5_n{n_test_sample}.pkl'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "averaged conf: 0.7038018107414246\n",
      "acc: 0.6948999762535095\n",
      "label dist: [0.0361, 0.041, 0.0383, 0.0329, 0.0369, 0.0335, 0.0382, 0.0317, 0.051, 0.0409, 0.0375, 0.0463, 0.0371, 0.0366, 0.0431, 0.043, 0.0284, 0.0373, 0.0354, 0.0398, 0.0422, 0.0392, 0.0292, 0.0365, 0.0371, 0.0508]\n",
      "\n",
      "averaged conf: 0.5417004227638245\n",
      "acc: 0.3407692313194275\n",
      "label dist: [0.03230769230769231, 0.038461538461538464, 0.05, 0.023846153846153847, 0.02576923076923077, 0.03769230769230769, 0.04730769230769231, 0.022692307692307692, 0.03576923076923077, 0.03923076923076923, 0.028076923076923076, 0.05423076923076923, 0.035, 0.036153846153846154, 0.03923076923076923, 0.04153846153846154, 0.03769230769230769, 0.04692307692307692, 0.03769230769230769, 0.041923076923076924, 0.04423076923076923, 0.038461538461538464, 0.028846153846153848, 0.03807692307692308, 0.03807692307692308, 0.06076923076923077]\n",
      "\n",
      "averaged conf: 0.4134865999221802\n",
      "acc: 0.21230769157409668\n",
      "label dist: [0.02346153846153846, 0.008846153846153846, 0.07923076923076923, 0.023846153846153847, 0.03576923076923077, 0.008461538461538461, 0.010384615384615384, 0.001153846153846154, 0.013076923076923076, 0.011538461538461539, 0.025, 0.18807692307692309, 0.07769230769230769, 0.005, 0.004230769230769231, 0.04230769230769231, 0.002307692307692308, 0.01730769230769231, 0.047692307692307694, 0.051153846153846154, 0.2803846153846154, 0.0015384615384615385, 0.018461538461538463, 0.01, 0.0, 0.013076923076923076]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "iid_sm = get_conf_scores(iid_emb)\n",
    "novel_sm = get_conf_scores(novel_emb)\n",
    "corr_sm = get_conf_scores(corr_emb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "import altair as alt\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "# alt.data_transformers.enable('default', max_rows=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "iid_hist, _ = np.histogram(iid_sm, bins=[i/10 for i in range(11)])\n",
    "corr_hist, _ = np.histogram(corr_sm, bins=[i/10 for i in range(11)])\n",
    "novel_hist, _ = np.histogram(novel_sm, bins=[i/10 for i in range(11)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "iid_dens = iid_hist / sum(iid_hist)\n",
    "corr_dens = corr_hist / sum(corr_hist)\n",
    "novel_dens = novel_hist / sum(novel_hist)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "def transform_data(dens, group):\n",
    "    return [{'percent': i, 'conf': ind/10+0.05, 'group': group} for ind, i in enumerate(dens)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "dens = transform_data(iid_dens, 'iid') + transform_data(corr_dens, 'corr') + transform_data(novel_dens, 'novel')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/workspace/lu35/anaconda3/envs/ood/lib/python3.9/site-packages/altair/utils/core.py:317: FutureWarning: iteritems is deprecated and will be removed in a future version. Use .items instead.\n",
      "  for col_name, dtype in df.dtypes.iteritems():\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "<div id=\"altair-viz-2010ba3f69d945d79d8332f220650e0c\"></div>\n",
       "<script type=\"text/javascript\">\n",
       "  var VEGA_DEBUG = (typeof VEGA_DEBUG == \"undefined\") ? {} : VEGA_DEBUG;\n",
       "  (function(spec, embedOpt){\n",
       "    let outputDiv = document.currentScript.previousElementSibling;\n",
       "    if (outputDiv.id !== \"altair-viz-2010ba3f69d945d79d8332f220650e0c\") {\n",
       "      outputDiv = document.getElementById(\"altair-viz-2010ba3f69d945d79d8332f220650e0c\");\n",
       "    }\n",
       "    const paths = {\n",
       "      \"vega\": \"https://cdn.jsdelivr.net/npm//vega@5?noext\",\n",
       "      \"vega-lib\": \"https://cdn.jsdelivr.net/npm//vega-lib?noext\",\n",
       "      \"vega-lite\": \"https://cdn.jsdelivr.net/npm//vega-lite@4.17.0?noext\",\n",
       "      \"vega-embed\": \"https://cdn.jsdelivr.net/npm//vega-embed@6?noext\",\n",
       "    };\n",
       "\n",
       "    function maybeLoadScript(lib, version) {\n",
       "      var key = `${lib.replace(\"-\", \"\")}_version`;\n",
       "      return (VEGA_DEBUG[key] == version) ?\n",
       "        Promise.resolve(paths[lib]) :\n",
       "        new Promise(function(resolve, reject) {\n",
       "          var s = document.createElement('script');\n",
       "          document.getElementsByTagName(\"head\")[0].appendChild(s);\n",
       "          s.async = true;\n",
       "          s.onload = () => {\n",
       "            VEGA_DEBUG[key] = version;\n",
       "            return resolve(paths[lib]);\n",
       "          };\n",
       "          s.onerror = () => reject(`Error loading script: ${paths[lib]}`);\n",
       "          s.src = paths[lib];\n",
       "        });\n",
       "    }\n",
       "\n",
       "    function showError(err) {\n",
       "      outputDiv.innerHTML = `<div class=\"error\" style=\"color:red;\">${err}</div>`;\n",
       "      throw err;\n",
       "    }\n",
       "\n",
       "    function displayChart(vegaEmbed) {\n",
       "      vegaEmbed(outputDiv, spec, embedOpt)\n",
       "        .catch(err => showError(`Javascript Error: ${err.message}<br>This usually means there's a typo in your chart specification. See the javascript console for the full traceback.`));\n",
       "    }\n",
       "\n",
       "    if(typeof define === \"function\" && define.amd) {\n",
       "      requirejs.config({paths});\n",
       "      require([\"vega-embed\"], displayChart, err => showError(`Error loading script: ${err.message}`));\n",
       "    } else {\n",
       "      maybeLoadScript(\"vega\", \"5\")\n",
       "        .then(() => maybeLoadScript(\"vega-lite\", \"4.17.0\"))\n",
       "        .then(() => maybeLoadScript(\"vega-embed\", \"6\"))\n",
       "        .catch(showError)\n",
       "        .then(() => displayChart(vegaEmbed));\n",
       "    }\n",
       "  })({\"config\": {\"view\": {\"continuousWidth\": 400, \"continuousHeight\": 300}}, \"data\": {\"name\": \"data-db6cb0b70751f803e464295837688a91\"}, \"mark\": \"bar\", \"encoding\": {\"color\": {\"field\": \"group\", \"type\": \"nominal\"}, \"column\": {\"field\": \"group\", \"type\": \"nominal\"}, \"x\": {\"bin\": true, \"field\": \"conf\", \"type\": \"quantitative\"}, \"y\": {\"field\": \"percent\", \"type\": \"quantitative\"}}, \"height\": 200, \"width\": 200, \"$schema\": \"https://vega.github.io/schema/vega-lite/v4.17.0.json\", \"datasets\": {\"data-db6cb0b70751f803e464295837688a91\": [{\"percent\": 0.0, \"conf\": 0.05, \"group\": \"iid\"}, {\"percent\": 0.041, \"conf\": 0.15000000000000002, \"group\": \"iid\"}, {\"percent\": 0.0899, \"conf\": 0.25, \"group\": \"iid\"}, {\"percent\": 0.086, \"conf\": 0.35, \"group\": \"iid\"}, {\"percent\": 0.083, \"conf\": 0.45, \"group\": \"iid\"}, {\"percent\": 0.073, \"conf\": 0.55, \"group\": \"iid\"}, {\"percent\": 0.0655, \"conf\": 0.65, \"group\": \"iid\"}, {\"percent\": 0.0678, \"conf\": 0.75, \"group\": \"iid\"}, {\"percent\": 0.0796, \"conf\": 0.8500000000000001, \"group\": \"iid\"}, {\"percent\": 0.4142, \"conf\": 0.9500000000000001, \"group\": \"iid\"}, {\"percent\": 0.0015384615384615385, \"conf\": 0.05, \"group\": \"corr\"}, {\"percent\": 0.15076923076923077, \"conf\": 0.15000000000000002, \"group\": \"corr\"}, {\"percent\": 0.2403846153846154, \"conf\": 0.25, \"group\": \"corr\"}, {\"percent\": 0.18846153846153846, \"conf\": 0.35, \"group\": \"corr\"}, {\"percent\": 0.1276923076923077, \"conf\": 0.45, \"group\": \"corr\"}, {\"percent\": 0.08807692307692308, \"conf\": 0.55, \"group\": \"corr\"}, {\"percent\": 0.0673076923076923, \"conf\": 0.65, \"group\": \"corr\"}, {\"percent\": 0.05076923076923077, \"conf\": 0.75, \"group\": \"corr\"}, {\"percent\": 0.046153846153846156, \"conf\": 0.8500000000000001, \"group\": \"corr\"}, {\"percent\": 0.03884615384615384, \"conf\": 0.9500000000000001, \"group\": \"corr\"}, {\"percent\": 0.0003846153846153846, \"conf\": 0.05, \"group\": \"novel\"}, {\"percent\": 0.0926923076923077, \"conf\": 0.15000000000000002, \"group\": \"novel\"}, {\"percent\": 0.15846153846153846, \"conf\": 0.25, \"group\": \"novel\"}, {\"percent\": 0.14692307692307693, \"conf\": 0.35, \"group\": \"novel\"}, {\"percent\": 0.115, \"conf\": 0.45, \"group\": \"novel\"}, {\"percent\": 0.09653846153846155, \"conf\": 0.55, \"group\": \"novel\"}, {\"percent\": 0.08153846153846153, \"conf\": 0.65, \"group\": \"novel\"}, {\"percent\": 0.06384615384615384, \"conf\": 0.75, \"group\": \"novel\"}, {\"percent\": 0.06653846153846153, \"conf\": 0.8500000000000001, \"group\": \"novel\"}, {\"percent\": 0.17807692307692308, \"conf\": 0.9500000000000001, \"group\": \"novel\"}]}}, {\"mode\": \"vega-lite\"});\n",
       "</script>"
      ],
      "text/plain": [
       "alt.Chart(...)"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "alt.Chart(pd.DataFrame(dens)).mark_bar().encode(\n",
    "    x=alt.X('conf:Q', bin=True),\n",
    "    y=alt.Y('percent:Q'),\n",
    "    column='group:N',\n",
    "    color='group:N'\n",
    ").properties(\n",
    "    width=200,\n",
    "    height=200\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ood",
   "language": "python",
   "name": "ood"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}