With table generator
This commit is contained in:
		| @@ -13,4 +13,10 @@ conda activate nas-wot | ||||
| ./reproduce.sh | ||||
| ``` | ||||
|  | ||||
| Will produce the following table: | ||||
|  | ||||
| | Method       |   Search time (s) | CIFAR-10 (val)   | CIFAR-10 (test)   | CIFAR-100 (val)   | CIFAR-100 (test)   | ImageNet16-120 (val)   | ImageNet16-120 (test)   | | ||||
| |:-------------|------------------:|:-----------------|:------------------|:------------------|:-------------------|:-----------------------|:------------------------| | ||||
| | Ours (N=100) |             18.35 | 89.18 +- 0.29    | 91.76 +- 1.28     | 67.17 +- 2.79     | 67.27 +- 2.68      | 40.84 +- 5.36          | 41.33 +- 5.74           | | ||||
|  | ||||
| The code is licensed under the MIT licence. | ||||
|   | ||||
| @@ -30,10 +30,13 @@ dependencies: | ||||
|   - numpy-base=1.18.1=py38hde5b4d6_1 | ||||
|   - olefile=0.46=py_0 | ||||
|   - openssl=1.1.1g=h7b6447c_0 | ||||
|   - pandas=1.0.3=py38h0573a6f_0 | ||||
|   - pillow=7.1.2=py38hb39fc2d_0 | ||||
|   - pip=20.0.2=py38_3 | ||||
|   - python=3.8.3=hcff3b4d_0 | ||||
|   - python-dateutil=2.8.1=py_0 | ||||
|   - pytorch=1.5.0=py3.8_cuda10.2.89_cudnn7.6.5_0 | ||||
|   - pytz=2020.1=py_0 | ||||
|   - readline=8.0=h7b6447c_0 | ||||
|   - setuptools=46.4.0=py38_0 | ||||
|   - six=1.14.0=py38_0 | ||||
| @@ -48,3 +51,5 @@ dependencies: | ||||
|   - pip: | ||||
|     - argparse==1.4.0 | ||||
|     - nas-bench-201==1.3 | ||||
| prefix: /home/jturner/miniconda3/envs/nas-wot | ||||
|  | ||||
|   | ||||
							
								
								
									
										86
									
								
								process_results.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								process_results.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,86 @@ | ||||
| import numpy as np | ||||
| import argparse | ||||
| import os | ||||
| import random | ||||
| import pandas as pd | ||||
| from collections import OrderedDict | ||||
|  | ||||
| import tabulate | ||||
| parser = argparse.ArgumentParser(description='Produce tables') | ||||
| parser.add_argument('--data_loc', default='../datasets/cifar/', type=str, help='dataset folder') | ||||
| parser.add_argument('--save_loc', default='results', type=str, help='folder to save results') | ||||
|  | ||||
| parser.add_argument('--batch_size', default=256, type=int) | ||||
| parser.add_argument('--GPU', default='0', type=str) | ||||
|  | ||||
| parser.add_argument('--seed', default=1, type=int) | ||||
| parser.add_argument('--trainval', action='store_true') | ||||
|  | ||||
| parser.add_argument('--n_samples', default=100, type=int, help='how many samples to take') | ||||
| parser.add_argument('--n_runs', default=500, type=int) | ||||
|  | ||||
| args = parser.parse_args() | ||||
| os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU | ||||
|  | ||||
| from statistics import mean, median, stdev as std | ||||
|  | ||||
| import torch | ||||
|  | ||||
| torch.backends.cudnn.deterministic = True | ||||
| torch.backends.cudnn.benchmark = False | ||||
| random.seed(args.seed) | ||||
| np.random.seed(args.seed) | ||||
| torch.manual_seed(args.seed) | ||||
|  | ||||
| df = [] | ||||
|  | ||||
| datasets = OrderedDict() | ||||
|  | ||||
| datasets['CIFAR-10 (val)'] = ('cifar10-valid', 'x-valid', True) | ||||
| datasets['CIFAR-10 (test)'] = ('cifar10', 'ori-test', False) | ||||
|  | ||||
| ### CIFAR-100 | ||||
| datasets['CIFAR-100 (val)'] = ('cifar100', 'x-valid', False) | ||||
| datasets['CIFAR-100 (test)'] = ('cifar100', 'x-test', False) | ||||
|  | ||||
| datasets['ImageNet16-120 (val)'] = ('ImageNet16-120', 'x-valid', False) | ||||
| datasets['ImageNet16-120 (test)'] = ('ImageNet16-120', 'x-test', False) | ||||
|  | ||||
|  | ||||
| dataset_top1s = OrderedDict() | ||||
| dataset_top1s['Method'] = f"Ours (N={args.n_samples})" | ||||
| dataset_top1s['Search time (s)'] = np.nan | ||||
|  | ||||
| time = 0. | ||||
|  | ||||
| for dataset, params in datasets.items(): | ||||
|     top1s = [] | ||||
|  | ||||
|     dset =  params[0] | ||||
|     acc_type = 'accs' if 'test' in params[1] else 'val_accs' | ||||
|     filename = f"{args.save_loc}/{dset}_{args.n_runs}_{args.n_samples}_{args.seed}.t7" | ||||
|  | ||||
|     full_scores = torch.load(filename) | ||||
|     if dataset == 'CIFAR-10 (test)': | ||||
|         time = median(full_scores['times']) | ||||
|         dataset_top1s['Search time (s)'] = time | ||||
|     accs = [] | ||||
|     for n in range(args.n_runs): | ||||
|         acc = full_scores[acc_type][n] | ||||
|         accs.append(acc) | ||||
|     dataset_top1s[dataset] = accs | ||||
|  | ||||
| df = pd.DataFrame(dataset_top1s) | ||||
|  | ||||
| df['CIFAR-10 (val)']  = f"{mean(df['CIFAR-10 (val)']):.2f} +- {std(df['CIFAR-10 (val)']):.2f}" | ||||
| df['CIFAR-10 (test)']  = f"{mean(df['CIFAR-10 (test)']):.2f} +- {std(df['CIFAR-10 (test)']):.2f}" | ||||
|  | ||||
| df['CIFAR-100 (val)'] = f"{mean(df['CIFAR-100 (val)']):.2f} +- {std(df['CIFAR-100 (val)']):.2f}" | ||||
| df['CIFAR-100 (test)'] = f"{mean(df['CIFAR-100 (test)']):.2f} +- {std(df['CIFAR-100 (test)']):.2f}" | ||||
|  | ||||
| df['ImageNet16-120 (val)']  = f"{mean(df['ImageNet16-120 (val)']):.2f} +- {std(df['ImageNet16-120 (val)']):.2f}" | ||||
| df['ImageNet16-120 (test)'] = f"{mean(df['ImageNet16-120 (test)']):.2f} +- {std(df['ImageNet16-120 (test)']):.2f}" | ||||
|  | ||||
| df = df.round(2) | ||||
| df = df.iloc[:1] | ||||
| print(tabulate.tabulate(df.values,df.columns, tablefmt="pipe")) | ||||
							
								
								
									
										10
									
								
								reproduce.sh
									
									
									
									
									
										
										
										Normal file → Executable file
									
								
							
							
						
						
									
										10
									
								
								reproduce.sh
									
									
									
									
									
										
										
										Normal file → Executable file
									
								
							| @@ -1,4 +1,6 @@ | ||||
| python search.py --dataset cifar10 --data_loc '../datasets/cifar10' | ||||
| python search.py --dataset cifar10 --trainval --data_loc '../datasets/cifar10' | ||||
| python search.py --dataset cifar100 --data_loc '../datasets/cifar100' | ||||
| python search.py --dataset ImageNet16-120 --data_loc '../datasets/ImageNet16' | ||||
| python search.py --dataset cifar10 --data_loc '../datasets/cifar10' --n_runs 3 | ||||
| python search.py --dataset cifar10 --trainval --data_loc '../datasets/cifar10' --n_runs 3 | ||||
| python search.py --dataset cifar100 --data_loc '../datasets/cifar100' --n_runs 3 | ||||
| python search.py --dataset ImageNet16-120 --data_loc '../datasets/ImageNet16' --n_runs 3 | ||||
|  | ||||
| python process_results.py --n_runs 3 | ||||
|   | ||||
| @@ -153,5 +153,5 @@ state = {'accs': acc, | ||||
|          } | ||||
|  | ||||
| dset = args.dataset if not args.trainval else 'cifar10-valid' | ||||
| fname = f"{args.save_loc}/{dset}_{args.n_runs}_{args.n_samples}_{args.mc_samples}_{args.alpha}_{args.seed}.t7" | ||||
| fname = f"{args.save_loc}/{dset}_{args.n_runs}_{args.n_samples}_{args.seed}.t7" | ||||
| torch.save(state, fname) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user