FOT-OOD / notebooks / visualize_correlation_results.ipynb
visualize_correlation_results.ipynb
Raw
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 153,
   "id": "c542d51c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import json\n",
    "import glob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 154,
   "id": "1592981c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_results(dataset, model, pretrained, ckpt, metric, n_sample, seed):\n",
    "    if pretrained:\n",
    "        result_dir = f'../results/{dataset}/pretrained/{model}_{seed}-{ckpt}/{metric}_{n_sample}'\n",
    "    else:\n",
    "        result_dir = f'../results/{dataset}/scratch/{model}_{seed}-{ckpt}/{metric}_{n_sample}'\n",
    "      \n",
    "    result_fs = glob.glob(f'{result_dir}/*.json')\n",
    "    results = []\n",
    "    for file in result_fs:\n",
    "        with open(file, 'r') as f:\n",
    "            data = json.load(f)\n",
    "            print(file, len(data))\n",
    "        results.extend(data)\n",
    "        \n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "id": "54321e39",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_sample_dict = {\n",
    "    'CIFAR-10': -1, \n",
    "    'CIFAR-100': -1, \n",
    "    'Living-17': -1,\n",
    "    'Nonliving-26': -1,\n",
    "    'Entity-13': -1,\n",
    "    'Entity-30': -1,\n",
    "    'ImageNet': -1,\n",
    "    'RxRx1': -1,\n",
    "    'FMoW': -1,\n",
    "    'Amazon': -1,\n",
    "    'CivilComments': -1\n",
    "}\n",
    "\n",
    "n_epoch_dict = {\n",
    "    'CIFAR-10': 300, \n",
    "    'CIFAR-100': 300, \n",
    "    'Living-17': 450,\n",
    "    'Nonliving-26': 450,\n",
    "    'Entity-13': 300,\n",
    "    'Entity-30': 300,\n",
    "    'ImageNet': 10,\n",
    "    'FMoW': 50,\n",
    "    'RxRx1': 90,\n",
    "    'Amazon': 3,\n",
    "    'CivilComments': 5\n",
    "}\n",
    "\n",
    "pretrained_dict = {\n",
    "    'CIFAR-10': False, \n",
    "    'CIFAR-100': False, \n",
    "    'Living-17': False,\n",
    "    'Nonliving-26': False,\n",
    "    'Entity-13': False,\n",
    "    'Entity-30': False,\n",
    "    'ImageNet': True,\n",
    "    'FMoW': True,\n",
    "    'RxRx1': True,\n",
    "    'Amazon': True,\n",
    "    'CivilComments': True\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 189,
   "id": "13c79cfb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "../results/Amazon/pretrained/distilbert-base-uncased_1-3/COTT-MC_-1/group-1.json 1\n",
      "../results/Amazon/pretrained/distilbert-base-uncased_1-3/COTT-MC_-1/group-2.json 1\n"
     ]
    }
   ],
   "source": [
    "metric = 'COTT-MC'\n",
    "dataset = 'Amazon'\n",
    "arch = 'distilbert-base-uncased' # 'resnet50' #  \n",
    "n_sample = n_sample_dict[dataset]\n",
    "seed = '1'\n",
    "model_ckpt = n_epoch_dict[dataset]\n",
    "pretrained = pretrained_dict[dataset]\n",
    "results = load_results(dataset, arch, pretrained, model_ckpt, metric, n_sample, seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 190,
   "id": "a1c8634e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 190,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 191,
   "id": "edb585d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import altair as alt\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 192,
   "id": "ef90531d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_corr_chart(data, subpop):\n",
    "    corr = alt.Chart(alt.Data(values=data), title=subpop).mark_point(size=50).encode(\n",
    "        x=alt.X('metric:Q', title=metric),\n",
    "        y=alt.X('error:Q', title='Test Error'),\n",
    "        color=alt.Color('corruption:N', scale=alt.Scale(scheme='category20')),\n",
    "        shape='corruption level:O'\n",
    "    ).properties(\n",
    "        width=200,\n",
    "        height=200\n",
    "    )\n",
    "    return corr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 193,
   "id": "86df21ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "same_pop_results =  [i for i in results if i['subpopulation'] == 'same']\n",
    "natural_pop_results =  [i for i in results if i['subpopulation'] == 'natural']\n",
    "novel_pop_results = [i for i in results if i['subpopulation'] == 'novel']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 194,
   "id": "fd94e84f",
   "metadata": {},
   "outputs": [],
   "source": [
    "line = pd.DataFrame({'metric': [0, 1], 'error': [0, 1]})\n",
    "line_plot = alt.Chart(line).mark_line(color='black', strokeDash=[5, 8]).encode(\n",
    "    x='metric',\n",
    "    y='error',\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 195,
   "id": "a2681d4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "same_corr = get_corr_chart(same_pop_results, 'same')\n",
    "same_plt = same_corr + line_plot\n",
    "\n",
    "natural_corr = get_corr_chart(natural_pop_results, 'natural')\n",
    "natural_plt = natural_corr + line_plot\n",
    "\n",
    "novel_corr = get_corr_chart(novel_pop_results, 'novel')\n",
    "novel_plt = novel_corr + line_plot\n",
    "\n",
    "plt = same_plt | natural_plt | novel_plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 196,
   "id": "3c3d3a92",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/workspace/lu35/anaconda3/envs/ood/lib/python3.9/site-packages/altair/utils/core.py:317: FutureWarning: iteritems is deprecated and will be removed in a future version. Use .items instead.\n",
      "  for col_name, dtype in df.dtypes.iteritems():\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "<div id=\"altair-viz-f38cb6baa1fe499dbce882e840642607\"></div>\n",
       "<script type=\"text/javascript\">\n",
       "  var VEGA_DEBUG = (typeof VEGA_DEBUG == \"undefined\") ? {} : VEGA_DEBUG;\n",
       "  (function(spec, embedOpt){\n",
       "    let outputDiv = document.currentScript.previousElementSibling;\n",
       "    if (outputDiv.id !== \"altair-viz-f38cb6baa1fe499dbce882e840642607\") {\n",
       "      outputDiv = document.getElementById(\"altair-viz-f38cb6baa1fe499dbce882e840642607\");\n",
       "    }\n",
       "    const paths = {\n",
       "      \"vega\": \"https://cdn.jsdelivr.net/npm//vega@5?noext\",\n",
       "      \"vega-lib\": \"https://cdn.jsdelivr.net/npm//vega-lib?noext\",\n",
       "      \"vega-lite\": \"https://cdn.jsdelivr.net/npm//vega-lite@4.17.0?noext\",\n",
       "      \"vega-embed\": \"https://cdn.jsdelivr.net/npm//vega-embed@6?noext\",\n",
       "    };\n",
       "\n",
       "    function maybeLoadScript(lib, version) {\n",
       "      var key = `${lib.replace(\"-\", \"\")}_version`;\n",
       "      return (VEGA_DEBUG[key] == version) ?\n",
       "        Promise.resolve(paths[lib]) :\n",
       "        new Promise(function(resolve, reject) {\n",
       "          var s = document.createElement('script');\n",
       "          document.getElementsByTagName(\"head\")[0].appendChild(s);\n",
       "          s.async = true;\n",
       "          s.onload = () => {\n",
       "            VEGA_DEBUG[key] = version;\n",
       "            return resolve(paths[lib]);\n",
       "          };\n",
       "          s.onerror = () => reject(`Error loading script: ${paths[lib]}`);\n",
       "          s.src = paths[lib];\n",
       "        });\n",
       "    }\n",
       "\n",
       "    function showError(err) {\n",
       "      outputDiv.innerHTML = `<div class=\"error\" style=\"color:red;\">${err}</div>`;\n",
       "      throw err;\n",
       "    }\n",
       "\n",
       "    function displayChart(vegaEmbed) {\n",
       "      vegaEmbed(outputDiv, spec, embedOpt)\n",
       "        .catch(err => showError(`Javascript Error: ${err.message}<br>This usually means there's a typo in your chart specification. See the javascript console for the full traceback.`));\n",
       "    }\n",
       "\n",
       "    if(typeof define === \"function\" && define.amd) {\n",
       "      requirejs.config({paths});\n",
       "      require([\"vega-embed\"], displayChart, err => showError(`Error loading script: ${err.message}`));\n",
       "    } else {\n",
       "      maybeLoadScript(\"vega\", \"5\")\n",
       "        .then(() => maybeLoadScript(\"vega-lite\", \"4.17.0\"))\n",
       "        .then(() => maybeLoadScript(\"vega-embed\", \"6\"))\n",
       "        .catch(showError)\n",
       "        .then(() => displayChart(vegaEmbed));\n",
       "    }\n",
       "  })({\"config\": {\"view\": {\"continuousWidth\": 400, \"continuousHeight\": 300}, \"axis\": {\"labelFontSize\": 14, \"titleFontSize\": 16}, \"legend\": {\"labelFontSize\": 16, \"titleFontSize\": 14}}, \"hconcat\": [{\"layer\": [{\"data\": {\"values\": []}, \"mark\": {\"type\": \"point\", \"size\": 50}, \"encoding\": {\"color\": {\"field\": \"corruption\", \"scale\": {\"scheme\": \"category20\"}, \"type\": \"nominal\"}, \"shape\": {\"field\": \"corruption level\", \"type\": \"ordinal\"}, \"x\": {\"field\": \"metric\", \"title\": \"COTT-MC\", \"type\": \"quantitative\"}, \"y\": {\"field\": \"error\", \"title\": \"Test Error\", \"type\": \"quantitative\"}}, \"height\": 200, \"title\": \"same\", \"width\": 200}, {\"data\": {\"name\": \"data-371cce5920a7f84c9a50d4162766dae5\"}, \"mark\": {\"type\": \"line\", \"color\": \"black\", \"strokeDash\": [5, 8]}, \"encoding\": {\"x\": {\"field\": \"metric\", \"type\": \"quantitative\"}, \"y\": {\"field\": \"error\", \"type\": \"quantitative\"}}}]}, {\"layer\": [{\"data\": {\"values\": [{\"corruption\": \"group-1\", \"corruption level\": 0, \"metric\": 0.26476, \"ref\": \"COTT-MC\", \"acc\": 0.7209794521331787, \"error\": 0.2790205478668213, \"subpopulation\": \"natural\", \"pretrained\": true}, {\"corruption\": \"group-2\", \"corruption level\": 0, \"metric\": 0.26741000000000004, \"ref\": \"COTT-MC\", \"acc\": 0.7141029238700867, \"error\": 0.28589707612991333, \"subpopulation\": \"natural\", \"pretrained\": true}]}, \"mark\": {\"type\": \"point\", \"size\": 50}, \"encoding\": {\"color\": {\"field\": \"corruption\", \"scale\": {\"scheme\": \"category20\"}, \"type\": \"nominal\"}, \"shape\": {\"field\": \"corruption level\", \"type\": \"ordinal\"}, \"x\": {\"field\": \"metric\", \"title\": \"COTT-MC\", \"type\": \"quantitative\"}, \"y\": {\"field\": \"error\", \"title\": \"Test Error\", \"type\": \"quantitative\"}}, \"height\": 200, \"title\": \"natural\", \"width\": 200}, {\"data\": {\"name\": \"data-371cce5920a7f84c9a50d4162766dae5\"}, \"mark\": {\"type\": \"line\", \"color\": \"black\", \"strokeDash\": [5, 8]}, \"encoding\": {\"x\": {\"field\": \"metric\", \"type\": \"quantitative\"}, \"y\": {\"field\": \"error\", \"type\": \"quantitative\"}}}]}, {\"layer\": [{\"data\": {\"values\": []}, \"mark\": {\"type\": \"point\", \"size\": 50}, \"encoding\": {\"color\": {\"field\": \"corruption\", \"scale\": {\"scheme\": \"category20\"}, \"type\": \"nominal\"}, \"shape\": {\"field\": \"corruption level\", \"type\": \"ordinal\"}, \"x\": {\"field\": \"metric\", \"title\": \"COTT-MC\", \"type\": \"quantitative\"}, \"y\": {\"field\": \"error\", \"title\": \"Test Error\", \"type\": \"quantitative\"}}, \"height\": 200, \"title\": \"novel\", \"width\": 200}, {\"data\": {\"name\": \"data-371cce5920a7f84c9a50d4162766dae5\"}, \"mark\": {\"type\": \"line\", \"color\": \"black\", \"strokeDash\": [5, 8]}, \"encoding\": {\"x\": {\"field\": \"metric\", \"type\": \"quantitative\"}, \"y\": {\"field\": \"error\", \"type\": \"quantitative\"}}}]}], \"$schema\": \"https://vega.github.io/schema/vega-lite/v4.17.0.json\", \"datasets\": {\"data-371cce5920a7f84c9a50d4162766dae5\": [{\"metric\": 0, \"error\": 0}, {\"metric\": 1, \"error\": 1}]}}, {\"mode\": \"vega-lite\"});\n",
       "</script>"
      ],
      "text/plain": [
       "alt.HConcatChart(...)"
      ]
     },
     "execution_count": 196,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "plt.configure_axis(\n",
    "    labelFontSize=14,\n",
    "    titleFontSize=16\n",
    ").configure_axis(\n",
    "    labelFontSize=14,\n",
    "    titleFontSize=16,\n",
    ").configure_legend(\n",
    "    titleFontSize=14,\n",
    "    labelFontSize=16\n",
    ") \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 197,
   "id": "f6c96c4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def polyfit(x, y, degree=1):\n",
    "    results = {}\n",
    "\n",
    "    coeffs = np.polyfit(x, y, degree)\n",
    "\n",
    "    results['polynomial'] = coeffs.tolist()\n",
    "\n",
    "    p = np.poly1d(coeffs)\n",
    "\n",
    "    yhat = p(x)                \n",
    "    ybar = np.sum(y)/len(y)          \n",
    "    ssreg = np.sum((yhat - ybar)**2)   \n",
    "    sstot = np.sum((y - ybar)**2)    \n",
    "    results['determination'] = ssreg / sstot\n",
    "\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 198,
   "id": "73c59c8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import spearmanr\n",
    "import math"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 199,
   "id": "08834324",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_corr_stats(results):\n",
    "    if len(results) == 0:\n",
    "        return\n",
    "    \n",
    "    d = [i['metric'] for i in results]\n",
    "    e = [i['error'] for i in results]\n",
    "    \n",
    "    if len(results) > 1:\n",
    "        print(f'{metric} r2:', polyfit(d, e)['determination'])\n",
    "\n",
    "        coef, p = spearmanr(d, e)\n",
    "        print(f'{metric} rho:', coef)\n",
    "\n",
    "        print(f'{metric} slope:', polyfit(d, e)['polynomial'][0])\n",
    "        print(f'{metric} bias:', polyfit(d, e)['polynomial'][1])\n",
    "\n",
    "    yhat = np.array(d)\n",
    "    y = np.array(e)\n",
    "\n",
    "    print('MAE:', np.abs(yhat - y).mean() * 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 200,
   "id": "23828d93",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----- same pop results -----\n"
     ]
    }
   ],
   "source": [
    "print(\"----- same pop results -----\")\n",
    "compute_corr_stats(same_pop_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 201,
   "id": "bf70a29c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----- natural pop results -----\n",
      "COTT-MC r2: 1.0\n",
      "COTT-MC rho: 0.9999999999999999\n",
      "COTT-MC slope: 2.5949163256950816\n",
      "COTT-MC bias: -0.40800949852420865\n",
      "MAE: 1.6373811998367294\n"
     ]
    }
   ],
   "source": [
    "print(\"----- natural pop results -----\")\n",
    "compute_corr_stats(natural_pop_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 202,
   "id": "05c3ac30",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----- novel pop results -----\n"
     ]
    }
   ],
   "source": [
    "print(\"----- novel pop results -----\")\n",
    "compute_corr_stats(novel_pop_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 203,
   "id": "1df2cb59",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----- novel pop results -----\n"
     ]
    }
   ],
   "source": [
    "print(\"----- novel pop results -----\")\n",
    "compute_corr_stats([i for i in novel_pop_results if i['corruption level'] == 0 ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4599063",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ood",
   "language": "python",
   "name": "ood"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}