750 lines
166 KiB
Plaintext
750 lines
166 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 1,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import os, pickle, sys\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"from scipy import stats\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import glob\n",
|
|||
|
"from prettytable import PrettyTable\n",
|
|||
|
"from tqdm import tqdm"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Table 1: Spearman ρ of zero-cost proxies on NAS-Bench-201."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 2,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"../results_release/nasbench2/nb2_cf10_seed42_dlrandom_dlinfo1_initwnone_initbnone.p 15625\n",
|
|||
|
"../results_release/nasbench2/nb2_cf100_seed42_dlrandom_dlinfo1_initwnone_initbnone.p 15625\n",
|
|||
|
"../results_release/nasbench2/nb2_im120_seed42_dlrandom_dlinfo1_initwnone_initbnone.p 15625\n",
|
|||
|
"+----------------+-----------+-------+-------+--------+---------+-----------+\n",
|
|||
|
"| Dataset | grad_norm | snip | grasp | fisher | synflow | jacob_cov |\n",
|
|||
|
"+----------------+-----------+-------+-------+--------+---------+-----------+\n",
|
|||
|
"| CIFAR10 | 0.594 | 0.596 | 0.514 | 0.36 | 0.737 | 0.731 |\n",
|
|||
|
"| CIFAR100 | 0.637 | 0.637 | 0.547 | 0.385 | 0.763 | 0.704 |\n",
|
|||
|
"| ImageNet16-120 | 0.578 | 0.578 | 0.549 | 0.327 | 0.751 | 0.701 |\n",
|
|||
|
"+----------------+-----------+-------+-------+--------+---------+-----------+\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"t=None\n",
|
|||
|
"all_ds = {}\n",
|
|||
|
"all_acc = {}\n",
|
|||
|
"allc = {}\n",
|
|||
|
"all_metrics = {}\n",
|
|||
|
"all_runs = {}\n",
|
|||
|
"metric_names = ['grad_norm', 'snip', 'grasp', 'fisher', 'synflow', 'jacob_cov']\n",
|
|||
|
"for fname,rname in [('../results_release/nasbench2/nb2_cf10_seed42_dlrandom_dlinfo1_initwnone_initbnone.p','CIFAR10'),\n",
|
|||
|
" ('../results_release/nasbench2/nb2_cf100_seed42_dlrandom_dlinfo1_initwnone_initbnone.p','CIFAR100'),\n",
|
|||
|
" ('../results_release/nasbench2/nb2_im120_seed42_dlrandom_dlinfo1_initwnone_initbnone.p','ImageNet16-120')]:\n",
|
|||
|
" runs=[]\n",
|
|||
|
" f = open(fname,'rb')\n",
|
|||
|
" while(1):\n",
|
|||
|
" try:\n",
|
|||
|
" runs.append(pickle.load(f))\n",
|
|||
|
" except EOFError:\n",
|
|||
|
" break\n",
|
|||
|
" f.close()\n",
|
|||
|
" print(fname, len(runs))\n",
|
|||
|
" \n",
|
|||
|
" all_runs[fname]=runs\n",
|
|||
|
" all_ds[fname] = {}\n",
|
|||
|
" metrics={}\n",
|
|||
|
" for k in metric_names:\n",
|
|||
|
" metrics[k] = []\n",
|
|||
|
" acc = []\n",
|
|||
|
" \n",
|
|||
|
" if t is None:\n",
|
|||
|
" hl=['Dataset']\n",
|
|||
|
" hl.extend(metric_names)\n",
|
|||
|
" t = PrettyTable(hl)\n",
|
|||
|
" \n",
|
|||
|
" for r in runs:\n",
|
|||
|
" for k,v in r['logmeasures'].items():\n",
|
|||
|
" if k in metrics:\n",
|
|||
|
" metrics[k].append(v)\n",
|
|||
|
" acc.append(r['testacc'])\n",
|
|||
|
" \n",
|
|||
|
" all_ds[fname]['metrics'] = metrics\n",
|
|||
|
" all_ds[fname]['acc'] = acc\n",
|
|||
|
" \n",
|
|||
|
" res = []\n",
|
|||
|
" crs = {}\n",
|
|||
|
" for k in hl:\n",
|
|||
|
" if k=='Dataset':\n",
|
|||
|
" continue\n",
|
|||
|
" v = metrics[k]\n",
|
|||
|
" cr = abs(stats.spearmanr(acc,v,nan_policy='omit').correlation)\n",
|
|||
|
" #print(f'{k} = {cr}')\n",
|
|||
|
" res.append(round(cr,3))\n",
|
|||
|
" crs[k]=cr\n",
|
|||
|
" \n",
|
|||
|
" ds = rname\n",
|
|||
|
" all_acc[ds]=acc\n",
|
|||
|
" allc[ds]=crs\n",
|
|||
|
" t.add_row([ds]+res)\n",
|
|||
|
" \n",
|
|||
|
" all_metrics[ds] = metrics\n",
|
|||
|
"print(t)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Voting between 3 metrics could improve rank correlation"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 3,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"100%|██████████| 15625/15625 [08:13<00:00, 31.65it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [08:17<00:00, 31.41it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [08:20<00:00, 31.20it/s]"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"votes correlation: {'cifar10': 0.8170822831897641, 'cifar100': 0.8323757385510576, 'ImageNet16-120': 0.8131110314104887}\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from tqdm import tqdm\n",
|
|||
|
"votes = {}\n",
|
|||
|
"def vote(mets, gt):\n",
|
|||
|
" numpos = 0\n",
|
|||
|
" for m in mets:\n",
|
|||
|
" numpos += 1 if m > 0 else 0\n",
|
|||
|
" if numpos >= len(mets)/2:\n",
|
|||
|
" sign = +1\n",
|
|||
|
" else:\n",
|
|||
|
" sign = -1\n",
|
|||
|
" return sign*gt\n",
|
|||
|
"\n",
|
|||
|
"for ds in all_acc.keys():\n",
|
|||
|
" num_pts = 15625\n",
|
|||
|
" #num_pts = 1000\n",
|
|||
|
" tot=0\n",
|
|||
|
" right=0\n",
|
|||
|
" for i in tqdm(range(num_pts)):\n",
|
|||
|
" for j in range(num_pts):\n",
|
|||
|
" if i!=j:\n",
|
|||
|
" diff = all_acc[ds][i] - all_acc[ds][j]\n",
|
|||
|
" if diff == 0:\n",
|
|||
|
" continue\n",
|
|||
|
" diffsyn = []\n",
|
|||
|
" for m in ['synflow', 'jacob_cov', 'snip']:\n",
|
|||
|
" diffsyn.append(all_metrics[ds][m][i] - all_metrics[ds][m][j])\n",
|
|||
|
" same_sign = vote(diffsyn, diff)\n",
|
|||
|
" right += 1 if same_sign > 0 else 0\n",
|
|||
|
" tot += 1\n",
|
|||
|
" votes[ds.lower() if 'CIFAR' in ds else ds] = right/tot\n",
|
|||
|
"print('votes correlation: ', votes)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Figure 1: Evaluation of different econas proxies on NAS-Bench-201 CIFAR-10"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 4,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"try to create the NAS-Bench-201 api from ../data/NAS-Bench-201-v1_0-e61699.pth\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from nas_201_api import NASBench201API as API\n",
|
|||
|
"api = API('../data/NAS-Bench-201-v1_0-e61699.pth')\n",
|
|||
|
"api.verbose = False"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 5,
|
|||
|
"metadata": {
|
|||
|
"scrolled": true
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"100%|██████████| 15625/15625 [00:04<00:00, 3694.77it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:04<00:00, 3677.47it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:04<00:00, 3666.86it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:04<00:00, 3668.30it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:04<00:00, 3647.79it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:04<00:00, 3645.00it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:04<00:00, 3633.39it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:04<00:00, 3639.16it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:04<00:00, 3630.43it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:04<00:00, 3618.66it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:04<00:00, 3606.21it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:04<00:00, 3591.64it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6387.22it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6375.77it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6371.41it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6363.89it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6352.84it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6370.86it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6365.73it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6330.58it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6358.99it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6357.03it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6369.05it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6348.56it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6362.67it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6343.42it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6341.09it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6350.24it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6364.69it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6329.40it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6323.95it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6292.45it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6297.17it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6354.38it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6375.51it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6360.75it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6350.71it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6312.46it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6349.09it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6367.49it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 6360.85it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5456.53it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5444.93it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5434.87it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5437.96it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5429.18it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5397.78it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5414.16it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5470.91it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5439.19it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5432.42it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5428.60it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5416.43it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5435.99it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5432.32it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5435.25it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5429.95it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5430.57it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5434.24it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5515.30it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5447.61it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5494.53it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5402.25it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5421.46it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5409.47it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5428.66it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5416.57it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5416.47it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5421.63it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5408.41it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5441.42it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5421.06it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5415.65it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5414.25it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5407.01it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5408.28it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5380.57it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5371.12it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5381.57it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5413.72it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5430.10it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5438.85it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5464.67it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5454.57it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5454.08it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5439.15it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5452.53it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5438.12it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5437.44it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5412.62it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5432.36it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5406.65it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5410.13it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5404.87it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5453.51it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5434.85it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5441.20it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5439.81it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5449.05it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5451.43it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5445.37it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5443.61it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5425.72it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5437.65it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5430.78it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5434.60it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5433.80it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5435.14it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5430.05it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5404.00it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5428.38it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5413.04it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5427.40it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5421.51it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5435.61it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5427.46it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5442.03it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5430.49it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5435.78it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5393.37it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5399.36it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5413.35it/s]\n",
|
|||
|
"100%|██████████| 15625/15625 [00:02<00:00, 5410.74it/s]\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"dallb={}\n",
|
|||
|
"dallb_10={}\n",
|
|||
|
"dallb_10f={}\n",
|
|||
|
"for ds in ['cifar10', 'cifar100','ImageNet16-120']:\n",
|
|||
|
" allb = {}\n",
|
|||
|
" allb_10 = {}\n",
|
|||
|
" allb_10f = {}\n",
|
|||
|
"\n",
|
|||
|
" for k in range(0,41):\n",
|
|||
|
"\n",
|
|||
|
" b=[]\n",
|
|||
|
" b_10 = []\n",
|
|||
|
" b_10f = []\n",
|
|||
|
"\n",
|
|||
|
" for i in tqdm(range(len(api))):\n",
|
|||
|
" info = api.get_more_info(i, 'cifar10-valid' if ds=='cifar10' else ds, iepoch=None, hp='200', is_random=False)\n",
|
|||
|
" info_10 = api.get_more_info(i, 'cifar10-valid' if ds=='cifar10' else ds, iepoch=k, hp='200', is_random=False)\n",
|
|||
|
"\n",
|
|||
|
" try:\n",
|
|||
|
" info_10_fast = api.get_more_info(i, 'cifar10-valid' if ds=='cifar10' else ds, iepoch=k, hp='12', is_random=False)\n",
|
|||
|
" testacc_10_fast = info_10_fast['valid-accuracy' if ds=='cifar10' else 'valtest-accuracy']\n",
|
|||
|
" except Exception:\n",
|
|||
|
" pass\n",
|
|||
|
" \n",
|
|||
|
" testacc = info['test-accuracy']\n",
|
|||
|
" testacc_10 = info_10['valid-accuracy' if ds=='cifar10' else 'valtest-accuracy']\n",
|
|||
|
"\n",
|
|||
|
" b.append(testacc)\n",
|
|||
|
" b_10.append(testacc_10) \n",
|
|||
|
" b_10f.append(testacc_10_fast)\n",
|
|||
|
"\n",
|
|||
|
" allb[k] = b\n",
|
|||
|
" allb_10[k] = b_10\n",
|
|||
|
" allb_10f[k] = b_10f\n",
|
|||
|
" dallb[ds]=allb\n",
|
|||
|
" dallb_10[ds]=allb_10\n",
|
|||
|
" dallb_10f[ds]=allb_10f"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 6,
|
|||
|
"metadata": {
|
|||
|
"scrolled": true
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"/home/SERILOCAL/mohamed1.a/anaconda3/envs/snip-torch/lib/python3.7/site-packages/numpy/lib/function_base.py:2534: RuntimeWarning: invalid value encountered in true_divide\n",
|
|||
|
" c /= stddev[:, None]\n",
|
|||
|
"/home/SERILOCAL/mohamed1.a/anaconda3/envs/snip-torch/lib/python3.7/site-packages/numpy/lib/function_base.py:2535: RuntimeWarning: invalid value encountered in true_divide\n",
|
|||
|
" c /= stddev[None, :]\n",
|
|||
|
"/home/SERILOCAL/mohamed1.a/anaconda3/envs/snip-torch/lib/python3.7/site-packages/scipy/stats/_distn_infrastructure.py:903: RuntimeWarning: invalid value encountered in greater\n",
|
|||
|
" return (a < x) & (x < b)\n",
|
|||
|
"/home/SERILOCAL/mohamed1.a/anaconda3/envs/snip-torch/lib/python3.7/site-packages/scipy/stats/_distn_infrastructure.py:903: RuntimeWarning: invalid value encountered in less\n",
|
|||
|
" return (a < x) & (x < b)\n",
|
|||
|
"/home/SERILOCAL/mohamed1.a/anaconda3/envs/snip-torch/lib/python3.7/site-packages/scipy/stats/_distn_infrastructure.py:1912: RuntimeWarning: invalid value encountered in less_equal\n",
|
|||
|
" cond2 = cond0 & (x <= _a)\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"dslow = {}\n",
|
|||
|
"dfast = {}\n",
|
|||
|
"for ds,allb in dallb.items():\n",
|
|||
|
" dslow[ds] = []\n",
|
|||
|
" dfast[ds] = []\n",
|
|||
|
" t = PrettyTable(['Epoch', 'Normal Training (200 epochs)', 'Fast Training (12 Epochs)'])\n",
|
|||
|
" for k,b in allb.items():\n",
|
|||
|
" r = [k]\n",
|
|||
|
" for v in [dallb_10[ds][k], dallb_10f[ds][k]]:\n",
|
|||
|
" cr = abs(stats.spearmanr(b,v,nan_policy='omit').correlation)\n",
|
|||
|
" r.append(round(cr,3))\n",
|
|||
|
" t.add_row(r)\n",
|
|||
|
" dslow[ds].append(r[1])\n",
|
|||
|
" dfast[ds].append(r[2])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 7,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAvsAAAD8CAYAAADpG2vfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nOzdd3hUVfrA8e+dmknvPZCEJBBawABSpQpIVWwIooiK3bWs3VVsu7quv1WsoOvaQMQOCK6ChKJ0Qi8BSSGN9DJJJplyfn8EQgJJSEKSmYTzeZ48ydx77r3vzGTuvPfcUxQhBJIkSZIkSZIkdT4qewcgSZIkSZIkSVLbkMm+JEmSJEmSJHVSMtmXJEmSJEmSpE5KJvuSJEmSJEmS1EnJZF+SJEmSJEmSOimZ7EuSJEmSJElSJ9Wuyb6iKBMVRTmqKMpxRVGerGd9F0VR1iuKkqgoyj5FUSadXh6uKEqFoih7Tv980J5xS5IkSZIkSVJHpLTXOPuKoqiBJOBKIB3YAdwkhDhUq8xiIFEI8b6iKD2B1UKIcEVRwoFVQoje7RKsJEmSJEmSJHUC7VmzPwg4LoQ4IYSoApYB088pIwD30397AJntGJ8kSZIkSZIkdSrtmeyHACdrPU4/vay2BcDNiqKkA6uBB2qtizjdvGeDoigj2jRSSZIkSZIkSeoENPYO4Bw3AZ8IId5QFGUI8LmiKL2BLKCLECJfUZR44AdFUXoJIUpqb6woynxgPoDBYIgPCwtr9GA2mw2VyvH7KMs4W19HifVCcSYlJeUJIfyaur/mfkbakqO+B44aFzhubI4a18V8PpycnOK7dOnSZrG1hCO+zjKmpnHEmKD5nxGpgxJCtMsPMAT4X63HTwFPnVPmIBBW6/EJwL+efSUAAxo7Xnx8vLiQ9evXX7CMI5Bxtr6OEuuF4gR2ihZ+JpvyGWlLjvoeOGpcQjhubI4a18V8PmJiYuwRcqMc8XWWMTWNI8YkxMV9RuRPx/lpz8vMHUC0oigRiqLogJnAinPKpAFjARRFiQWcgFxFUfxOd/BFUZRIIJrqCwFJkiRJkiRJkhrQbs14hBAWRVHuB/4HqIGPhRAHFUV5keoryxXAo8CHiqI8THVn3blCCKEoyhXAi4qimAEbcLcQoqC9Ypc6ryqThZOHC6gst2A2WakyWQjv64tfmFu95SuMVWz++hi9hgcTHO1VbxlTmZnEX9KI7O9HQLh7vWUqy80c3FTd/1xRFBQVhPX0xifYtXWemCRJknTJMFdayUgqpKrCgrnSiqXKRlisN97BLvWWLy+pYtNXSe0cpWQv7dpmXwixmuqOt7WXPVfr70PAsHq2+xb4ts0DlDqV8pIqNn55lO5Dgojo61tvmaoKCz8vOlBnmYuHvsFk32q2kf1nMRF9G27iWGWysGdtGt7BLo0k+xa2fP9nnWVjbomVyb7UKoyFlbh66e0dhiSd583wN3ko5SF7h+EwhE1QXlqFVq9G51R/SlZWVMkP/05k0NQIogcE1FvGVGbmp3f31Vk2ek6PBpN9m1WQl268uOClDsPROuhKUpNUVVjY+FUSkf38iOxXf+KtdVKTn1mGsDY8l4Szu44bnhmI3lmDzkmD1kmNWt1w6zZXLyfmvDy00djcfQzc8+7oRsu4eTsxf+FIENX9ZhCg1jpe5y2p49n6w5/s+jmVMbfEEjs0yN7hSHa074t9bFu4DWuVlZDLQ5j83mRO/HqCdU+vQ1gFzr7O3LLuFioKKvhx3o8UnihE66xl6uKpBPQNIGFBAsVpxRSeKKQ4rZjBDw3m8gcvB2DZ1csoOVlCcX4xbk+7ET8/HpvVxorbV5C5MxNFUeg3rx9DHh5i51fBvqxWG1azrcH1pnIznzzxOyNujKbv6PoHTNDoVPiGueLkqm1wP84eOq57cgA6JzVavRqNTo3WSd1geVcvPbNfGMzNLzb9uUgdl0z2pQ5Jq1eTk1LSYM05gFanZvYLgxvdj0qtarAWvy0pKgWtruETsSS1xJ+JOez6OZVu/f3o2tvH3uFIdpR7OJeDXx1k3u/zUGvV/HTvT+z7Yh/rn13P3I1z8YrwoqKgAoD1z68nsH8gM3+YSfJvyXx/y/fcveduAPKO5HHr+lupKq3ine7vMOCeAai1aqZ/PB2Dt4F1/1vHtke3EXttLEUpRZRmlHLvgXsBMBWZ7Pb87UXYBBuXJVGSX0FpvonivAoGXBUO9Vewo3PSMHJWd4K6eTS4T72zlgl3ND6nqFqtavT7ULq0yWRfchjCJsg9WUr60UIyjhbiFeDC8Bui6y2rqBRmLWg8kZekS0nuyVJ++eggARHujL2tp7yYvMQlr0smc1cmHw78EABLhYWMbRl0vaIrXhHV/Y0M3gYATm4+yQ3f3gBAxJgIKvIrqCypBCB6cjQavQaNXoOLvwtlp8pwD3Vn28JtHPn+CEajEUuehYJjBfh096HwRCGrH1hNzOQYuo3vBsDGVzZy6OtDAJRmlvJBvw8ACBsWxuR3J7ffi9IKzJVWKsstDTaTU1QKJ48UoNWr8QpyIbyvLyExXiRlpNZbXq1R0fuKc6cckqTWJZN9yWGsemcvaYeq+117BTo3WtMhSVJdbt5O9BwWzOXTImWiLyGEIO7WOMb9Y1zNsqMrj3Jw2cFm7UejP5smKGoFm8VGSkIKJ9ae4PYtt/P79t9JWZCCxWTB4GXg7r13c/x/x9n5wU4OLj/I9I+nc8UzV3DFM1cA1W32z9w16IiWvbQNvy7uTJzfcE37zS+e33QpKaMto5KkxslkX3IYscOCiR4UQFisNy4esnOhJDWHk4uWkbO62zsMyUFEjo1k2fRlDHl4CC7+LlQUVBDQN4DV966mMLmwphmPwdtAlxFd2LdkHyP/NpKUhBScfZ3Ruzd8DjYVmzB4GdA6aylPKyd9azoA5XnlqHVqel7bE9/uvnx383ft9XRbjbAJFJXS4PrBV3fD2U3XjhFJ0sWTyb7ULqxmGyePFBDawwuNtv5ax6h4/3aOSpIkqXPy6+nH6JdH8/n4zxE2gVqrZtK7k5iyeArLZyxH2AQu/i7M+XUOoxaM4sd5P/J+3/fROmu5+tOrG9131MQodn2wi3dj30X4CEIHhwJQklHCj7f9iLBVD4ow9h9j2/x5tqbtq5I5ujWLm18agqLUn/A3NBqOJDkymexL7SLtcAGr39vHtAf7EdbT297hNEuJycz+9GKGRdU/fKckSW3PZhNsPp7HiGjfBhMxqa7eN/am943nNzeJvqpuXyiDt4GZP8w8r9yoBaPqPD7T8RZg9prZACQkJDBq1Nlyd+2+q9GYHHnYTZ8QFyL6+WGzCtQa+T8mdR5yrD+pXYTFejHl/jiCoz3tHcp58o2VmMzWBte/89txZn+0jY1Jue0YlSRJZ1htgrfWHeOWj7fz455Me4cjdTCmMjP71qez7OXtlBY0PEJQt/7+DL8uGrVGpkZS5yL/o6V2odGq6drbx+HGkt98LI/hr63n6e/317veZhOs3FudXDy/4iCVloYvCiRJahsv/3SIt9YdA+BEXpmdo5E6irLiSn7970E+efJ3Nn2VhEqlYDbJc7h06XGszEvqsJK2Z5O6wVbTVrMj+O3IKeZ9uoMqq41V+7IoKq86r8yOlAKyik1cFx9Kcl4ZizecaJNYhBAs2ZbKl9vT+ON4HumF5Vg70Gsptb/KcjNlRZX2DqNdLNt+subvSrMVm02wbHsaxeVmO0YlOTq9QUPW8WJihwZxw9MDueHpgQ3OKCtJnZlssy+1CptNYLNW3y41dICRCnZkW1j0yy56BrvzxMQezP5oGyv2ZnLLkPA65VbszcRJq+KFab2oqLLyzvrjXN0/hDBv53r3W2mxkl5YQVpBOVF+rg2WO9fGY3k88/2BOst0ahUjQlTUag4rSTX2rjvJ7l/SuOWVoTi7O/5n7mJU1Gpml1ZQzuoDWTz53X5eWX2YPc+NR93I6CmXKnOFmSUTl3DLb7egOj0reGVJJe/2fJceV/dg0juTztsmYUECuz/cjbNf9Xlr7N/HEj2p/rlOzijYXsA7d72DzWr
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 770.4x259.2 with 3 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {
|
|||
|
"needs_background": "light"
|
|||
|
},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"with open('../results_release/nasbench2/nb2_fast_train.p','rb') as f:\n",
|
|||
|
" truns = pickle.load(f)\n",
|
|||
|
" \n",
|
|||
|
"per_epochs=[]\n",
|
|||
|
"for expname in ['nb2_fast_train', 'nb2_fast_train_ch8', 'nb2_fast_train_im16', 'nb2_fast_train_im16_ch8']:\n",
|
|||
|
"\n",
|
|||
|
" with open(f'../results_release/nasbench2/{expname}.p','rb') as f:\n",
|
|||
|
" truns = pickle.load(f)\n",
|
|||
|
"\n",
|
|||
|
" #form array per epoch for fast training\n",
|
|||
|
" per_epoch = {}\n",
|
|||
|
" for r in truns:\n",
|
|||
|
" for i,e in enumerate(r['logmeasures']):\n",
|
|||
|
" if i not in per_epoch:\n",
|
|||
|
" per_epoch[i] = []\n",
|
|||
|
" per_epoch[i].append(e['val_acc'])\n",
|
|||
|
" per_epochs.append(per_epoch)\n",
|
|||
|
" \n",
|
|||
|
"ds = 'cifar10'\n",
|
|||
|
"econas = []\n",
|
|||
|
"acc = dallb[ds][0]\n",
|
|||
|
"for exp in per_epochs:\n",
|
|||
|
" l = []\n",
|
|||
|
" t = PrettyTable(['Epoch', 'Correlation of Proxy Training'])\n",
|
|||
|
" for k,b in exp.items():\n",
|
|||
|
" r = [k]\n",
|
|||
|
" cr = abs(stats.spearmanr(b,acc,nan_policy='omit').correlation)\n",
|
|||
|
" r.append(round(cr,3))\n",
|
|||
|
" t.add_row(r)\n",
|
|||
|
" l.append(cr)\n",
|
|||
|
" econas.append(l)\n",
|
|||
|
"\n",
|
|||
|
"#WE ONLY IMPLEMENT ECONAS FOR CIFAR10\n",
|
|||
|
"\n",
|
|||
|
"ls = ['solid','dotted','dashed','dashdot',(0, (3, 5, 1, 5, 1, 5))]\n",
|
|||
|
"\n",
|
|||
|
"ds = 'cifar10'\n",
|
|||
|
"slow=dslow[ds]\n",
|
|||
|
"fig, axs = plt.subplots(1,3,figsize=(10.7,3.6), sharey=True)\n",
|
|||
|
"\n",
|
|||
|
"ax=axs[0]\n",
|
|||
|
"\n",
|
|||
|
"epx=41\n",
|
|||
|
"\n",
|
|||
|
"#regular training\n",
|
|||
|
"x=range(0,epx,1)\n",
|
|||
|
"ax.plot(x,slow[0:epx], label= 'baseline $r_{32}c_{16}$', linestyle=ls[0])\n",
|
|||
|
"\n",
|
|||
|
"#econas\n",
|
|||
|
"x2=range(0,epx,1)\n",
|
|||
|
"ax.plot(x2,econas[0][0:epx], label='econas $r_8c_4$', linestyle=ls[1])\n",
|
|||
|
"ax.plot(x2,econas[1][0:epx], label='econas $r_8c_8$', linestyle=ls[2])\n",
|
|||
|
"ax.plot(x2,econas[2][0:epx], label='econas $r_{16}c_4$', linestyle=ls[3])\n",
|
|||
|
"ax.plot(x2,econas[3][0:epx], label='econas $r_{16}c_8$', linestyle=ls[4])\n",
|
|||
|
"ax.grid()\n",
|
|||
|
"ax.set_ylim(0.5,0.85)\n",
|
|||
|
"ax.set_xlabel(\"Epoch\")\n",
|
|||
|
"ax.set_ylabel('Spearman $\\\\rho$')\n",
|
|||
|
"\n",
|
|||
|
"#--------------------------------------------------------------------------------\n",
|
|||
|
"ax=axs[1]\n",
|
|||
|
"\n",
|
|||
|
"#regular training\n",
|
|||
|
"x=[ff for ff in range(0,epx)]\n",
|
|||
|
"ax.plot(x,slow[0:epx], label= 'baseline $r_{32}c_{16}$', linestyle=ls[0])\n",
|
|||
|
"\n",
|
|||
|
"#econas\n",
|
|||
|
"d = 230\n",
|
|||
|
"x2=[ff/d for ff in range(0,epx)]\n",
|
|||
|
"ax.plot(x2,econas[0][0:epx], label='econas $r_8c_4$', linestyle=ls[1])\n",
|
|||
|
"d = 59.5\n",
|
|||
|
"x2=[ff/d for ff in range(0,epx)]\n",
|
|||
|
"ax.plot(x2,econas[1][0:epx], label='econas $r_8c_8$', linestyle=ls[2])\n",
|
|||
|
"ax.plot(x2,econas[2][0:epx], label='econas $r_{16}c_4$', linestyle=ls[3])\n",
|
|||
|
"d = 15.4\n",
|
|||
|
"x2=[ff/d for ff in range(0,epx)]\n",
|
|||
|
"ax.plot(x2,econas[3][0:epx], label='econas $r_{16}c_8$', linestyle=ls[4])\n",
|
|||
|
"\n",
|
|||
|
"ax.grid()\n",
|
|||
|
"ax.set_ylim(0.5,0.85)\n",
|
|||
|
"#plt.xlim(0,0.01)\n",
|
|||
|
"ax.set_xlabel(\"Normalized FLOPS\")\n",
|
|||
|
"ax.set_xscale('log', basex=2)\n",
|
|||
|
"\n",
|
|||
|
"from fractions import Fraction\n",
|
|||
|
"labels = (\"1\", \"1\", '1/64', \"1/8\", \"1\", \"8\")\n",
|
|||
|
"ax.set_xticklabels(labels)\n",
|
|||
|
"\n",
|
|||
|
"#--------------------------------------------------------------------------------\n",
|
|||
|
"ax=axs[2]\n",
|
|||
|
"\n",
|
|||
|
"#regular training\n",
|
|||
|
"x=[ff for ff in range(0,epx)]\n",
|
|||
|
"ax.plot(x,slow[0:epx], label= 'baseline $r_{32}c_{16}$', linestyle=ls[0])\n",
|
|||
|
"\n",
|
|||
|
"#econas\n",
|
|||
|
"d = 4\n",
|
|||
|
"x2=[ff/d for ff in range(0,epx)]\n",
|
|||
|
"ax.plot(x2,econas[0][0:epx], label='econas $r_8c_4$', linestyle=ls[1])\n",
|
|||
|
"d = 4\n",
|
|||
|
"x2=[ff/d for ff in range(0,epx)]\n",
|
|||
|
"ax.plot(x2,econas[1][0:epx], label='econas $r_8c_8$', linestyle=ls[2])\n",
|
|||
|
"ax.plot(x2,econas[2][0:epx], label='econas $r_{16}c_4$', linestyle=ls[3])\n",
|
|||
|
"d = 3.3\n",
|
|||
|
"x2=[ff/d for ff in range(0,epx)]\n",
|
|||
|
"ax.plot(x2,econas[3][0:epx], label='econas $r_{16}c_8$', linestyle=ls[4])\n",
|
|||
|
"\n",
|
|||
|
"p=15\n",
|
|||
|
"ax.scatter(p/d, econas[3][p], marker='o', color='purple')\n",
|
|||
|
"ax.annotate(f'({round(p/d,1)}, {round(econas[3][p],2)})',(p/d-2, econas[3][p]+0.01), horizontalalignment='left', color='purple')\n",
|
|||
|
"ax.annotate(f'econas+',(p/d-2, econas[3][p]+0.03), horizontalalignment='left', color='purple')\n",
|
|||
|
"\n",
|
|||
|
"p=20\n",
|
|||
|
"ax.scatter(p/4, econas[0][p], marker='p', color='orange')\n",
|
|||
|
"ax.annotate(f'({round(p/4,1)}, {round(econas[0][p],2)})',(p/d-1.3, econas[0][p]-0.03), horizontalalignment='left', color='chocolate')\n",
|
|||
|
"ax.annotate(f'econas',(p/d-1.3, econas[0][p]-0.05), horizontalalignment='left', color='chocolate')\n",
|
|||
|
"\n",
|
|||
|
"ax.grid()\n",
|
|||
|
"ax.set_ylim(0.5,0.85)\n",
|
|||
|
"ax.set_xlim(0,10)\n",
|
|||
|
"ax.set_xlabel(\"Normalized GPU Runtime\")\n",
|
|||
|
"#ax.set_xscale('log', basex=10)\n",
|
|||
|
"\n",
|
|||
|
"fig.tight_layout(pad=0.3)\n",
|
|||
|
"plt.legend(bbox_to_anchor=(1,0.7))\n",
|
|||
|
"plt.tight_layout()\n",
|
|||
|
"plt.savefig('econas.pdf')\n",
|
|||
|
"plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Figure 2: Correlation of validation accuracy to final test accuracy during the first 12 epochs training for three datasets on the NAS-Bench-201 search space. Zero-cost and EcoNAS proxies are also labeled for comparison.\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 8,
|
|||
|
"metadata": {
|
|||
|
"scrolled": false
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"cifar10\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAARgAAAEYCAYAAACHjumMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nOydd1hU19aH3w0MVVAQJWLFGguIYo293KtGozGJQZMYTbsxX3rxanrvvXhTNLYUNUWjMcbEhpJYQVGxN1SsYKW3Wd8fZyAoKDMDw1D2+zzzcM4+e++zzmFY7Lp+SkTQaDQaR+DibAM0Gk3VRTsYjUbjMLSD0Wg0DkM7GI1G4zC0g9FoNA5DOxiNRuMwtIPRlAqlVJ5SKq7QZ3IZ1t1EKRVfVvVpyh83ZxugqfRkiEi4s43QVEx0C0bjEJRSCUqpd5RS25VSG5VSzS3pTZRSK5VS25RSK5RSjSzpQUqpBUqprZbPdZaqXJVSU5VSO5RSfyqlvCz5H1FK7bTUM9dJj6kpAe1gNKXF67IuUmShaxdEJBT4DPjIkvYpMEtEwoDvgE8s6Z8Aq0WkPdAR2GFJbwFMEZG2wHngZkv6ZKCDpZ4Jjno4TelQequApjQopVJFpEYx6QlAfxE5qJQyASdFpLZSKhmoJyI5lvQTIhKolEoCGohIVqE6mgDLRKSF5XwSYBKR15RSS4FU4BfgFxFJdfCjauxAt2A0jkSucGwLWYWO8/hn3HAoMAWjtbNJKaXHEysg2sFoHElkoZ/rLMdrgdGW49uBaMvxCuABAKWUq1Kq5pUqVUq5AA1FZBUwCagJFGlFaZyP9vqa0uKllIordL5URPKnqv2VUtswWiFjLGkPAzOUUhOBJOAuS/qjwFdKqXswWioPACeucE9X4FuLE1LAJyJyvsyeSFNm6DEYjUOwjMF0EpFkZ9uicR66i6TRaByGbsFoNBqHoVswGo3GYVRJB6OUGqyU2qOU2l/avTFKqelKqdNlsSdGKdVQKbXKsgJ1h1Lq0VLU5WlZIbvVUtfLZWCfq1Jqi1JqcRnUlWBZxRunlIopZV21lFI/KaV2K6V2KaW621lPq8sWBV5USj1WStset7z/eKXUHKWUZynqetRSzw577Cruu6qUClBKLVNK7bP89C9FXaMstpmVUp2sMkpEqtQHY4bhANAUcAe2Am1KUV9vjLUW8WVgWz2go+XYF9hrr20Ysyc1LMcmYAPQrZT2PQF8Dywug2dNAALL6Hc6C7jXcuwO1Cqj78lJoHEp6qgPHAK8LOc/AOPtrKsdEA94Y8zuLgea21hHke8q8A4w2XI8GXi7FHW1BloBURgD+CXWUxVbMF2A/SJyUESygbnACHsrE5E1wNmyMExETojIZstxCrAL40tqT10i/6xeNVk+dg+oKaUaYCxem2ZvHY7AMhXdG/gaQESypWympAcAB0TkcCnrccOYqnfDcA7H7aynNbBBRNJFJBdYDdxkSwVX+K6OwHDQWH7eaG9dIrJLRPbYYlNVdDD1gaOFzhOx84/YkViWwXfAaHnYW4erZQ3KaYwl9XbXhbFX6L+AuRR1FEaAP5VSsUqp/5SinhCM9TIzLN23aUopnzKwbzQwpzQViMgx4D3gCMaanQsi8qed1cUDvZRStZVS3sD1QMPS2GchSETy1xOdBILKoE6rqYoOpsKjlKoB/Aw8JiIX7a1HRPLECJXQAOiilGpnpz3DgNMiEmuvLcXQU0Q6AkOAB5VSve2sxw2jqf65iHQA0jCa+najlHIHhgM/lrIef4wWQggQDPgope6wpy4R2QW8DfwJLAXiMBYclhli9HPKddq4KjqYY1zq+RtY0ioElg1+PwPficj8sqjT0mVYBQy2s4oewHDL4ri5QH+l1LeltOmY5edpYAFG19UeEoHEQq2znzAcTmkYAmwWkVOlrGcgcEhEkkQkB5gPXFdCmSsiIl+LSISI9AbOYYzRlZZTSql6AJafp8ugTqupig5mE9BCKRVi+U81GljkZJsAUEopjLGEXSLyQSnrqqOUqmU59gL+Bey2py4ReVpEGohIE4z3tVJE7PpPbLHHRynlm38M/BujC2CPbSeBo0qpVpakAcBOe22zMIZSdo8sHAG6KaW8Lb/bARjjanahlKpr+dkIY/zl+zKwcREwznI8DlhYBnVaj70j6BX5g9F/3Ysxm/RsKeuag9G/zsH4b3pPKerqidFE3YbRBI4DrrezrjBgi6WueOCFMnp3fSnlLBLGDN5Wy2dHGfwOwoEYy7P+AviXoi4f4AxQs4ze18sYjj0e+AbwKEVd0RjOcyswwI7yRb6rQG2MjaT7MGamAkpR10jLcRZwCvijpHr0Sl6NRuMwqmIXSaPRVBC0g9FoNA5DOxiNRuMwtIPRaDQOo1wdTEmbEJVSjSybAbcoQ47iekt6E6VURqFNal9Yca/SrB51aH0Vta6yrq+62Kaf8yqUxVSdldNeJW5CBL4CHrActwESLMdNsHGzIRBTxvaXWX0VtS5tm/Prqsi22VNXebZgrNmEKICf5bgm9m8c02g0FYDyDPpd3CbErpfleQljg9zDGAuiBha6FqKU2gJcBJ4TkejLyuY34fKbcRGenp5ltsjHzc2NsqqvotZV1vVVF9uqy3Nizz6msmzaldC8ugWYVuh8LPDZZXmeAJ60HHfHWNXoAngAtS3pERiOyu9q9/P29hZrWLVqlVX5KiLadudQmW0Xsd9+IE0qcBfJmk2I92AE7UFE1gGeGEGLskTkjCU9FmMsp6XDLdZoKjliFsx5ZRWBw3bK08FYswnxCMaGMZRSrTEcTJJlY5+rJb0phl7xwXKzXKMpY8x5ZnLSc4qkn44/TcLqhCLpB5Yd4IdbfuDbQd+y7sN1Ra7HTo3ljyf/KJIe80UMfzxeNH3T/zbxmudrvFXzLZZPXl7k+pYZW1j53Eorn+bKlNsYjIjkKqUeAv7AmFGaLiI7lFKvYIxOLwKeBKYqpR7H6O+NFxGxxBJ5RSmVgxEQaYKIlEmUOY3GERxaeYiNn20kNzOXkAEhXPfkpVEc4mbGcWrrKYZ8MuSS9IyzGaQcTylSX63GtWgb2RYPXw/8mxUNq9v+zvaIuegQSef/61ysfZ0mdCL8rnDysvJwcSvazmhzSxskr/RDN+Wq7CgiS4All6W9UOh4J0ZsksvL/YwRQ0WjcQrZadmknU7DP+TSP+6DKw6y77d9DPpg0CXptUJqEXZHGG5ebtRqXKtIfR3vKT6kTePejYtNr92yNrVb1r6ifW4etv0pKxeFycuEyctU7HUPXw+b6rsSWjpWoylEVkoWKcdSCLw28JL04zHHiZ8Tz7Avhl2SHhwRTECzgCL1+If4F3FG1RHtYDQa4Nyhc8wdMZdzB87RakQrbv7+5kuuN+nThCZ9mhQp51nLE89adiuVVHn0XqQyICEhge+/L4vgYxpHk52aza/3/1ok3TfYlxHTRzDp3KQizkVjP9rBlAHawVQ8RITlTy8vMkVr8jHRuHfj/HVXBbh5uBHcKRhXd9fyNLPKU20dzKfjRvF+5DBOHdzP+5HD+HTcqEuuT548mSlTphScv/TSS7z77rtMnDiRdu3aERoayrx58wryRkdHEx4ezocffkheXh4TJ06kc+fOhIWF8eWXX5brs1U3ilvroZTCr74feVl5RdLDbg/DCKGrcTTV1sFkZ2Zc9TwyMpIffvih4PyHH36gbt26xMXFsXXrVpYvX87EiRM5ceIEb731Fr169SIuLo7HH3+cr7/+mpo1a7Jp0yY2bdrE1KlTOXToULk8V3UiYXUCv/7nV94Pfp8j0UeKXO/yUBdM3sXPkmjKBz3IewU6dOjA6dOnOX78OElJSfj7+xMXF8eYMWNwdXUlKCiIPn36sGnTJvz8/C4p++eff7Jt2zZ++uknAC5cuMC+ffsICQlxxqNUWZJ3JVO7ZW3u/utuApoXncnROJ9q5WA+HTeqoKWiXFwR8z/NZ+XiyvuRxhSku6cXD8/6kVGjRvHTTz9x8uRJIiMjrW6FiAiffvopgwYNKjmzpkR2/7Kb5D3J9JzU85L0ThO
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 288x288 with 2 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {
|
|||
|
"needs_background": "light"
|
|||
|
},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"cifar100\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAARgAAAEYCAYAAACHjumMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO2dd3hU1daH35UCgdCrSBFQggiBQOi9qKAi2DAgKtiuWNCr96Lea0f0qqj3syAqCBaq2OXaaIEgNdFQpQQITUpogYT0rO+PcxKGEMjMJJOZSfb7PPPMOfucvc/vnJlZs+taoqoYDAaDJwjwtgCDwVB2MQbGYDB4DGNgDAaDxzAGxmAweAxjYAwGg8cwBsZgMHgMY2AMxUJEckQk3uH1VAmW3VRENpZUeYbSJ8jbAgx+T5qqRnhbhME3MTUYg0cQkUQReV1ENojIGhG5zE5vKiKLRWS9iCwSkSZ2en0R+UZE1tmv7nZRgSIyRUQ2icivIlLJPv8REdlslzPHS7dpKAJjYAzFpVKBJlKUw7FkVQ0H3gP+z057F/hUVdsCM4F37PR3gKWq2g7oAGyy01sAk1S1NXACuNlOfwpob5czxlM3ZygeYpYKGIqDiKSoapVC0hOB/qq6U0SCgYOqWltEjgANVDXLTj+gqnVEJAlopKoZDmU0BRaoagt7/0kgWFUniMjPQArwLfCtqqZ4+FYNbmBqMAZPoufZdoUMh+0czvQbXgdMwqrtrBUR05/ogxgDY/AkUQ7vK+3tFcBwe3skEGNvLwIeABCRQBGpfr5CRSQAaKyqS4AngerAObUog/cxVt9QXCqJSLzD/s+qmjdUXVNE1mPVQkbYaWOB6SIyDkgC7rLTHwU+EpF7sGoqDwAHznPNQGCGbYQEeEdVT5TYHRlKDNMHY/AIdh9MR1U94m0tBu9hmkgGg8FjmBqMwWDwGKYGYzAYPEaZNDAiMkhEtopIQnHXxojINBE5XBJrYkSksYgssWegbhKRR4tRVog9Q3adXdaLJaAvUET+EJH5JVBWoj2LN15EYotZVg0R+VJEtojInyLSzc1yWhaYFHhSRP5eTG2P2c9/o4jMFpGQYpT1qF3OJnd0FfZdFZFaIrJARLbb7zWLUdYwW1uuiHR0SpSqlqkX1gjDDqA5UAFYB1xRjPJ6Y8212FgC2hoAHeztqsA2d7VhjZ5UsbeDgdVA12LqexyYBcwvgXtNBOqU0Gf6KXCvvV0BqFFC35ODwCXFKKMhsAuoZO9/AYx2s6w2wEagMtbo7kLgMhfLOOe7CrwOPGVvPwW8VoyyWgEtgWisDvwiyymLNZjOQIKq7lTVTGAOMNTdwlR1GXCsJISp6gFV/d3ePgX8ifUldacs1TOzV4Ptl9sdaiLSCGvy2lR3y/AE9lB0b+BjAFXN1JIZkh4A7FDV3cUsJwhrqD4Iyzj85WY5rYDVqnpaVbOBpcBNrhRwnu/qUCwDjf1+g7tlqeqfqrrVFU1l0cA0BPY67O/DzR+xJ7GnwbfHqnm4W0agPQflMNaUerfLwlor9ASQW4wyHFHgVxGJE5G/FaOcZljzZabbzbepIhJaAvqGA7OLU4Cq7gfeAPZgzdlJVtVf3SxuI9BLRGqLSGXgWqBxcfTZ1FfVvPlEB4H6JVCm05RFA+PziEgV4Cvg76p60t1yVDVHLVcJjYDOItLGTT2DgcOqGueulkLoqaodgGuAh0Skt5vlBGFV1SeransgFauq7zYiUgEYAswrZjk1sWoIzYCLgVARud2dslT1T+A14FfgZyAea8JhiaFWO6dUh43LooHZz9mWv5Gd5hPYC/y+Amaq6tclUabdZFgCDHKziB7AEHty3Bygv4jMKKam/fb7YeAbrKarO+wD9jnUzr7EMjjF4Rrgd1U9VMxyrgR2qWqSqmYBXwPdi8hzXlT1Y1WNVNXewHGsPrrickhEGgDY74dLoEynKYsGZi3QQkSa2f9Uw4HvvawJABERrL6EP1X1rWKWVVdEatjblYCrgC3ulKWq/1LVRqraFOt5LVZVt/6JbT2hIlI1bxu4GqsJ4I62g8BeEWlpJw0ANrurzWYExWwe2ewBuopIZfuzHYDVr+YWIlLPfm+C1f8yqwQ0fg+MsrdHAd+VQJnO424Pui+/sNqv27BGk54uZlmzsdrXWVj/pvcUo6yeWFXU9VhV4HjgWjfLagv8YZe1EXiuhJ5dX4o5ioQ1grfOfm0qgc8gAoi17/VboGYxygoFjgLVS+h5vYhl2DcCnwMVi1FWDJbxXAcMcCP/Od9VoDbWQtLtWCNTtYpR1o32dgZwCPilqHLMTF6DweAxymITyWAw+AjGwBgMBo9hDIzBYPAYxsAYDAaPUaoGpqhFiCLSxF4M+IdY4SiutdObikiawyK1D5y4VnFmj3q0PF8tq6TLKy/azH1egJIYqnNy2KvIRYjAR8AD9vYVQKK93RQXFxsCsSWsv8TK89WyjDbvl+XL2twpqzRrMM4sQlSgmr1dHfcXjhkMBh+gNJ1+F7YIsUuBc17AWiA3FmtC1JUOx5qJyB/ASeAZVY0pkDevCpdXjYsMCQkpsUk+QUFBlFR5vlpWSZdXXrSVl/vEnXVMJVm1K6J6dQsw1WH/DuC9Auc8DvzD3u6GNasxAKgI1LbTI7EMVbULXa9y5crqDEuWLHHqvAuxa9cunTlzZrHLcZWS0O4tjHbv4a5+IFV9uInkzCLEe7Cc9qCqK4EQLKdFGap61E6Pw+rLCfO4YidJTExk1qySWDZiMJQtStPAOLMIcQ/WgjFEpBWWgUmyF/YF2unNseIV7yyOmHdHDePNqMEc2pnAm1GDeXfUsLOOP/XUU0yaNCl//4UXXmDixImMGzeONm3aEB4ezty5c/PPjYmJISIigv/+97/k5OQwbtw4OnXqRNu2bfnwww+LI9Vg8FtKzcCo5aXrYeAXrBWnX6jqJhEZLyJD7NP+AdwnIuuwFluNtqtmvYH1tnOlL4ExqlosL3OZ6WkX3I+KiuKLL77I3//iiy+oV68e8fHxrFu3joULFzJu3DgOHDjAq6++Sq9evYiPj+exxx7j448/pnr16qxdu5a1a9cyZcoUdu3aVRy5BoNfUqqRHVX1R+DHAmnPOWxvxvJNUjDfV1g+VEqN9u3bc/jwYf766y+SkpKoWbMm8fHxjBgxgsDAQOrXr0+fPn1Yu3Yt1apVOyvvr7/+yvr16/nyyy8BSE5OZvv27TRr1qw0b8Fg8DrlKnTsu6OG5ddUJCAQzT3jMEwCAnkzajAAFUIqMfbTeQwbNowvv/ySgwcPEhUV5XQtRFV59913GThwYMnfhMHgR5SrpQKOzSBH41JwP++8qKgo5syZw5dffsmwYcPo1asXc+fOJScnh6SkJJYtW0bnzp2pWrUqp06dys8/cOBAJk+eTFZWFgDbtm0jNTXVk7dmMPgk5aoG4yqtW7fm1KlTNGzYkAYNGnDjjTeycuVK2rVrh4jw+uuvc9FFF1G7dm0CAwNp164do0eP5tFHHyUxMZEOHTqgqtStW5dvv/3W27djMJQ6xsAUwYYNG/K3RYSJEycyceLEs84JDg5m8eLFZ6W98sorvPLKK6Wi0WDwVcpVE6lCSKX8bQkIPOuY477jeQaDwX3KVQ1m7KdnolTkdejmobk5/GNusSOmGgwGB8pVDcZgMJQu5dbAFGwGmWaRwVDylKsmkiN5zaXo6GiiTNPIYPAI5bYGYzAYPI8xMAaDwWMYA2MwGDyGMTAGg8FjGANjMBg8hl+ELbGP/cvOt1VEzDJlg8EPKLVhatsj3STgKiyH32tF5HvbB0wez2A5oposIldg+Y5pam8PB1oDFwMLRSRMVc9eEm0wGHwKfwlbMhSYY/vm3QUk2OUZDAYfxl/CljQEVhXI27DgBRzDlgQFBREdHV2kqJSUFKfO80WMdu/gz9qhdPX72kzeEcAnqvqmiHQDPheRNs5mVtWPsKJDEhoaqn379i0yT3R0NM6c54sY7d7Bn7VD6eovTQPjbNiSQWCFLRGREKCOk3kNBoOP4RdhS+zzhotIRRFphhW2ZE2pKTcYDG5
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 288x288 with 2 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {
|
|||
|
"needs_background": "light"
|
|||
|
},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"ImageNet16-120\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAARgAAAEYCAYAAACHjumMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO2dd3hVVdaH35UOoYUqRYQondCCiDQBC2CBQcWAijC2UdHxs8KMDbGMijqOiqggqFhAsCFWEAIBRAgQeoeIIL0EkpB2s74/zkm4CQnJvbktZL/Pc557zj67/M7JvSu7L1FVDAaDwRsE+VuAwWA4dzEGxmAweA1jYAwGg9cwBsZgMHgNY2AMBoPXMAbGYDB4DWNgDGVCRBwikuR0jPFg3k1EZL2n8jP4nhB/CzCUe06pagd/izAEJqYGY/AKIpIsIq+IyDoRWS4iF9nhTURkvoisFZFfRaSxHV5PRL4WkTX20c3OKlhEJonIBhH5RUQq2fH/KSIb7Xym++kxDSVgDIyhrFQq1ESKc7qXoqoxwNvAG3bYW8BHqtoO+BR40w5/E1ioqu2BTsAGO7wZMEFV2wDHgRvs8DFARzufe7z1cIayIWapgKEsiEiqqlYpIjwZ6KuqO0UkFNivqrVE5DBQX1Wz7fB9qlpbRA4BjVQ10ymPJsBcVW1mX48GQlX1eRH5CUgFvgG+UdVULz+qwQ1MDcbgTbSYc1fIdDp3cLrf8BpgAlZtZ4WImP7EAMQYGIM3iXP6/M0+XwoMtc9vARLs81+BewFEJFhEqheXqYgEAeer6gJgNFAdOKMWZfA/xuobykolEUlyuv5JVfOGqqNEZC1WLWSYHfYAMFVEHgMOAX+3wx8E3heRO7BqKvcC+4opMxj4xDZCArypqsc99kQGj2H6YAxewe6D6ayqh/2txeA/TBPJYDB4DVODMRgMXsPUYAwGg9c4Jw2MiPQXkS0isr2sa2NEZIqIHPTEmhgROV9EFtgzUDeIyINlyCvCniG7xs7rWQ/oCxaR1SIyxwN5JduzeJNEJLGMedUQkVkisllENonIpW7m06LQpMATIvJ/ZdT2kP3+14vI5yISUYa8HrTz2eCOrqK+qyJSU0Tmisg2+zOqDHkNsbXlikjnUolS1XPqwBph2AFEA2HAGqB1GfLrhTXXYr0HtNUHOtnnVYGt7mrDGj2pYp+HAr8DXcuo72HgM2COB541Gajtob/pR8Cd9nkYUMND35P9wAVlyKMhsAuoZF9/AYx0M6+2wHqgMtbo7jzgIhfzOOO7CrwCjLHPxwAvlyGvVkALIB6rA7/EfM7FGkwXYLuq7lTVLGA6MMjdzFR1EXDUE8JUdZ+qrrLPTwKbsL6k7uSlenr2aqh9uN2hJiKNsCavTXY3D29gD0X3Aj4AUNUs9cyQ9OXADlX9o4z5hGAN1YdgGYe/3MynFfC7qqarag6wELjelQyK+a4OwjLQ2J9/czcvVd2kqltc0XQuGpiGwJ9O13tw80fsTexp8B2xah7u5hFsz0E5iDWl3u28sNYKPQ7kliEPZxT4RURWisjdZcinKdZ8mal2822yiER6QN9Q4POyZKCqe4FXgd1Yc3ZSVPUXN7NbD/QUkVoiUhm4Gji/LPps6qlq3nyi/UA9D+RZas5FAxPwiEgV4Evg/1T1hLv5qKpDra0SGgFdRKStm3quBQ6q6kp3tRRBD1XtBAwARolILzfzCcGqqk9U1Y5AGlZV321EJAwYCMwsYz5RWDWEpkADIFJEbnUnL1XdBLwM/AL8BCRhTTj0GGq1c3w6bHwuGpi9FLT8jeywgMBe4Pcl8KmqfuWJPO0mwwKgv5tZdAcG2pPjpgN9ReSTMmraa38eBL7Garq6wx5gj1PtbBaWwSkLA4BVqnqgjPlcAexS1UOqmg18BXQrIU2xqOoHqhqrqr2AY1h9dGXlgIjUB7A/D3ogz1JzLhqYFUAzEWlq/6caCsz2syYARESw+hI2qerrZcyrjojUsM8rAVcCm93JS1X/paqNVLUJ1vuar6pu/Se29USKSNW8c+AqrCaAO9r2A3+KSAs76HJgo7vabIZRxuaRzW6gq4hUtv+2l2P1q7mFiNS1Pxtj9b985gGNs4ER9vkI4FsP5Fl63O1BD+QDq/26FWs06Yky5vU5Vvs6G+u/6R1lyKsHVhV1LVYVOAm42s282gGr7bzWA0976N31poyjSFgjeGvsY4MH/gYdgET7Wb8BosqQVyRwBKjuoff1LJZhXw9MA8LLkFcClvFcA1zuRvozvqtALayFpNuwRqZqliGvwfZ5JnAA+LmkfMxMXoPB4DXOxSaSwWAIEIyBMRgMXsMYGIPB4DWMgTEYDF7DpwampEWIItLYXgy4Wix3FFfb4U1E5JTTIrV3S1FWWWaPejW/QM3L0/lVFG3mOc+CJ4bqSjnsVeIiROB94F77vDWQbJ83wcXFhkCih/V7LL9Azcto839egazNnbx8WYMpzSJEBarZ59Vxf+GYwWAIAHy56XdRixAvKRRnLNYCuQewJkRd4XSvqYisBk4AT6pqQqG0eVW4vGpcbEREhMcm+YSEhOCp/AI1L0/nV1G0VZTnxJ11TJ6s2pVQvboRmOx0PRx4u1Cch4FH7PNLsWY1BgHhQC07PBbLUFU7W3mVK1fW0rBgwYJSxQtEjHb/UJ61q7qvH0jTAG4ilWYR4h1Ym/agqr8BEVibFmWq6hE7fCVWX05zrys2GAxlwpcGpjSLEHdjLRhDRFphGZhD9sK+YDs8Gstf8U6fKTcYDG7hsz4YVc0RkfuBn7FGlKao6gYRGYfVOz0beASYJCIPYbX3Rqqq2nuJjBORbKwNke5RVY/sMmcwGLyHTz07quoPwA+Fwp52Ot+ItTdJ4XRfYu2hYjAYyhFmJq/BYPAaxsAYDAavYQyMwWDwGsbAGAwGr2EMjAdITk7ms888sX2qwXBuYQyMBzAGxmAomgprYN4aMYTX4q7lwM7tvBZ3LW+NGFLg/pgxY5gwYUL+9dixYxk/fjyPPfYYbdu2JSYmhhkzZuTHTUhIoEOHDvz3v//F4XDw2GOPcfHFF9OuXTvee+89nz6bwRAoVFgDk5Vx6qzXcXFxfPHFF/nXX3zxBXXr1iUpKYk1a9Ywb948HnvsMfbt28dLL71Ez549SUpK4qGHHuKDDz6gevXqrFixghUrVjBp0iR27drlk+cyGAIJn060K0907NiRgwcP8tdff3Ho0CGioqJISkpi2LBhBAcHU69ePS677DJWrFhBtWrVCqT95ZdfWLt2LbNmzQIgJSWFbdu20bRpU388isHgNyqUgXlrxJD8mooEBaO5pz1zSlAwr8VdC0BYRCUe+GgmQ4YMYdasWezfv5+4uLhS10JUlbfeeot+/fp5/iEMhnJEhWoiOTeDnI1L4eu8eHFxcUyfPp1Zs2YxZMgQevbsyYwZM3A4HBw6dIhFixbRpUsXqlatysmTJ/PT9+vXj4kTJ5KdnQ3A1q1bSUtL8+ajGQwBSYWqwbhKmzZtOHnyJA0bNqR+/foMHjyY3377jfbt2yMivPLKK5x33nnUqlWL4OBg2rdvz8iRI3nwwQdJTk6mU6dOqCp16tThm2++8ffjGAw+xxiYEli3bl3+uYgwfvx4xo8fXyBOaGgo8+fPLxD24osv8uKLL/pEo8EQqFSoJlJYRKX8cwkKLnDP+do5nsFgcB+f1mBEpD/wP6z9YCar6kuF7jcGPgJq2HHG2Fs8ICL/wtrxzgH8U1V/drX8Bz6amX+e16Gbh+Y6eGTGHFezNBgMZ8FnBsbekW4CcCXWht8rRGS2vQdMHk8CX6jqRBFpjbV3TBP7fCjQBmgAzBOR5qpasKfWYDAEFOXFbckgYLq9N+8uYLudn9sUbgaZZpHB4HnE2izcBwWJ3Aj0V9U77evhwCWqer9TnPrAL0AUttsSVV0pIm8Dy1T1EzveB8CPqjqrUBn5bktCQkJi586dW6Ku1NRUqlSp4olH9DlGu38oz9rBff19+vRJV9VIV9IE2ijSMOBDVX1NRC4FpolI29ImVtX3sbxDEhkZqb179y4xTXx
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 288x288 with 2 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {
|
|||
|
"needs_background": "light"
|
|||
|
},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import itertools\n",
|
|||
|
"sh={\n",
|
|||
|
" 'cifar10':{\n",
|
|||
|
" 'jacob_cov':(40,-.02),\n",
|
|||
|
" 'synflow':(40,.015),\n",
|
|||
|
" 'snip':(10,.02),\n",
|
|||
|
" 'grad_norm':(40,-.025),\n",
|
|||
|
" 'grasp':(40,.0),\n",
|
|||
|
" 'vote':(60,.0)\n",
|
|||
|
" },\n",
|
|||
|
" 'cifar100':{\n",
|
|||
|
" 'jacob_cov':(40,.0),\n",
|
|||
|
" 'synflow':(40,.0),\n",
|
|||
|
" 'snip':(35,-.02),\n",
|
|||
|
" 'grad_norm':(35,.02),\n",
|
|||
|
" 'grasp':(35,.0),\n",
|
|||
|
" 'vote':(60,.0)\n",
|
|||
|
" },\n",
|
|||
|
" 'ImageNet16-120':{\n",
|
|||
|
" 'jacob_cov':(40,.0),\n",
|
|||
|
" 'synflow':(40,.0),\n",
|
|||
|
" 'snip':(40,.02),\n",
|
|||
|
" 'grad_norm':(40,-.01),\n",
|
|||
|
" 'grasp':(40,-.03),\n",
|
|||
|
" 'vote':(60,.0)\n",
|
|||
|
" },\n",
|
|||
|
"}\n",
|
|||
|
"markers = {'synflow':'*','jacob_cov':'x','snip':'o','grad_norm':'s','fisher':'+','grasp':'d', 'vote':'P'}\n",
|
|||
|
"for ds,slow in dslow.items():\n",
|
|||
|
" plt.figure(figsize=(4,4))\n",
|
|||
|
" x=range(0,196*41,196)\n",
|
|||
|
" plt.plot(x,slow)\n",
|
|||
|
" for k,v in allc[ds.upper() if 'cifar' in ds else ds].items():\n",
|
|||
|
" if v < 0.4:\n",
|
|||
|
" continue\n",
|
|||
|
" plt.scatter(1,v, marker=markers[k], s=100)\n",
|
|||
|
" plt.text(1+sh[ds][k][0],v+sh[ds][k][1],f'{k}',horizontalalignment='left')\n",
|
|||
|
" \n",
|
|||
|
" \n",
|
|||
|
" k='vote'\n",
|
|||
|
" v = votes[ds]\n",
|
|||
|
" plt.scatter(3,v, marker=markers['vote'], s=100)\n",
|
|||
|
" plt.text(1+sh[ds][k][0],v+sh[ds][k][1],f'{k}',horizontalalignment='left')\n",
|
|||
|
" \n",
|
|||
|
" if ds == 'cifar10':\n",
|
|||
|
" x2 = [c/3.3 for c in x]\n",
|
|||
|
" plt.plot(x2,econas[3][0:epx], label='econas $r_{16}c_8$', linestyle=ls[4], color='purple', linewidth=1)\n",
|
|||
|
" p=15\n",
|
|||
|
" plt.scatter(p*196/d, econas[3][p], marker='o', color='purple', s=60)\n",
|
|||
|
" plt.annotate('econas+',(p*196/d+30, econas[3][p]-0.02), horizontalalignment='left', color='purple')\n",
|
|||
|
" x3 = [c/4 for c in x]\n",
|
|||
|
" p=20\n",
|
|||
|
" plt.scatter(p*196/4, econas[0][p], marker='p', color='orange', s=80)\n",
|
|||
|
" plt.annotate(f'econas',(p*196/4+50, econas[0][p]-0.005), horizontalalignment='left', color='chocolate')\n",
|
|||
|
"\n",
|
|||
|
" \n",
|
|||
|
" #plt.legend()\n",
|
|||
|
" plt.grid()\n",
|
|||
|
"\n",
|
|||
|
" ax1 = plt.gca()\n",
|
|||
|
" ax1.set_xlim(-100,196*11)\n",
|
|||
|
" ax1.set_ylim(0.4,0.85)\n",
|
|||
|
" ax2 = ax1.twiny()\n",
|
|||
|
" ax2.set_xticks(range(0,12))\n",
|
|||
|
" ax1.set_xlabel(\"Evaluation Cost (number of minibatches)\")\n",
|
|||
|
" ax2.set_xlabel(\"Epochs\")\n",
|
|||
|
"\n",
|
|||
|
" ax1.set_ylabel('Spearman $\\\\rho$')\n",
|
|||
|
" #plt.xscale('log')\n",
|
|||
|
" print(ds)\n",
|
|||
|
" plt.tight_layout()\n",
|
|||
|
" plt.savefig(f'nb2'+(ds if 'cifar' in ds else 'im16')+'.pdf')\n",
|
|||
|
" plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": []
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "Python 3",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.7.6"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 4
|
|||
|
}
|