{ "cells": [ { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import math\n", "import json\n", "import glob\n", "from functools import reduce\n", "import altair as alt" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "def get_mae(dataset, model, metric, n_sample, seed):\n", " result_dir = f'../results/{dataset}/{model}_{seed}/{metric}_{n_sample}'\n", " \n", " result_fs = glob.glob(f'{result_dir}/*.json')\n", " results = []\n", " for file in result_fs:\n", " with open(file, 'r') as f:\n", " data = json.load(f)\n", " results.extend(data)\n", " \n", " maes = [abs(result['metric'] - result['error']) for result in results]\n", " \n", " return np.array(maes).mean()" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "def get_sens_results(metric, dataset, arch):\n", " n_samples = [10000, 5000, 2000, 1000, 500, 200, 100]\n", " seed = '1'\n", " sens = []\n", " for n in n_samples: \n", " mae = get_mae(dataset, arch, metric, n, seed)\n", " sens.append({\n", " 'metric': metric,\n", " 'dataset': dataset,\n", " 'n_sample': n,\n", " 'arch': arch,\n", " 'mae': mae\n", " })\n", " return sens" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "metrics = ['EMD', 'ATC', 'REMD_1.25']\n", "datasets = ['CIFAR-10', 'CIFAR-100', 'Tiny-ImageNet']\n", "sens = []\n", "for metric in metrics:\n", " for dataset in datasets:\n", " sens += get_sens_results(metric, dataset, 'resnet50')" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "<div id=\"altair-viz-4b59fbffcc37474daa32cb84404743a9\"></div>\n", "<script type=\"text/javascript\">\n", " var VEGA_DEBUG = (typeof VEGA_DEBUG == \"undefined\") ? {} : VEGA_DEBUG;\n", " (function(spec, embedOpt){\n", " let outputDiv = document.currentScript.previousElementSibling;\n", " if (outputDiv.id !== \"altair-viz-4b59fbffcc37474daa32cb84404743a9\") {\n", " outputDiv = document.getElementById(\"altair-viz-4b59fbffcc37474daa32cb84404743a9\");\n", " }\n", " const paths = {\n", " \"vega\": \"https://cdn.jsdelivr.net/npm//vega@5?noext\",\n", " \"vega-lib\": \"https://cdn.jsdelivr.net/npm//vega-lib?noext\",\n", " \"vega-lite\": \"https://cdn.jsdelivr.net/npm//vega-lite@4.17.0?noext\",\n", " \"vega-embed\": \"https://cdn.jsdelivr.net/npm//vega-embed@6?noext\",\n", " };\n", "\n", " function maybeLoadScript(lib, version) {\n", " var key = `${lib.replace(\"-\", \"\")}_version`;\n", " return (VEGA_DEBUG[key] == version) ?\n", " Promise.resolve(paths[lib]) :\n", " new Promise(function(resolve, reject) {\n", " var s = document.createElement('script');\n", " document.getElementsByTagName(\"head\")[0].appendChild(s);\n", " s.async = true;\n", " s.onload = () => {\n", " VEGA_DEBUG[key] = version;\n", " return resolve(paths[lib]);\n", " };\n", " s.onerror = () => reject(`Error loading script: ${paths[lib]}`);\n", " s.src = paths[lib];\n", " });\n", " }\n", "\n", " function showError(err) {\n", " outputDiv.innerHTML = `<div class=\"error\" style=\"color:red;\">${err}</div>`;\n", " throw err;\n", " }\n", "\n", " function displayChart(vegaEmbed) {\n", " vegaEmbed(outputDiv, spec, embedOpt)\n", " .catch(err => showError(`Javascript Error: ${err.message}<br>This usually means there's a typo in your chart specification. See the javascript console for the full traceback.`));\n", " }\n", "\n", " if(typeof define === \"function\" && define.amd) {\n", " requirejs.config({paths});\n", " require([\"vega-embed\"], displayChart, err => showError(`Error loading script: ${err.message}`));\n", " } else {\n", " maybeLoadScript(\"vega\", \"5\")\n", " .then(() => maybeLoadScript(\"vega-lite\", \"4.17.0\"))\n", " .then(() => maybeLoadScript(\"vega-embed\", \"6\"))\n", " .catch(showError)\n", " .then(() => displayChart(vegaEmbed));\n", " }\n", " })({\"config\": {\"view\": {\"continuousWidth\": 400, \"continuousHeight\": 300}, \"axis\": {\"labelFontSize\": 16, \"titleFontSize\": 18}, \"legend\": {\"labelFontSize\": 18, \"titleFontSize\": 20}}, \"data\": {\"values\": [{\"metric\": \"EMD\", \"dataset\": \"CIFAR-10\", \"n_sample\": 10000, \"arch\": \"resnet50\", \"mae\": 0.014469379930368452}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-10\", \"n_sample\": 5000, \"arch\": \"resnet50\", \"mae\": 0.015004043102478862}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-10\", \"n_sample\": 2000, \"arch\": \"resnet50\", \"mae\": 0.016281259801561865}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-10\", \"n_sample\": 1000, \"arch\": \"resnet50\", \"mae\": 0.020989433447233103}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-10\", \"n_sample\": 500, \"arch\": \"resnet50\", \"mae\": 0.04018579669484187}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-10\", \"n_sample\": 200, \"arch\": \"resnet50\", \"mae\": 0.04615530701310061}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-10\", \"n_sample\": 100, \"arch\": \"resnet50\", \"mae\": 0.045474082633063746}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-100\", \"n_sample\": 10000, \"arch\": \"resnet50\", \"mae\": 0.01992340799863755}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-100\", \"n_sample\": 5000, \"arch\": \"resnet50\", \"mae\": 0.020585536769618707}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-100\", \"n_sample\": 2000, \"arch\": \"resnet50\", \"mae\": 0.024342450141671763}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-100\", \"n_sample\": 1000, \"arch\": \"resnet50\", \"mae\": 0.03338057958699007}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-100\", \"n_sample\": 500, \"arch\": \"resnet50\", \"mae\": 0.0534508796307282}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-100\", \"n_sample\": 200, \"arch\": \"resnet50\", \"mae\": 0.10063707337972706}, {\"metric\": \"EMD\", \"dataset\": \"CIFAR-100\", \"n_sample\": 100, \"arch\": \"resnet50\", \"mae\": 0.19019687173354596}, {\"metric\": \"EMD\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 10000, \"arch\": \"resnet50\", \"mae\": 0.032434747111411456}, {\"metric\": \"EMD\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 5000, \"arch\": \"resnet50\", \"mae\": 0.029363516945500913}, {\"metric\": \"EMD\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 2000, \"arch\": \"resnet50\", \"mae\": 0.02454799185399754}, {\"metric\": \"EMD\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 1000, \"arch\": \"resnet50\", \"mae\": 0.03127610869061543}, {\"metric\": \"EMD\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 500, \"arch\": \"resnet50\", \"mae\": 0.04364479950047747}, {\"metric\": \"EMD\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 200, \"arch\": \"resnet50\", \"mae\": 0.09441336930069869}, {\"metric\": \"EMD\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 100, \"arch\": \"resnet50\", \"mae\": 0.20636375601731388}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-10\", \"n_sample\": 10000, \"arch\": \"resnet50\", \"mae\": 0.0316802075425163}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-10\", \"n_sample\": 5000, \"arch\": \"resnet50\", \"mae\": 0.03201457563489675}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-10\", \"n_sample\": 2000, \"arch\": \"resnet50\", \"mae\": 0.031520850390195844}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-10\", \"n_sample\": 1000, \"arch\": \"resnet50\", \"mae\": 0.028937513527770835}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-10\", \"n_sample\": 500, \"arch\": \"resnet50\", \"mae\": 0.03106251671910286}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-10\", \"n_sample\": 200, \"arch\": \"resnet50\", \"mae\": 0.03927082135031621}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-10\", \"n_sample\": 100, \"arch\": \"resnet50\", \"mae\": 0.045416658793886504}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-100\", \"n_sample\": 10000, \"arch\": \"resnet50\", \"mae\": 0.03524166695301732}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-100\", \"n_sample\": 5000, \"arch\": \"resnet50\", \"mae\": 0.03512082161357005}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-100\", \"n_sample\": 2000, \"arch\": \"resnet50\", \"mae\": 0.034406267752870916}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-100\", \"n_sample\": 1000, \"arch\": \"resnet50\", \"mae\": 0.03490626962669194}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-100\", \"n_sample\": 500, \"arch\": \"resnet50\", \"mae\": 0.03997918866078059}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-100\", \"n_sample\": 200, \"arch\": \"resnet50\", \"mae\": 0.043802075817560154}, {\"metric\": \"ATC\", \"dataset\": \"CIFAR-100\", \"n_sample\": 100, \"arch\": \"resnet50\", \"mae\": 0.047083325767889615}, {\"metric\": \"ATC\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 10000, \"arch\": \"resnet50\", \"mae\": 0.03485526348164207}, {\"metric\": \"ATC\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 5000, \"arch\": \"resnet50\", \"mae\": 0.034565782804316596}, {\"metric\": \"ATC\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 2000, \"arch\": \"resnet50\", \"mae\": 0.03525658435531352}, {\"metric\": \"ATC\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 1000, \"arch\": \"resnet50\", \"mae\": 0.03613158635089272}, {\"metric\": \"ATC\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 500, \"arch\": \"resnet50\", \"mae\": 0.042947374909331926}, {\"metric\": \"ATC\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 200, \"arch\": \"resnet50\", \"mae\": 0.04743420748922386}, {\"metric\": \"ATC\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 100, \"arch\": \"resnet50\", \"mae\": 0.05434210061634842}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-10\", \"n_sample\": 10000, \"arch\": \"resnet50\", \"mae\": 0.04440492593811391}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-10\", \"n_sample\": 5000, \"arch\": \"resnet50\", \"mae\": 0.04269366873211624}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-10\", \"n_sample\": 2000, \"arch\": \"resnet50\", \"mae\": 0.042810486328272036}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-10\", \"n_sample\": 1000, \"arch\": \"resnet50\", \"mae\": 0.04571517167147061}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-10\", \"n_sample\": 500, \"arch\": \"resnet50\", \"mae\": 0.051640504498346416}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-10\", \"n_sample\": 200, \"arch\": \"resnet50\", \"mae\": 0.05429068888781783}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-10\", \"n_sample\": 100, \"arch\": \"resnet50\", \"mae\": 0.05624488137029399}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-100\", \"n_sample\": 10000, \"arch\": \"resnet50\", \"mae\": 0.0934960299973343}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-100\", \"n_sample\": 5000, \"arch\": \"resnet50\", \"mae\": 0.09297857507000007}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-100\", \"n_sample\": 2000, \"arch\": \"resnet50\", \"mae\": 0.08642741723291288}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-100\", \"n_sample\": 1000, \"arch\": \"resnet50\", \"mae\": 0.0810600697000193}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-100\", \"n_sample\": 500, \"arch\": \"resnet50\", \"mae\": 0.07525555161545947}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-100\", \"n_sample\": 200, \"arch\": \"resnet50\", \"mae\": 0.08454419112086059}, {\"metric\": \"REMD_1.25\", \"dataset\": \"CIFAR-100\", \"n_sample\": 100, \"arch\": \"resnet50\", \"mae\": 0.12893080543225685}, {\"metric\": \"REMD_1.25\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 10000, \"arch\": \"resnet50\", \"mae\": 0.14702108353799165}, {\"metric\": \"REMD_1.25\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 5000, \"arch\": \"resnet50\", \"mae\": 0.14431254023133747}, {\"metric\": \"REMD_1.25\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 2000, \"arch\": \"resnet50\", \"mae\": 0.1329430498002896}, {\"metric\": \"REMD_1.25\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 1000, \"arch\": \"resnet50\", \"mae\": 0.12127065566562693}, {\"metric\": \"REMD_1.25\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 500, \"arch\": \"resnet50\", \"mae\": 0.106019251174766}, {\"metric\": \"REMD_1.25\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 200, \"arch\": \"resnet50\", \"mae\": 0.08902080514337413}, {\"metric\": \"REMD_1.25\", \"dataset\": \"Tiny-ImageNet\", \"n_sample\": 100, \"arch\": \"resnet50\", \"mae\": 0.12326658876760059}]}, \"facet\": {\"column\": {\"field\": \"dataset\", \"header\": {\"labelFontSize\": 18}, \"type\": \"nominal\"}}, \"spec\": {\"mark\": {\"type\": \"line\", \"point\": {\"fill\": \"white\", \"filled\": false, \"size\": 50}}, \"encoding\": {\"color\": {\"field\": \"metric\", \"type\": \"nominal\"}, \"x\": {\"axis\": {\"values\": [100, 1000, 10000]}, \"field\": \"n_sample\", \"scale\": {\"reverse\": false, \"type\": \"log\"}, \"title\": \"Number of Samples\", \"type\": \"quantitative\"}, \"y\": {\"field\": \"mae\", \"scale\": {\"domain\": [0, 0.22]}, \"title\": \"Mean Absolute Error\", \"type\": \"quantitative\"}}, \"height\": 150, \"width\": 150}, \"$schema\": \"https://vega.github.io/schema/vega-lite/v4.17.0.json\"}, {\"mode\": \"vega-lite\"});\n", "</script>" ], "text/plain": [ "alt.FacetChart(...)" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "alt.Chart(alt.Data(values=sens)).mark_line(\n", " point=alt.OverlayMarkDef(size=50, filled=False, fill='white')\n", ").encode(\n", " x=alt.X('n_sample:Q', title='Number of Samples', scale=alt.Scale(type=\"log\", reverse=False), axis=alt.Axis(values=[100, 1000, 10000])), \n", " y=alt.Y('mae:Q', scale=alt.Scale(domain=[0, 0.22]), title='Mean Absolute Error'),\n", " color=alt.Color('metric:N')\n", ").properties(\n", " width=150,\n", " height=150\n", ").facet(\n", " column=alt.Column('dataset:N', header=alt.Header(labelFontSize=18))\n", ").configure_axis(\n", " labelFontSize=16,\n", " titleFontSize=18,\n", ").configure_legend(\n", " titleFontSize=20,\n", " labelFontSize=18\n", ") " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" }, "vscode": { "interpreter": { "hash": "3f1da77dfecb8ea08abe23a06d7ff88ed6156499ac883f4692cab9e51a9d5766" } } }, "nbformat": 4, "nbformat_minor": 2 }