{
"cells": [
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import math\n",
"import json\n",
"import glob\n",
"from functools import reduce\n",
"import altair as alt"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"def get_mae(dataset, model, metric, n_sample, seed):\n",
" result_dir = f'../results/{dataset}/{model}_{seed}/{metric}_{n_sample}'\n",
" \n",
" result_fs = glob.glob(f'{result_dir}/*.json')\n",
" results = []\n",
" for file in result_fs:\n",
" with open(file, 'r') as f:\n",
" data = json.load(f)\n",
" results.extend(data)\n",
" \n",
" maes = [abs(result['metric'] - result['error']) for result in results]\n",
" \n",
" return np.array(maes).mean()"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"def get_sens_results(metric, dataset, arch):\n",
" n_samples = [10000, 5000, 2000, 1000, 500, 200, 100]\n",
" seed = '1'\n",
" sens = []\n",
" for n in n_samples: \n",
" mae = get_mae(dataset, arch, metric, n, seed)\n",
" sens.append({\n",
" 'metric': metric,\n",
" 'dataset': dataset,\n",
" 'n_sample': n,\n",
" 'arch': arch,\n",
" 'mae': mae\n",
" })\n",
" return sens"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"metrics = ['EMD', 'ATC', 'REMD_1.25']\n",
"datasets = ['CIFAR-10', 'CIFAR-100', 'Tiny-ImageNet']\n",
"sens = []\n",
"for metric in metrics:\n",
" for dataset in datasets:\n",
" sens += get_sens_results(metric, dataset, 'resnet50')"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"<div id=\"altair-viz-4b59fbffcc37474daa32cb84404743a9\"></div>\n",
"<script type=\"text/javascript\">\n",
" var VEGA_DEBUG = (typeof VEGA_DEBUG == \"undefined\") ? {} : VEGA_DEBUG;\n",
" (function(spec, embedOpt){\n",
" let outputDiv = document.currentScript.previousElementSibling;\n",
" if (outputDiv.id !== \"altair-viz-4b59fbffcc37474daa32cb84404743a9\") {\n",
" outputDiv = document.getElementById(\"altair-viz-4b59fbffcc37474daa32cb84404743a9\");\n",
" }\n",
" const paths = {\n",
" \"vega\": \"https://cdn.jsdelivr.net/npm//vega@5?noext\",\n",
" \"vega-lib\": \"https://cdn.jsdelivr.net/npm//vega-lib?noext\",\n",
" \"vega-lite\": \"https://cdn.jsdelivr.net/npm//vega-lite@4.17.0?noext\",\n",
" \"vega-embed\": \"https://cdn.jsdelivr.net/npm//vega-embed@6?noext\",\n",
" };\n",
"\n",
" function maybeLoadScript(lib, version) {\n",
" var key = `${lib.replace(\"-\", \"\")}_version`;\n",
" return (VEGA_DEBUG[key] == version) ?\n",
" Promise.resolve(paths[lib]) :\n",
" new Promise(function(resolve, reject) {\n",
" var s = document.createElement('script');\n",
" document.getElementsByTagName(\"head\")[0].appendChild(s);\n",
" s.async = true;\n",
" s.onload = () => {\n",
" VEGA_DEBUG[key] = version;\n",
" return resolve(paths[lib]);\n",
" };\n",
" s.onerror = () => reject(`Error loading script: ${paths[lib]}`);\n",
" s.src = paths[lib];\n",
" });\n",
" }\n",
"\n",
" function showError(err) {\n",
" outputDiv.innerHTML = `<div class=\"error\" style=\"color:red;\">${err}</div>`;\n",
" throw err;\n",
" }\n",
"\n",
" function displayChart(vegaEmbed) {\n",
" vegaEmbed(outputDiv, spec, embedOpt)\n",
" .catch(err => showError(`Javascript Error: ${err.message}<br>This usually means there's a typo in your chart specification. See the javascript console for the full traceback.`));\n",
" }\n",
"\n",
" if(typeof define === \"function\" && define.amd) {\n",
" requirejs.config({paths});\n",
" require([\"vega-embed\"], displayChart, err => showError(`Error loading script: ${err.message}`));\n",
" } else {\n",
" maybeLoadScript(\"vega\", \"5\")\n",
" .then(() => maybeLoadScript(\"vega-lite\", \"4.17.0\"))\n",
" .then(() => maybeLoadScript(\"vega-embed\", \"6\"))\n",
" .catch(showError)\n",
" .then(() => displayChart(vegaEmbed));\n",
" }\n",
" })({\"config\": {\"view\": {\"continuousWidth\": 400, \"continuousHeight\": 300}, \"axis\": {\"labelFontSize\": 16, \"titleFontSize\": 18}, \"legend\": {\"labelFontSize\": 18, \"titleFontSize\": 20}}, \"data\": {\"values\": [{\"metric\": \"EMD\", \"dataset\": \"CIFAR-10\", \"n_sample\": 10000, \"arch\": \"resnet50\", \"mae\": 0.014469379930368452}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-10\", \"n_sample\": 5000, \"arch\": \"resnet50\", \"mae\": 0.015004043102478862}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-10\", \"n_sample\": 2000, \"arch\": \"resnet50\", \"mae\": 0.016281259801561865}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-10\", \"n_sample\": 1000, \"arch\": \"resnet50\", \"mae\": 0.020989433447233103}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-10\", \"n_sample\": 500, \"arch\": \"resnet50\", \"mae\": 0.04018579669484187}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-10\", \"n_sample\": 200, \"arch\": \"resnet50\", \"mae\": 0.04615530701310061}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-10\", \"n_sample\": 100, \"arch\": \"resnet50\", \"mae\": 0.045474082633063746}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-100\", \"n_sample\": 10000, \"arch\": \"resnet50\", \"mae\": 0.01992340799863755}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-100\", \"n_sample\": 5000, \"arch\": \"resnet50\", \"mae\": 0.020585536769618707}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-100\", \"n_sample\": 2000, \"arch\": \"resnet50\", \"mae\": 0.024342450141671763}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-100\", \"n_sample\": 1000, \"arch\": \"resnet50\", \"mae\": 0.03338057958699007}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-100\", \"n_sample\": 500, \"arch\": \"resnet50\", \"mae\": 0.0534508796307282}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-100\", \"n_sample\": 200, \"arch\": \"resnet50\", \"mae\": 0.10063707337972706}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-100\", \"n_sample\": 100, \"arch\": \"resnet50\", \"mae\": 0.19019687173354596}, {\"metric\": \"EMD\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 10000, \"arch\": \"resnet50\", \"mae\": 0.032434747111411456}, {\"metric\": \"EMD\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 5000, \"arch\": \"resnet50\", \"mae\": 0.029363516945500913}, {\"metric\": \"EMD\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 2000, \"arch\": \"resnet50\", \"mae\": 0.02454799185399754}, {\"metric\": \"EMD\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 1000, \"arch\": \"resnet50\", \"mae\": 0.03127610869061543}, {\"metric\": \"EMD\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 500, \"arch\": \"resnet50\", \"mae\": 0.04364479950047747}, {\"metric\": \"EMD\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 200, \"arch\": \"resnet50\", \"mae\": 0.09441336930069869}, {\"metric\": \"EMD\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 100, \"arch\": \"resnet50\", \"mae\": 0.20636375601731388}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-10\", \"n_sample\": 10000, \"arch\": \"resnet50\", \"mae\": 0.0316802075425163}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-10\", \"n_sample\": 5000, \"arch\": \"resnet50\", \"mae\": 0.03201457563489675}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-10\", \"n_sample\": 2000, \"arch\": \"resnet50\", \"mae\": 0.031520850390195844}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-10\", \"n_sample\": 1000, \"arch\": \"resnet50\", \"mae\": 0.028937513527770835}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-10\", \"n_sample\": 500, \"arch\": \"resnet50\", \"mae\": 0.03106251671910286}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-10\", \"n_sample\": 200, \"arch\": \"resnet50\", \"mae\": 0.03927082135031621}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-10\", \"n_sample\": 100, \"arch\": \"resnet50\", \"mae\": 0.045416658793886504}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-100\", \"n_sample\": 10000, \"arch\": \"resnet50\", \"mae\": 0.03524166695301732}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-100\", \"n_sample\": 5000, \"arch\": \"resnet50\", \"mae\": 0.03512082161357005}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-100\", \"n_sample\": 2000, \"arch\": \"resnet50\", \"mae\": 0.034406267752870916}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-100\", \"n_sample\": 1000, \"arch\": \"resnet50\", \"mae\": 0.03490626962669194}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-100\", \"n_sample\": 500, \"arch\": \"resnet50\", \"mae\": 0.03997918866078059}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-100\", \"n_sample\": 200, \"arch\": \"resnet50\", \"mae\": 0.043802075817560154}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-100\", \"n_sample\": 100, \"arch\": \"resnet50\", \"mae\": 0.047083325767889615}, {\"metric\": \"ATC\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 10000, \"arch\": \"resnet50\", \"mae\": 0.03485526348164207}, {\"metric\": \"ATC\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 5000, \"arch\": \"resnet50\", \"mae\": 0.034565782804316596}, {\"metric\": \"ATC\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 2000, \"arch\": \"resnet50\", \"mae\": 0.03525658435531352}, {\"metric\": \"ATC\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 1000, \"arch\": \"resnet50\", \"mae\": 0.03613158635089272}, {\"metric\": \"ATC\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 500, \"arch\": \"resnet50\", \"mae\": 0.042947374909331926}, {\"metric\": \"ATC\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 200, \"arch\": \"resnet50\", \"mae\": 0.04743420748922386}, {\"metric\": \"ATC\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 100, \"arch\": \"resnet50\", \"mae\": 0.05434210061634842}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-10\", \"n_sample\": 10000, \"arch\": \"resnet50\", \"mae\": 0.04440492593811391}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-10\", \"n_sample\": 5000, \"arch\": \"resnet50\", \"mae\": 0.04269366873211624}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-10\", \"n_sample\": 2000, \"arch\": \"resnet50\", \"mae\": 0.042810486328272036}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-10\", \"n_sample\": 1000, \"arch\": \"resnet50\", \"mae\": 0.04571517167147061}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-10\", \"n_sample\": 500, \"arch\": \"resnet50\", \"mae\": 0.051640504498346416}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-10\", \"n_sample\": 200, \"arch\": \"resnet50\", \"mae\": 0.05429068888781783}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-10\", \"n_sample\": 100, \"arch\": \"resnet50\", \"mae\": 0.05624488137029399}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-100\", \"n_sample\": 10000, \"arch\": \"resnet50\", \"mae\": 0.0934960299973343}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-100\", \"n_sample\": 5000, \"arch\": \"resnet50\", \"mae\": 0.09297857507000007}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-100\", \"n_sample\": 2000, \"arch\": \"resnet50\", \"mae\": 0.08642741723291288}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-100\", \"n_sample\": 1000, \"arch\": \"resnet50\", \"mae\": 0.0810600697000193}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-100\", \"n_sample\": 500, \"arch\": \"resnet50\", \"mae\": 0.07525555161545947}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-100\", \"n_sample\": 200, \"arch\": \"resnet50\", \"mae\": 0.08454419112086059}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-100\", \"n_sample\": 100, \"arch\": \"resnet50\", \"mae\": 0.12893080543225685}, {\"metric\": \"REMD_1.25\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 10000, \"arch\": \"resnet50\", \"mae\": 0.14702108353799165}, {\"metric\": \"REMD_1.25\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 5000, \"arch\": \"resnet50\", \"mae\": 0.14431254023133747}, {\"metric\": \"REMD_1.25\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 2000, \"arch\": \"resnet50\", \"mae\": 0.1329430498002896}, {\"metric\": \"REMD_1.25\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 1000, \"arch\": \"resnet50\", \"mae\": 0.12127065566562693}, {\"metric\": \"REMD_1.25\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 500, \"arch\": \"resnet50\", \"mae\": 0.106019251174766}, {\"metric\": \"REMD_1.25\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 200, \"arch\": \"resnet50\", \"mae\": 0.08902080514337413}, {\"metric\": \"REMD_1.25\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 100, \"arch\": \"resnet50\", \"mae\": 0.12326658876760059}]}, \"facet\": {\"column\": {\"field\": \"dataset\", \"header\": {\"labelFontSize\": 18}, \"type\": \"nominal\"}}, \"spec\": {\"mark\": {\"type\": \"line\", \"point\": {\"fill\": \"white\", \"filled\": false, \"size\": 50}}, \"encoding\": {\"color\": {\"field\": \"metric\", \"type\": \"nominal\"}, \"x\": {\"axis\": {\"values\": [100, 1000, 10000]}, \"field\": \"n_sample\", \"scale\": {\"reverse\": false, \"type\": \"log\"}, \"title\": \"Number of Samples\", \"type\": \"quantitative\"}, \"y\": {\"field\": \"mae\", \"scale\": {\"domain\": [0, 0.22]}, \"title\": \"Mean Absolute Error\", \"type\": \"quantitative\"}}, \"height\": 150, \"width\": 150}, \"$schema\": \"https://vega.github.io/schema/vega-lite/v4.17.0.json\"}, {\"mode\": \"vega-lite\"});\n",
"</script>"
],
"text/plain": [
"alt.FacetChart(...)"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"alt.Chart(alt.Data(values=sens)).mark_line(\n",
" point=alt.OverlayMarkDef(size=50, filled=False, fill='white')\n",
").encode(\n",
" x=alt.X('n_sample:Q', title='Number of Samples', scale=alt.Scale(type=\"log\", reverse=False), axis=alt.Axis(values=[100, 1000, 10000])), \n",
" y=alt.Y('mae:Q', scale=alt.Scale(domain=[0, 0.22]), title='Mean Absolute Error'),\n",
" color=alt.Color('metric:N')\n",
").properties(\n",
" width=150,\n",
" height=150\n",
").facet(\n",
" column=alt.Column('dataset:N', header=alt.Header(labelFontSize=18))\n",
").configure_axis(\n",
" labelFontSize=16,\n",
" titleFontSize=18,\n",
").configure_legend(\n",
" titleFontSize=20,\n",
" labelFontSize=18\n",
") "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"vscode": {
"interpreter": {
"hash": "3f1da77dfecb8ea08abe23a06d7ff88ed6156499ac883f4692cab9e51a9d5766"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}