Final check
This commit is contained in:
		
							
								
								
									
										17
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								README.md
									
									
									
									
									
								
							| @@ -10,22 +10,25 @@ To reproduce our results: | |||||||
| conda env create -f environment.yml | conda env create -f environment.yml | ||||||
|  |  | ||||||
| conda activate nas-wot | conda activate nas-wot | ||||||
| ./reproduce.sh | ./reproduce.sh 3 # average accuracy over 3 runs | ||||||
|  | ./reproduce.sh 500 # average accuracy over 500 runs (this will take longer) | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| For a quick run you can set `--n_runs 3` to get results after 3 runs: | Each command will finish by calling `process_results.py`, which will print a table. `./reproduce.sh 3` should print the following table: | ||||||
|  |  | ||||||
| | Method       |   Search time (s) | CIFAR-10 (val)   | CIFAR-10 (test)   | CIFAR-100 (val)   | CIFAR-100 (test)   | ImageNet16-120 (val)   | ImageNet16-120 (test)   | | | Method       |   Search time (s) | CIFAR-10 (val)   | CIFAR-10 (test)   | CIFAR-100 (val)   | CIFAR-100 (test)   | ImageNet16-120 (val)   | ImageNet16-120 (test)   | | ||||||
| |:-------------|------------------:|:-----------------|:------------------|:------------------|:-------------------|:-----------------------|:------------------------| | |:-------------|------------------:|:-----------------|:------------------|:------------------|:-------------------|:-----------------------|:------------------------| | ||||||
| | Ours (N=10)  |           1.73435 | 88.99 $\pm$ 0.24 | 92.42 $\pm$ 0.33  | 67.86 $\pm$ 0.49  | 67.54 $\pm$ 0.75   | 41.16 $\pm$ 2.31       | 40.98 $\pm$ 2.72        | | | Ours (N=10)  |           1.73435 | 88.99 +- 0.24    | 92.42 +- 0.33     | 67.86 +- 0.49     | 67.54 +- 0.75      | 41.16 +- 2.31          | 40.98 +- 2.72           | | ||||||
| | Ours (N=100) |          17.4139  | 89.18 $\pm$ 0.29 | 91.76 $\pm$ 1.28  | 67.17 $\pm$ 2.79  | 67.27 $\pm$ 2.68   | 40.84 $\pm$ 5.36       | 41.33 $\pm$ 5.74 | | Ours (N=100) |          17.4139  | 89.18 +- 0.29    | 91.76 +- 1.28     | 67.17 +- 2.79     | 67.27 +- 2.68      | 40.84 +- 5.36          | 41.33 +- 5.74 | ||||||
|  |  | ||||||
| The size of `N` is set with `--n_samples 10`. To produce the results in the paper, set `--n_runs 500`: | `./reproduce 500` will produce the following table (which is the same as what we report in the paper): | ||||||
|  |  | ||||||
| | Method       |   Search time (s) | CIFAR-10 (val)   | CIFAR-10 (test)   | CIFAR-100 (val)   | CIFAR-100 (test)   | ImageNet16-120 (val)   | ImageNet16-120 (test)   | | | Method       |   Search time (s) | CIFAR-10 (val)   | CIFAR-10 (test)   | CIFAR-100 (val)   | CIFAR-100 (test)   | ImageNet16-120 (val)   | ImageNet16-120 (test)   | | ||||||
| |:-------------|------------------:|:-----------------|:------------------|:------------------|:-------------------|:-----------------------|:------------------------| | |:-------------|------------------:|:-----------------|:------------------|:------------------|:-------------------|:-----------------------|:------------------------| | ||||||
| | Ours (N=10)  |           1.73435 | 89.25 $\pm$ 0.08 | 92.21 $\pm$ 0.11  | 68.53 $\pm$ 0.17  | 68.40 $\pm$ 0.14   | 40.42 $\pm$ 1.15       | 40.66 $\pm$ 0.97        | | | Ours (N=10)  |           1.73435 | 89.25 +- 0.08    | 92.21 +- 0.11     | 68.53 +- 0.17     | 68.40 +- 0.14      | 40.42 +- 1.15          | 40.66 +- 0.97           | | ||||||
| | Ours (N=100) |          17.4139  | 88.45 $\pm$ 1.46 | 91.61 $\pm$ 1.71  | 66.42 $\pm$ 3.27  | 66.56 $\pm$ 3.28   | 36.56 $\pm$ 6.70       | 36.37 $\pm$ 6.97 | | Ours (N=100) |          17.4139  | 88.45 +- 1.46    | 91.61 +- 1.71     | 66.42 +- 3.27     | 66.56 +- 3.28      | 36.56 +- 6.70          | 36.37 +- 6.97 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | To try different sample sizes, simply change the `--n_samples` argument in the call to `search.py`, and update the list of sample sizes on line 51 of `process_results.py`. | ||||||
|  |  | ||||||
| The code is licensed under the MIT licence. | The code is licensed under the MIT licence. | ||||||
|   | |||||||
| @@ -51,5 +51,4 @@ dependencies: | |||||||
|   - pip: |   - pip: | ||||||
|     - argparse==1.4.0 |     - argparse==1.4.0 | ||||||
|     - nas-bench-201==1.3 |     - nas-bench-201==1.3 | ||||||
| prefix: /home/jturner/miniconda3/envs/nas-wot |     - tabulate==0.8.7 | ||||||
|  |  | ||||||
|   | |||||||
| @@ -63,25 +63,25 @@ for n_samples in [10, 100]: | |||||||
|         full_scores = torch.load(filename) |         full_scores = torch.load(filename) | ||||||
|         if dataset == 'CIFAR-10 (test)': |         if dataset == 'CIFAR-10 (test)': | ||||||
|             time = median(full_scores['times']) |             time = median(full_scores['times']) | ||||||
|             dataset_top1s['Search time (s)'] = time |             time = f"{time:.2f}" | ||||||
|         accs = [] |         accs = [] | ||||||
|         for n in range(args.n_runs): |         for n in range(args.n_runs): | ||||||
|             acc = full_scores[acc_type][n] |             acc = full_scores[acc_type][n] | ||||||
|             accs.append(acc) |             accs.append(acc) | ||||||
|         dataset_top1s[dataset] = accs |         dataset_top1s[dataset] = accs | ||||||
|  |  | ||||||
|     cifar10_val  = f"{mean(dataset_top1s['CIFAR-10 (val)']):.2f} $\pm$ {std(dataset_top1s['CIFAR-10 (val)']):.2f}" |     cifar10_val  = f"{mean(dataset_top1s['CIFAR-10 (val)']):.2f} +- {std(dataset_top1s['CIFAR-10 (val)']):.2f}" | ||||||
|     cifar10_test = f"{mean(dataset_top1s['CIFAR-10 (test)']):.2f} $\pm$ {std(dataset_top1s['CIFAR-10 (test)']):.2f}" |     cifar10_test = f"{mean(dataset_top1s['CIFAR-10 (test)']):.2f} +- {std(dataset_top1s['CIFAR-10 (test)']):.2f}" | ||||||
|  |  | ||||||
|     cifar100_val  = f"{mean(dataset_top1s['CIFAR-100 (val)']):.2f} $\pm$ {std(dataset_top1s['CIFAR-100 (val)']):.2f}" |     cifar100_val  = f"{mean(dataset_top1s['CIFAR-100 (val)']):.2f} +- {std(dataset_top1s['CIFAR-100 (val)']):.2f}" | ||||||
|     cifar100_test = f"{mean(dataset_top1s['CIFAR-100 (test)']):.2f} $\pm$ {std(dataset_top1s['CIFAR-100 (test)']):.2f}" |     cifar100_test = f"{mean(dataset_top1s['CIFAR-100 (test)']):.2f} +- {std(dataset_top1s['CIFAR-100 (test)']):.2f}" | ||||||
|  |  | ||||||
|     imagenet_val  = f"{mean(dataset_top1s['ImageNet16-120 (val)']):.2f} $\pm$ {std(dataset_top1s['ImageNet16-120 (val)']):.2f}" |     imagenet_val  = f"{mean(dataset_top1s['ImageNet16-120 (val)']):.2f} +- {std(dataset_top1s['ImageNet16-120 (val)']):.2f}" | ||||||
|     imagenet_test = f"{mean(dataset_top1s['ImageNet16-120 (test)']):.2f} $\pm$ {std(dataset_top1s['ImageNet16-120 (test)']):.2f}" |     imagenet_test = f"{mean(dataset_top1s['ImageNet16-120 (test)']):.2f} +- {std(dataset_top1s['ImageNet16-120 (test)']):.2f}" | ||||||
|  |  | ||||||
|     df.append([method, time, cifar10_val, cifar10_test, cifar100_val, cifar100_test, imagenet_val, imagenet_test]) |     df.append([method, time, cifar10_val, cifar10_test, cifar100_val, cifar100_test, imagenet_val, imagenet_test]) | ||||||
|  |  | ||||||
|  |  | ||||||
| df = pd.DataFrame(df, columns=['Method','Search time (s)','CIFAR-10 (val)','CIFAR-10 (test)','CIFAR-100 (val)','CIFAR-100 (test)','ImageNet16-120 (val)','ImageNet16-120 (test)' ]) | df = pd.DataFrame(df, columns=['Method','Search time (s)','CIFAR-10 (val)','CIFAR-10 (test)','CIFAR-100 (val)','CIFAR-100 (test)','ImageNet16-120 (val)','ImageNet16-120 (test)' ]) | ||||||
| df.round(2) |  | ||||||
| print(tabulate.tabulate(df.values,df.columns, tablefmt="pipe")) | print(tabulate.tabulate(df.values,df.columns, tablefmt="pipe")) | ||||||
|   | |||||||
							
								
								
									
										20
									
								
								reproduce.sh
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								reproduce.sh
									
									
									
									
									
								
							| @@ -1,11 +1,13 @@ | |||||||
| #python search.py --dataset cifar10 --data_loc '../datasets/cifar10' --n_runs 3 --n_samples 10 | #!/bin/bash | ||||||
| #python search.py --dataset cifar10 --trainval --data_loc '../datasets/cifar10' --n_runs 3 --n_samples 10 |  | ||||||
| #python search.py --dataset cifar100 --data_loc '../datasets/cifar100' --n_runs 3 --n_samples 10 |  | ||||||
| #python search.py --dataset ImageNet16-120 --data_loc '../datasets/ImageNet16' --n_runs 3 --n_samples 10 |  | ||||||
|  |  | ||||||
| python search.py --dataset cifar10 --data_loc '../datasets/cifar10'            --n_runs 3 --n_samples 100 | python search.py --dataset cifar10 --data_loc '../datasets/cifar10'            --n_runs $1 --n_samples 10 | ||||||
| python search.py --dataset cifar10 --trainval --data_loc '../datasets/cifar10' --n_runs 3 --n_samples 100 | python search.py --dataset cifar10 --trainval --data_loc '../datasets/cifar10' --n_runs $1 --n_samples 10 | ||||||
| python search.py --dataset cifar100 --data_loc '../datasets/cifar100'          --n_runs 3 --n_samples 100 | python search.py --dataset cifar100 --data_loc '../datasets/cifar100'          --n_runs $1 --n_samples 10 | ||||||
| python search.py --dataset ImageNet16-120 --data_loc '../datasets/ImageNet16'  --n_runs 3 --n_samples 100 | python search.py --dataset ImageNet16-120 --data_loc '../datasets/ImageNet16'  --n_runs $1 --n_samples 10 | ||||||
|  |  | ||||||
| python process_results.py --n_runs 3 | python search.py --dataset cifar10 --data_loc '../datasets/cifar10'            --n_runs $1 --n_samples 100 | ||||||
|  | python search.py --dataset cifar10 --trainval --data_loc '../datasets/cifar10' --n_runs $1 --n_samples 100 | ||||||
|  | python search.py --dataset cifar100 --data_loc '../datasets/cifar100'          --n_runs $1 --n_samples 100 | ||||||
|  | python search.py --dataset ImageNet16-120 --data_loc '../datasets/ImageNet16'  --n_runs $1 --n_samples 100 | ||||||
|  |  | ||||||
|  | python process_results.py --n_runs $1 | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user