{ "cells": [ { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import os, pickle, sys\n", "import matplotlib.pyplot as plt\n", "from scipy import stats\n", "import numpy as np\n", "import glob\n", "from tqdm import tqdm\n", "from prettytable import PrettyTable" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 96/96 [00:03<00:00, 30.17it/s]\n" ] } ], "source": [ "d = '../results_release/nasbench1/proxies'\n", "runs = []\n", "processed = set()\n", "\n", "for f in tqdm(os.listdir(d)):\n", " pf = open(os.path.join(d,f),'rb')\n", " while 1:\n", " try:\n", " p = pickle.load(pf)\n", " if p['hash'] in processed:\n", " continue\n", " processed.add(p['hash'])\n", " runs.append(p)\n", " except EOFError:\n", " break\n", " pf.close()\n", "with open('../data/nasbench1_accuracy.p','rb') as f:\n", " all_accur = pickle.load(f)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "423624 423624\n" ] } ], "source": [ "print(len(runs),len(all_accur))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "../results_release/nasbench1/proxies 423624\n", "+---------+-----------+-------+-------+--------+---------+-----------+\n", "| Dataset | grad_norm | snip | grasp | fisher | synflow | jacob_cov |\n", "+---------+-----------+-------+-------+--------+---------+-----------+\n", "| CIFAR10 | 0.198 | 0.164 | 0.448 | 0.257 | 0.372 | 0.378 |\n", "+---------+-----------+-------+-------+--------+---------+-----------+\n" ] } ], "source": [ "t=None\n", "\n", "print(d, len(runs))\n", "metrics={}\n", "for k in runs[0]['logmeasures'].keys():\n", " metrics[k] = []\n", "acc = []\n", "hashes = []\n", "\n", "if t is None:\n", " hl=['Dataset']\n", " hl.extend(['grad_norm', 'snip', 'grasp', 'fisher', 'synflow', 'jacob_cov'])\n", " t = PrettyTable(hl)\n", "\n", "for r in runs:\n", " for k,v in r['logmeasures'].items():\n", " metrics[k].append(v)\n", " \n", " acc.append(all_accur[r['hash']][0])\n", " hashes.append(r['hash'])\n", "\n", "res = []\n", "for k in hl:\n", " if k=='Dataset':\n", " continue\n", " v = metrics[k]\n", " cr = abs(stats.spearmanr(acc,v,nan_policy='omit').correlation)\n", " #print(f'{k} = {cr}')\n", " res.append(round(cr,3))\n", "\n", "ds = 'CIFAR10'\n", "t.add_row([ds]+res)\n", "\n", "print(t)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }