upload
This commit is contained in:
		
							
								
								
									
										5
									
								
								zero-cost-nas/notebooks/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								zero-cost-nas/notebooks/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| ## Notebooks | ||||
|  | ||||
| To run these notebooks, you need to compute the zero-cost metrics for each dataset.     | ||||
| Alternatively, you can download precomputed results from [here](https://drive.google.com/drive/folders/1fUBaTd05OHrKIRs-x9Fx8Zsk5QqErks8?usp=sharing) and save to the root folder of this repo.     | ||||
| You will also need the [`data` directory](https://drive.google.com/drive/folders/18Eia6YuTE5tn5Lis_43h30HYpnF9Ynqf?usp=sharing), which should also be saved to the root folder of the repo. | ||||
							
								
								
									
										387
									
								
								zero-cost-nas/notebooks/nas_examples.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										387
									
								
								zero-cost-nas/notebooks/nas_examples.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										153
									
								
								zero-cost-nas/notebooks/nasbench101_correlations.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										153
									
								
								zero-cost-nas/notebooks/nasbench101_correlations.ipynb
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,153 @@ | ||||
| { | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 5, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "import os, pickle, sys\n", | ||||
|     "import matplotlib.pyplot as plt\n", | ||||
|     "from scipy import stats\n", | ||||
|     "import numpy as np\n", | ||||
|     "import glob\n", | ||||
|     "from tqdm import tqdm\n", | ||||
|     "from prettytable import PrettyTable" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 6, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stderr", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "100%|██████████| 96/96 [00:03<00:00, 30.17it/s]\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "d = '../results_release/nasbench1/proxies'\n", | ||||
|     "runs = []\n", | ||||
|     "processed = set()\n", | ||||
|     "\n", | ||||
|     "for f in tqdm(os.listdir(d)):\n", | ||||
|     "    pf = open(os.path.join(d,f),'rb')\n", | ||||
|     "    while 1:\n", | ||||
|     "        try:\n", | ||||
|     "            p = pickle.load(pf)\n", | ||||
|     "            if p['hash'] in processed:\n", | ||||
|     "                continue\n", | ||||
|     "            processed.add(p['hash'])\n", | ||||
|     "            runs.append(p)\n", | ||||
|     "        except EOFError:\n", | ||||
|     "            break\n", | ||||
|     "    pf.close()\n", | ||||
|     "with open('../data/nasbench1_accuracy.p','rb') as f:\n", | ||||
|     "    all_accur = pickle.load(f)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 7, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "423624 423624\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "print(len(runs),len(all_accur))" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 8, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "../results_release/nasbench1/proxies 423624\n", | ||||
|       "+---------+-----------+-------+-------+--------+---------+-----------+\n", | ||||
|       "| Dataset | grad_norm |  snip | grasp | fisher | synflow | jacob_cov |\n", | ||||
|       "+---------+-----------+-------+-------+--------+---------+-----------+\n", | ||||
|       "| CIFAR10 |   0.198   | 0.164 | 0.448 | 0.257  |  0.372  |   0.378   |\n", | ||||
|       "+---------+-----------+-------+-------+--------+---------+-----------+\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "t=None\n", | ||||
|     "\n", | ||||
|     "print(d, len(runs))\n", | ||||
|     "metrics={}\n", | ||||
|     "for k in runs[0]['logmeasures'].keys():\n", | ||||
|     "    metrics[k] = []\n", | ||||
|     "acc = []\n", | ||||
|     "hashes = []\n", | ||||
|     "\n", | ||||
|     "if t is None:\n", | ||||
|     "    hl=['Dataset']\n", | ||||
|     "    hl.extend(['grad_norm', 'snip', 'grasp', 'fisher', 'synflow', 'jacob_cov'])\n", | ||||
|     "    t = PrettyTable(hl)\n", | ||||
|     "\n", | ||||
|     "for r in runs:\n", | ||||
|     "    for k,v in r['logmeasures'].items():\n", | ||||
|     "        metrics[k].append(v)\n", | ||||
|     "    \n", | ||||
|     "    acc.append(all_accur[r['hash']][0])\n", | ||||
|     "    hashes.append(r['hash'])\n", | ||||
|     "\n", | ||||
|     "res = []\n", | ||||
|     "for k in hl:\n", | ||||
|     "    if k=='Dataset':\n", | ||||
|     "        continue\n", | ||||
|     "    v = metrics[k]\n", | ||||
|     "    cr = abs(stats.spearmanr(acc,v,nan_policy='omit').correlation)\n", | ||||
|     "    #print(f'{k} = {cr}')\n", | ||||
|     "    res.append(round(cr,3))\n", | ||||
|     "\n", | ||||
|     "ds = 'CIFAR10'\n", | ||||
|     "t.add_row([ds]+res)\n", | ||||
|     "\n", | ||||
|     "print(t)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "kernelspec": { | ||||
|    "display_name": "Python 3", | ||||
|    "language": "python", | ||||
|    "name": "python3" | ||||
|   }, | ||||
|   "language_info": { | ||||
|    "codemirror_mode": { | ||||
|     "name": "ipython", | ||||
|     "version": 3 | ||||
|    }, | ||||
|    "file_extension": ".py", | ||||
|    "mimetype": "text/x-python", | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.7.6" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  "nbformat_minor": 4 | ||||
| } | ||||
							
								
								
									
										749
									
								
								zero-cost-nas/notebooks/nasbench201_correlations.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										749
									
								
								zero-cost-nas/notebooks/nasbench201_correlations.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										362
									
								
								zero-cost-nas/notebooks/ptcv_correlations.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										362
									
								
								zero-cost-nas/notebooks/ptcv_correlations.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
		Reference in New Issue
	
	Block a user