{ "cells": [ { "cell_type": "code", "execution_count": 153, "id": "c542d51c", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import json\n", "import glob" ] }, { "cell_type": "code", "execution_count": 154, "id": "1592981c", "metadata": {}, "outputs": [], "source": [ "def load_results(dataset, model, pretrained, ckpt, metric, n_sample, seed):\n", " if pretrained:\n", " result_dir = f'../results/{dataset}/pretrained/{model}_{seed}-{ckpt}/{metric}_{n_sample}'\n", " else:\n", " result_dir = f'../results/{dataset}/scratch/{model}_{seed}-{ckpt}/{metric}_{n_sample}'\n", " \n", " result_fs = glob.glob(f'{result_dir}/*.json')\n", " results = []\n", " for file in result_fs:\n", " with open(file, 'r') as f:\n", " data = json.load(f)\n", " print(file, len(data))\n", " results.extend(data)\n", " \n", " return results" ] }, { "cell_type": "code", "execution_count": 155, "id": "54321e39", "metadata": {}, "outputs": [], "source": [ "n_sample_dict = {\n", " 'CIFAR-10': -1, \n", " 'CIFAR-100': -1, \n", " 'Living-17': -1,\n", " 'Nonliving-26': -1,\n", " 'Entity-13': -1,\n", " 'Entity-30': -1,\n", " 'ImageNet': -1,\n", " 'RxRx1': -1,\n", " 'FMoW': -1,\n", " 'Amazon': -1,\n", " 'CivilComments': -1\n", "}\n", "\n", "n_epoch_dict = {\n", " 'CIFAR-10': 300, \n", " 'CIFAR-100': 300, \n", " 'Living-17': 450,\n", " 'Nonliving-26': 450,\n", " 'Entity-13': 300,\n", " 'Entity-30': 300,\n", " 'ImageNet': 10,\n", " 'FMoW': 50,\n", " 'RxRx1': 90,\n", " 'Amazon': 3,\n", " 'CivilComments': 5\n", "}\n", "\n", "pretrained_dict = {\n", " 'CIFAR-10': False, \n", " 'CIFAR-100': False, \n", " 'Living-17': False,\n", " 'Nonliving-26': False,\n", " 'Entity-13': False,\n", " 'Entity-30': False,\n", " 'ImageNet': True,\n", " 'FMoW': True,\n", " 'RxRx1': True,\n", " 'Amazon': True,\n", " 'CivilComments': True\n", "}" ] }, { "cell_type": "code", "execution_count": 189, "id": "13c79cfb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "../results/Amazon/pretrained/distilbert-base-uncased_1-3/COTT-MC_-1/group-1.json 1\n", "../results/Amazon/pretrained/distilbert-base-uncased_1-3/COTT-MC_-1/group-2.json 1\n" ] } ], "source": [ "metric = 'COTT-MC'\n", "dataset = 'Amazon'\n", "arch = 'distilbert-base-uncased' # 'resnet50' # \n", "n_sample = n_sample_dict[dataset]\n", "seed = '1'\n", "model_ckpt = n_epoch_dict[dataset]\n", "pretrained = pretrained_dict[dataset]\n", "results = load_results(dataset, arch, pretrained, model_ckpt, metric, n_sample, seed)" ] }, { "cell_type": "code", "execution_count": 190, "id": "a1c8634e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2" ] }, "execution_count": 190, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(results)" ] }, { "cell_type": "code", "execution_count": 191, "id": "edb585d8", "metadata": {}, "outputs": [], "source": [ "import altair as alt\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 192, "id": "ef90531d", "metadata": {}, "outputs": [], "source": [ "def get_corr_chart(data, subpop):\n", " corr = alt.Chart(alt.Data(values=data), title=subpop).mark_point(size=50).encode(\n", " x=alt.X('metric:Q', title=metric),\n", " y=alt.X('error:Q', title='Test Error'),\n", " color=alt.Color('corruption:N', scale=alt.Scale(scheme='category20')),\n", " shape='corruption level:O'\n", " ).properties(\n", " width=200,\n", " height=200\n", " )\n", " return corr" ] }, { "cell_type": "code", "execution_count": 193, "id": "86df21ae", "metadata": {}, "outputs": [], "source": [ "same_pop_results = [i for i in results if i['subpopulation'] == 'same']\n", "natural_pop_results = [i for i in results if i['subpopulation'] == 'natural']\n", "novel_pop_results = [i for i in results if i['subpopulation'] == 'novel']" ] }, { "cell_type": "code", "execution_count": 194, "id": "fd94e84f", "metadata": {}, "outputs": [], "source": [ "line = pd.DataFrame({'metric': [0, 1], 'error': [0, 1]})\n", "line_plot = alt.Chart(line).mark_line(color='black', strokeDash=[5, 8]).encode(\n", " x='metric',\n", " y='error',\n", ")" ] }, { "cell_type": "code", "execution_count": 195, "id": "a2681d4c", "metadata": {}, "outputs": [], "source": [ "same_corr = get_corr_chart(same_pop_results, 'same')\n", "same_plt = same_corr + line_plot\n", "\n", "natural_corr = get_corr_chart(natural_pop_results, 'natural')\n", "natural_plt = natural_corr + line_plot\n", "\n", "novel_corr = get_corr_chart(novel_pop_results, 'novel')\n", "novel_plt = novel_corr + line_plot\n", "\n", "plt = same_plt | natural_plt | novel_plt" ] }, { "cell_type": "code", "execution_count": 196, "id": "3c3d3a92", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/workspace/lu35/anaconda3/envs/ood/lib/python3.9/site-packages/altair/utils/core.py:317: FutureWarning: iteritems is deprecated and will be removed in a future version. Use .items instead.\n", " for col_name, dtype in df.dtypes.iteritems():\n" ] }, { "data": { "text/html": [ "\n", "<div id=\"altair-viz-f38cb6baa1fe499dbce882e840642607\"></div>\n", "<script type=\"text/javascript\">\n", " var VEGA_DEBUG = (typeof VEGA_DEBUG == \"undefined\") ? {} : VEGA_DEBUG;\n", " (function(spec, embedOpt){\n", " let outputDiv = document.currentScript.previousElementSibling;\n", " if (outputDiv.id !== \"altair-viz-f38cb6baa1fe499dbce882e840642607\") {\n", " outputDiv = document.getElementById(\"altair-viz-f38cb6baa1fe499dbce882e840642607\");\n", " }\n", " const paths = {\n", " \"vega\": \"https://cdn.jsdelivr.net/npm//vega@5?noext\",\n", " \"vega-lib\": \"https://cdn.jsdelivr.net/npm//vega-lib?noext\",\n", " \"vega-lite\": \"https://cdn.jsdelivr.net/npm//vega-lite@4.17.0?noext\",\n", " \"vega-embed\": \"https://cdn.jsdelivr.net/npm//vega-embed@6?noext\",\n", " };\n", "\n", " function maybeLoadScript(lib, version) {\n", " var key = `${lib.replace(\"-\", \"\")}_version`;\n", " return (VEGA_DEBUG[key] == version) ?\n", " Promise.resolve(paths[lib]) :\n", " new Promise(function(resolve, reject) {\n", " var s = document.createElement('script');\n", " document.getElementsByTagName(\"head\")[0].appendChild(s);\n", " s.async = true;\n", " s.onload = () => {\n", " VEGA_DEBUG[key] = version;\n", " return resolve(paths[lib]);\n", " };\n", " s.onerror = () => reject(`Error loading script: ${paths[lib]}`);\n", " s.src = paths[lib];\n", " });\n", " }\n", "\n", " function showError(err) {\n", " outputDiv.innerHTML = `<div class=\"error\" style=\"color:red;\">${err}</div>`;\n", " throw err;\n", " }\n", "\n", " function displayChart(vegaEmbed) {\n", " vegaEmbed(outputDiv, spec, embedOpt)\n", " .catch(err => showError(`Javascript Error: ${err.message}<br>This usually means there's a typo in your chart specification. See the javascript console for the full traceback.`));\n", " }\n", "\n", " if(typeof define === \"function\" && define.amd) {\n", " requirejs.config({paths});\n", " require([\"vega-embed\"], displayChart, err => showError(`Error loading script: ${err.message}`));\n", " } else {\n", " maybeLoadScript(\"vega\", \"5\")\n", " .then(() => maybeLoadScript(\"vega-lite\", \"4.17.0\"))\n", " .then(() => maybeLoadScript(\"vega-embed\", \"6\"))\n", " .catch(showError)\n", " .then(() => displayChart(vegaEmbed));\n", " }\n", " })({\"config\": {\"view\": {\"continuousWidth\": 400, \"continuousHeight\": 300}, \"axis\": {\"labelFontSize\": 14, \"titleFontSize\": 16}, \"legend\": {\"labelFontSize\": 16, \"titleFontSize\": 14}}, \"hconcat\": [{\"layer\": [{\"data\": {\"values\": []}, \"mark\": {\"type\": \"point\", \"size\": 50}, \"encoding\": {\"color\": {\"field\": \"corruption\", \"scale\": {\"scheme\": \"category20\"}, \"type\": \"nominal\"}, \"shape\": {\"field\": \"corruption level\", \"type\": \"ordinal\"}, \"x\": {\"field\": \"metric\", \"title\": \"COTT-MC\", \"type\": \"quantitative\"}, \"y\": {\"field\": \"error\", \"title\": \"Test Error\", \"type\": \"quantitative\"}}, \"height\": 200, \"title\": \"same\", \"width\": 200}, {\"data\": {\"name\": \"data-371cce5920a7f84c9a50d4162766dae5\"}, \"mark\": {\"type\": \"line\", \"color\": \"black\", \"strokeDash\": [5, 8]}, \"encoding\": {\"x\": {\"field\": \"metric\", \"type\": \"quantitative\"}, \"y\": {\"field\": \"error\", \"type\": \"quantitative\"}}}]}, {\"layer\": [{\"data\": {\"values\": [{\"corruption\": \"group-1\", \"corruption level\": 0, \"metric\": 0.26476, \"ref\": \"COTT-MC\", \"acc\": 0.7209794521331787, \"error\": 0.2790205478668213, \"subpopulation\": \"natural\", \"pretrained\": true}, {\"corruption\": \"group-2\", \"corruption level\": 0, \"metric\": 0.26741000000000004, \"ref\": \"COTT-MC\", \"acc\": 0.7141029238700867, \"error\": 0.28589707612991333, \"subpopulation\": \"natural\", \"pretrained\": true}]}, \"mark\": {\"type\": \"point\", \"size\": 50}, \"encoding\": {\"color\": {\"field\": \"corruption\", \"scale\": {\"scheme\": \"category20\"}, \"type\": \"nominal\"}, \"shape\": {\"field\": \"corruption level\", \"type\": \"ordinal\"}, \"x\": {\"field\": \"metric\", \"title\": \"COTT-MC\", \"type\": \"quantitative\"}, \"y\": {\"field\": \"error\", \"title\": \"Test Error\", \"type\": \"quantitative\"}}, \"height\": 200, \"title\": \"natural\", \"width\": 200}, {\"data\": {\"name\": \"data-371cce5920a7f84c9a50d4162766dae5\"}, \"mark\": {\"type\": \"line\", \"color\": \"black\", \"strokeDash\": [5, 8]}, \"encoding\": {\"x\": {\"field\": \"metric\", \"type\": \"quantitative\"}, \"y\": {\"field\": \"error\", \"type\": \"quantitative\"}}}]}, {\"layer\": [{\"data\": {\"values\": []}, \"mark\": {\"type\": \"point\", \"size\": 50}, \"encoding\": {\"color\": {\"field\": \"corruption\", \"scale\": {\"scheme\": \"category20\"}, \"type\": \"nominal\"}, \"shape\": {\"field\": \"corruption level\", \"type\": \"ordinal\"}, \"x\": {\"field\": \"metric\", \"title\": \"COTT-MC\", \"type\": \"quantitative\"}, \"y\": {\"field\": \"error\", \"title\": \"Test Error\", \"type\": \"quantitative\"}}, \"height\": 200, \"title\": \"novel\", \"width\": 200}, {\"data\": {\"name\": \"data-371cce5920a7f84c9a50d4162766dae5\"}, \"mark\": {\"type\": \"line\", \"color\": \"black\", \"strokeDash\": [5, 8]}, \"encoding\": {\"x\": {\"field\": \"metric\", \"type\": \"quantitative\"}, \"y\": {\"field\": \"error\", \"type\": \"quantitative\"}}}]}], \"$schema\": \"https://vega.github.io/schema/vega-lite/v4.17.0.json\", \"datasets\": {\"data-371cce5920a7f84c9a50d4162766dae5\": [{\"metric\": 0, \"error\": 0}, {\"metric\": 1, \"error\": 1}]}}, {\"mode\": \"vega-lite\"});\n", "</script>" ], "text/plain": [ "alt.HConcatChart(...)" ] }, "execution_count": 196, "metadata": {}, "output_type": "execute_result" } ], "source": [ "plt.configure_axis(\n", " labelFontSize=14,\n", " titleFontSize=16\n", ").configure_axis(\n", " labelFontSize=14,\n", " titleFontSize=16,\n", ").configure_legend(\n", " titleFontSize=14,\n", " labelFontSize=16\n", ") \n" ] }, { "cell_type": "code", "execution_count": 197, "id": "f6c96c4f", "metadata": {}, "outputs": [], "source": [ "def polyfit(x, y, degree=1):\n", " results = {}\n", "\n", " coeffs = np.polyfit(x, y, degree)\n", "\n", " results['polynomial'] = coeffs.tolist()\n", "\n", " p = np.poly1d(coeffs)\n", "\n", " yhat = p(x) \n", " ybar = np.sum(y)/len(y) \n", " ssreg = np.sum((yhat - ybar)**2) \n", " sstot = np.sum((y - ybar)**2) \n", " results['determination'] = ssreg / sstot\n", "\n", " return results" ] }, { "cell_type": "code", "execution_count": 198, "id": "73c59c8d", "metadata": {}, "outputs": [], "source": [ "from scipy.stats import spearmanr\n", "import math" ] }, { "cell_type": "code", "execution_count": 199, "id": "08834324", "metadata": {}, "outputs": [], "source": [ "def compute_corr_stats(results):\n", " if len(results) == 0:\n", " return\n", " \n", " d = [i['metric'] for i in results]\n", " e = [i['error'] for i in results]\n", " \n", " if len(results) > 1:\n", " print(f'{metric} r2:', polyfit(d, e)['determination'])\n", "\n", " coef, p = spearmanr(d, e)\n", " print(f'{metric} rho:', coef)\n", "\n", " print(f'{metric} slope:', polyfit(d, e)['polynomial'][0])\n", " print(f'{metric} bias:', polyfit(d, e)['polynomial'][1])\n", "\n", " yhat = np.array(d)\n", " y = np.array(e)\n", "\n", " print('MAE:', np.abs(yhat - y).mean() * 100)" ] }, { "cell_type": "code", "execution_count": 200, "id": "23828d93", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "----- same pop results -----\n" ] } ], "source": [ "print(\"----- same pop results -----\")\n", "compute_corr_stats(same_pop_results)" ] }, { "cell_type": "code", "execution_count": 201, "id": "bf70a29c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "----- natural pop results -----\n", "COTT-MC r2: 1.0\n", "COTT-MC rho: 0.9999999999999999\n", "COTT-MC slope: 2.5949163256950816\n", "COTT-MC bias: -0.40800949852420865\n", "MAE: 1.6373811998367294\n" ] } ], "source": [ "print(\"----- natural pop results -----\")\n", "compute_corr_stats(natural_pop_results)" ] }, { "cell_type": "code", "execution_count": 202, "id": "05c3ac30", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "----- novel pop results -----\n" ] } ], "source": [ "print(\"----- novel pop results -----\")\n", "compute_corr_stats(novel_pop_results)" ] }, { "cell_type": "code", "execution_count": 203, "id": "1df2cb59", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "----- novel pop results -----\n" ] } ], "source": [ "print(\"----- novel pop results -----\")\n", "compute_corr_stats([i for i in novel_pop_results if i['corruption level'] == 0 ])" ] }, { "cell_type": "code", "execution_count": null, "id": "f4599063", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "ood", "language": "python", "name": "ood" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 5 }