    "import pickle\n",
    "import torch\n",
    "from collections import Counter"
    "def get_conf_scores(emb_dir):   \n",
    "    res = pickle.load( open(emb_dir, 'rb') )\n",
    "    act, pred, tar = res['act'], res['pred'], res['tar']\n",
    "    sm = torch.nn.functional.softmax(act, dim=1)\n",
    "    acc = ( pred == tar ).sum() / len(pred)\n",
    "    print('averaged conf:', sm.max(1)[0].mean().item())\n",
    "    print('acc:', acc.item())\n",
    "    pred_counts = Counter(pred.tolist())\n",
    "    print('label dist:', [pred_counts[i] / len(tar) for i in range(act.shape[1])] )\n",
    "    print()\n",
    "    return sm.max(1)[0].tolist()"
    "iid_emb = f'../cache/{ds}/resnet50_1-{epoch}/scratch/id_m1-{epoch}_d1.pkl'\n",
    "novel_emb = f'../cache/{ds}/resnet50_1-{epoch}/scratch/od_pnovel_m1-{epoch}_cclean-0_n{n_test_sample}.pkl'\n",
    "corr_emb = f'../cache/{ds}/resnet50_1-{epoch}/scratch/od_psame_m1-{epoch}_cfrost-5_n{n_test_sample}.pkl'"
      "averaged conf: 0.7038018107414246\n",
      "acc: 0.6948999762535095\n",
      "label dist: [0.0361, 0.041, 0.0383, 0.0329, 0.0369, 0.0335, 0.0382, 0.0317, 0.051, 0.0409, 0.0375, 0.0463, 0.0371, 0.0366, 0.0431, 0.043, 0.0284, 0.0373, 0.0354, 0.0398, 0.0422, 0.0392, 0.0292, 0.0365, 0.0371, 0.0508]\n",
      "averaged conf: 0.5417004227638245\n",
      "acc: 0.3407692313194275\n",
      "label dist: [0.03230769230769231, 0.038461538461538464, 0.05, 0.023846153846153847, 0.02576923076923077, 0.03769230769230769, 0.04730769230769231, 0.022692307692307692, 0.03576923076923077, 0.03923076923076923, 0.028076923076923076, 0.05423076923076923, 0.035, 0.036153846153846154, 0.03923076923076923, 0.04153846153846154, 0.03769230769230769, 0.04692307692307692, 0.03769230769230769, 0.041923076923076924, 0.04423076923076923, 0.038461538461538464, 0.028846153846153848, 0.03807692307692308, 0.03807692307692308, 0.06076923076923077]\n",
      "averaged conf: 0.4134865999221802\n",
      "acc: 0.21230769157409668\n",
      "label dist: [0.02346153846153846, 0.008846153846153846, 0.07923076923076923, 0.023846153846153847, 0.03576923076923077, 0.008461538461538461, 0.010384615384615384, 0.001153846153846154, 0.013076923076923076, 0.011538461538461539, 0.025, 0.18807692307692309, 0.07769230769230769, 0.005, 0.004230769230769231, 0.04230769230769231, 0.002307692307692308, 0.01730769230769231, 0.047692307692307694, 0.051153846153846154, 0.2803846153846154, 0.0015384615384615385, 0.018461538461538463, 0.01, 0.0, 0.013076923076923076]\n",
   "source": [
    "iid_sm = get_conf_scores(iid_emb)\n",
    "novel_sm = get_conf_scores(novel_emb)\n",
    "corr_sm = get_conf_scores(corr_emb)"
    "import ot"
    "import altair as alt\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "# alt.data_transformers.enable('default', max_rows=None)"
    "iid_hist, _ = np.histogram(iid_sm, bins=[i/10 for i in range(11)])\n",
    "corr_hist, _ = np.histogram(corr_sm, bins=[i/10 for i in range(11)])\n",
    "novel_hist, _ = np.histogram(novel_sm, bins=[i/10 for i in range(11)])"
    "iid_dens = iid_hist / sum(iid_hist)\n",
    "corr_dens = corr_hist / sum(corr_hist)\n",
    "novel_dens = novel_hist / sum(novel_hist)"
    "def transform_data(dens, group):\n",
    "    return [{'percent': i, 'conf': ind/10+0.05, 'group': group} for ind, i in enumerate(dens)]"
    "dens = transform_data(iid_dens, 'iid') + transform_data(corr_dens, 'corr') + transform_data(novel_dens, 'novel')"
      "/usr/workspace/lu35/anaconda3/envs/ood/lib/python3.9/site-packages/altair/utils/ FutureWarning: iteritems is deprecated and will be removed in a future version. Use .items instead.\n",
      "  for col_name, dtype in df.dtypes.iteritems():\n"
    "    x=alt.X('conf:Q', bin=True),\n",
    "    y=alt.Y('percent:Q'),\n",
    "    column='group:N',\n",
    "    color='group:N'\n",
    "    width=200,\n",
    "    height=200\n",
