{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import pickle\n", "import torch\n", "from collections import Counter" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def get_conf_scores(emb_dir): \n", " res = pickle.load( open(emb_dir, 'rb') )\n", " act, pred, tar = res['act'], res['pred'], res['tar']\n", " sm = torch.nn.functional.softmax(act, dim=1)\n", " acc = ( pred == tar ).sum() / len(pred)\n", "\n", " print('averaged conf:', sm.max(1)[0].mean().item())\n", " print('acc:', acc.item())\n", " pred_counts = Counter(pred.tolist())\n", " print('label dist:', [pred_counts[i] / len(tar) for i in range(act.shape[1])] )\n", " print()\n", "\n", " return sm.max(1)[0].tolist()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "ds='Nonliving-26'\n", "epoch=450\n", "n_test_sample=2600\n", "iid_emb = f'../cache/{ds}/resnet50_1-{epoch}/scratch/id_m1-{epoch}_d1.pkl'\n", "novel_emb = f'../cache/{ds}/resnet50_1-{epoch}/scratch/od_pnovel_m1-{epoch}_cclean-0_n{n_test_sample}.pkl'\n", "corr_emb = f'../cache/{ds}/resnet50_1-{epoch}/scratch/od_psame_m1-{epoch}_cfrost-5_n{n_test_sample}.pkl'" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "averaged conf: 0.7038018107414246\n", "acc: 0.6948999762535095\n", "label dist: [0.0361, 0.041, 0.0383, 0.0329, 0.0369, 0.0335, 0.0382, 0.0317, 0.051, 0.0409, 0.0375, 0.0463, 0.0371, 0.0366, 0.0431, 0.043, 0.0284, 0.0373, 0.0354, 0.0398, 0.0422, 0.0392, 0.0292, 0.0365, 0.0371, 0.0508]\n", "\n", "averaged conf: 0.5417004227638245\n", "acc: 0.3407692313194275\n", "label dist: [0.03230769230769231, 0.038461538461538464, 0.05, 0.023846153846153847, 0.02576923076923077, 0.03769230769230769, 0.04730769230769231, 0.022692307692307692, 0.03576923076923077, 0.03923076923076923, 0.028076923076923076, 0.05423076923076923, 0.035, 0.036153846153846154, 0.03923076923076923, 0.04153846153846154, 0.03769230769230769, 0.04692307692307692, 0.03769230769230769, 0.041923076923076924, 0.04423076923076923, 0.038461538461538464, 0.028846153846153848, 0.03807692307692308, 0.03807692307692308, 0.06076923076923077]\n", "\n", "averaged conf: 0.4134865999221802\n", "acc: 0.21230769157409668\n", "label dist: [0.02346153846153846, 0.008846153846153846, 0.07923076923076923, 0.023846153846153847, 0.03576923076923077, 0.008461538461538461, 0.010384615384615384, 0.001153846153846154, 0.013076923076923076, 0.011538461538461539, 0.025, 0.18807692307692309, 0.07769230769230769, 0.005, 0.004230769230769231, 0.04230769230769231, 0.002307692307692308, 0.01730769230769231, 0.047692307692307694, 0.051153846153846154, 0.2803846153846154, 0.0015384615384615385, 0.018461538461538463, 0.01, 0.0, 0.013076923076923076]\n", "\n" ] } ], "source": [ "iid_sm = get_conf_scores(iid_emb)\n", "novel_sm = get_conf_scores(novel_emb)\n", "corr_sm = get_conf_scores(corr_emb)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "import ot" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "import altair as alt\n", "import pandas as pd\n", "import numpy as np\n", "\n", "# alt.data_transformers.enable('default', max_rows=None)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "iid_hist, _ = np.histogram(iid_sm, bins=[i/10 for i in range(11)])\n", "corr_hist, _ = np.histogram(corr_sm, bins=[i/10 for i in range(11)])\n", "novel_hist, _ = np.histogram(novel_sm, bins=[i/10 for i in range(11)])" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "iid_dens = iid_hist / sum(iid_hist)\n", "corr_dens = corr_hist / sum(corr_hist)\n", "novel_dens = novel_hist / sum(novel_hist)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "def transform_data(dens, group):\n", " return [{'percent': i, 'conf': ind/10+0.05, 'group': group} for ind, i in enumerate(dens)]" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "dens = transform_data(iid_dens, 'iid') + transform_data(corr_dens, 'corr') + transform_data(novel_dens, 'novel')" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/workspace/lu35/anaconda3/envs/ood/lib/python3.9/site-packages/altair/utils/core.py:317: FutureWarning: iteritems is deprecated and will be removed in a future version. Use .items instead.\n", " for col_name, dtype in df.dtypes.iteritems():\n" ] }, { "data": { "text/html": [ "\n", "<div id=\"altair-viz-2010ba3f69d945d79d8332f220650e0c\"></div>\n", "<script type=\"text/javascript\">\n", " var VEGA_DEBUG = (typeof VEGA_DEBUG == \"undefined\") ? {} : VEGA_DEBUG;\n", " (function(spec, embedOpt){\n", " let outputDiv = document.currentScript.previousElementSibling;\n", " if (outputDiv.id !== \"altair-viz-2010ba3f69d945d79d8332f220650e0c\") {\n", " outputDiv = document.getElementById(\"altair-viz-2010ba3f69d945d79d8332f220650e0c\");\n", " }\n", " const paths = {\n", " \"vega\": \"https://cdn.jsdelivr.net/npm//vega@5?noext\",\n", " \"vega-lib\": \"https://cdn.jsdelivr.net/npm//vega-lib?noext\",\n", " \"vega-lite\": \"https://cdn.jsdelivr.net/npm//vega-lite@4.17.0?noext\",\n", " \"vega-embed\": \"https://cdn.jsdelivr.net/npm//vega-embed@6?noext\",\n", " };\n", "\n", " function maybeLoadScript(lib, version) {\n", " var key = `${lib.replace(\"-\", \"\")}_version`;\n", " return (VEGA_DEBUG[key] == version) ?\n", " Promise.resolve(paths[lib]) :\n", " new Promise(function(resolve, reject) {\n", " var s = document.createElement('script');\n", " document.getElementsByTagName(\"head\")[0].appendChild(s);\n", " s.async = true;\n", " s.onload = () => {\n", " VEGA_DEBUG[key] = version;\n", " return resolve(paths[lib]);\n", " };\n", " s.onerror = () => reject(`Error loading script: ${paths[lib]}`);\n", " s.src = paths[lib];\n", " });\n", " }\n", "\n", " function showError(err) {\n", " outputDiv.innerHTML = `<div class=\"error\" style=\"color:red;\">${err}</div>`;\n", " throw err;\n", " }\n", "\n", " function displayChart(vegaEmbed) {\n", " vegaEmbed(outputDiv, spec, embedOpt)\n", " .catch(err => showError(`Javascript Error: ${err.message}<br>This usually means there's a typo in your chart specification. See the javascript console for the full traceback.`));\n", " }\n", "\n", " if(typeof define === \"function\" && define.amd) {\n", " requirejs.config({paths});\n", " require([\"vega-embed\"], displayChart, err => showError(`Error loading script: ${err.message}`));\n", " } else {\n", " maybeLoadScript(\"vega\", \"5\")\n", " .then(() => maybeLoadScript(\"vega-lite\", \"4.17.0\"))\n", " .then(() => maybeLoadScript(\"vega-embed\", \"6\"))\n", " .catch(showError)\n", " .then(() => displayChart(vegaEmbed));\n", " }\n", " })({\"config\": {\"view\": {\"continuousWidth\": 400, \"continuousHeight\": 300}}, \"data\": {\"name\": \"data-db6cb0b70751f803e464295837688a91\"}, \"mark\": \"bar\", \"encoding\": {\"color\": {\"field\": \"group\", \"type\": \"nominal\"}, \"column\": {\"field\": \"group\", \"type\": \"nominal\"}, \"x\": {\"bin\": true, \"field\": \"conf\", \"type\": \"quantitative\"}, \"y\": {\"field\": \"percent\", \"type\": \"quantitative\"}}, \"height\": 200, \"width\": 200, \"$schema\": \"https://vega.github.io/schema/vega-lite/v4.17.0.json\", \"datasets\": {\"data-db6cb0b70751f803e464295837688a91\": [{\"percent\": 0.0, \"conf\": 0.05, \"group\": \"iid\"}, {\"percent\": 0.041, \"conf\": 0.15000000000000002, \"group\": \"iid\"}, {\"percent\": 0.0899, \"conf\": 0.25, \"group\": \"iid\"}, {\"percent\": 0.086, \"conf\": 0.35, \"group\": \"iid\"}, {\"percent\": 0.083, \"conf\": 0.45, \"group\": \"iid\"}, {\"percent\": 0.073, \"conf\": 0.55, \"group\": \"iid\"}, {\"percent\": 0.0655, \"conf\": 0.65, \"group\": \"iid\"}, {\"percent\": 0.0678, \"conf\": 0.75, \"group\": \"iid\"}, {\"percent\": 0.0796, \"conf\": 0.8500000000000001, \"group\": \"iid\"}, {\"percent\": 0.4142, \"conf\": 0.9500000000000001, \"group\": \"iid\"}, {\"percent\": 0.0015384615384615385, \"conf\": 0.05, \"group\": \"corr\"}, {\"percent\": 0.15076923076923077, \"conf\": 0.15000000000000002, \"group\": \"corr\"}, {\"percent\": 0.2403846153846154, \"conf\": 0.25, \"group\": \"corr\"}, {\"percent\": 0.18846153846153846, \"conf\": 0.35, \"group\": \"corr\"}, {\"percent\": 0.1276923076923077, \"conf\": 0.45, \"group\": \"corr\"}, {\"percent\": 0.08807692307692308, \"conf\": 0.55, \"group\": \"corr\"}, {\"percent\": 0.0673076923076923, \"conf\": 0.65, \"group\": \"corr\"}, {\"percent\": 0.05076923076923077, \"conf\": 0.75, \"group\": \"corr\"}, {\"percent\": 0.046153846153846156, \"conf\": 0.8500000000000001, \"group\": \"corr\"}, {\"percent\": 0.03884615384615384, \"conf\": 0.9500000000000001, \"group\": \"corr\"}, {\"percent\": 0.0003846153846153846, \"conf\": 0.05, \"group\": \"novel\"}, {\"percent\": 0.0926923076923077, \"conf\": 0.15000000000000002, \"group\": \"novel\"}, {\"percent\": 0.15846153846153846, \"conf\": 0.25, \"group\": \"novel\"}, {\"percent\": 0.14692307692307693, \"conf\": 0.35, \"group\": \"novel\"}, {\"percent\": 0.115, \"conf\": 0.45, \"group\": \"novel\"}, {\"percent\": 0.09653846153846155, \"conf\": 0.55, \"group\": \"novel\"}, {\"percent\": 0.08153846153846153, \"conf\": 0.65, \"group\": \"novel\"}, {\"percent\": 0.06384615384615384, \"conf\": 0.75, \"group\": \"novel\"}, {\"percent\": 0.06653846153846153, \"conf\": 0.8500000000000001, \"group\": \"novel\"}, {\"percent\": 0.17807692307692308, \"conf\": 0.9500000000000001, \"group\": \"novel\"}]}}, {\"mode\": \"vega-lite\"});\n", "</script>" ], "text/plain": [ "alt.Chart(...)" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\n", "alt.Chart(pd.DataFrame(dens)).mark_bar().encode(\n", " x=alt.X('conf:Q', bin=True),\n", " y=alt.Y('percent:Q'),\n", " column='group:N',\n", " color='group:N'\n", ").properties(\n", " width=200,\n", " height=200\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "ood", "language": "python", "name": "ood" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }