Remove call to .eval(), update results in README, ignore .t7 files
This commit is contained in:
		
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -1,2 +1,3 @@ | ||||
| *.pth | ||||
| __pycache__ | ||||
| *.t7 | ||||
|   | ||||
							
								
								
									
										11
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								README.md
									
									
									
									
									
								
							| @@ -24,15 +24,16 @@ Each command will finish by calling `process_results.py`, which will print a tab | ||||
|  | ||||
| | Method       |   Search time (s) | CIFAR-10 (val)   | CIFAR-10 (test)   | CIFAR-100 (val)   | CIFAR-100 (test)   | ImageNet16-120 (val)   | ImageNet16-120 (test)   | | ||||
| |:-------------|------------------:|:-----------------|:------------------|:------------------|:-------------------|:-----------------------|:------------------------| | ||||
| | Ours (N=10)  |           1.73435 | 89.25 +- 0.08    | 92.21 +- 0.11     | 68.53 +- 0.17     | 68.40 +- 0.14      | 40.42 +- 1.15          | 40.66 +- 0.97           |        | ||||
| | Ours (N=100) |          17.4139  | 89.18 +- 0.29    | 91.76 +- 1.28     | 67.17 +- 2.79     | 67.27 +- 2.68      | 40.84 +- 5.36          | 41.33 +- 5.74 | ||||
| | Ours (N=10)  |              1.75 | 89.50 +- 0.51    | 92.98 +- 0.82     | 69.80 +- 2.46     | 69.86 +- 2.21      | 42.35 +- 1.19          | 42.38 +- 1.37           | | ||||
| | Ours (N=100) |             17.76 | 87.44 +- 1.45    | 92.27 +- 1.53     | 70.26 +- 1.09     | 69.86 +- 0.60      | 43.30 +- 1.62          | 43.51 +- 1.40        | ||||
|  | ||||
| `./reproduce 500` will produce the following table (which is the same as what we report in the paper): | ||||
| `./reproduce 500` will produce the following table: | ||||
|  | ||||
| | Method       |   Search time (s) | CIFAR-10 (val)   | CIFAR-10 (test)   | CIFAR-100 (val)   | CIFAR-100 (test)   | ImageNet16-120 (val)   | ImageNet16-120 (test)   | | ||||
| |:-------------|------------------:|:-----------------|:------------------|:------------------|:-------------------|:-----------------------|:------------------------| | ||||
| | Ours (N=10) |            1.73435 | 88.47 +- 1.33    | 91.53 +- 1.62     | 66.49 +- 3.08     | 66.63 +- 3.14      | 38.33 +- 4.98          | 38.33 +- 5.22           | | ||||
| | Ours (N=100) |          17.4139  | 88.45 +- 1.46    | 91.61 +- 1.71     | 66.42 +- 3.27     | 66.56 +- 3.28      | 36.56 +- 6.70          | 36.37 +- 6.97 | ||||
| | Ours (N=10)  |              1.67 | 88.61 +- 1.58    | 91.58 +- 1.70     | 67.03 +- 3.01     | 67.15 +- 3.08      | 39.74 +- 4.17          | 39.76 +- 4.39           | | ||||
| | Ours (N=100) |             17.12 | 88.43 +- 1.67    | 91.24 +- 1.70     | 67.04 +- 2.91     | 67.12 +- 2.98      | 40.68 +- 3.41          | 40.67 +- 3.55           | | ||||
|  | ||||
|  | ||||
|  | ||||
| To try different sample sizes, simply change the `--n_samples` argument in the call to `search.py`, and update the list of sample sizes [this line](https://github.com/BayesWatch/nas-without-training/blob/master/process_results.py#L51) of `process_results.py`. | ||||
|   | ||||
| @@ -8,7 +8,7 @@ from statistics import mean | ||||
|  | ||||
| parser = argparse.ArgumentParser(description='NAS Without Training') | ||||
| parser.add_argument('--data_loc', default='../datasets/cifar', type=str, help='dataset folder') | ||||
| parser.add_argument('--api_loc', default='NAS-Bench-201-v1_1-096897.pth', | ||||
| parser.add_argument('--api_loc', default='../datasets/NAS-Bench-201-v1_1-096897.pth', | ||||
|                     type=str, help='path to API') | ||||
| parser.add_argument('--save_loc', default='results', type=str, help='folder to save results') | ||||
| parser.add_argument('--batch_size', default=256, type=int) | ||||
| @@ -116,7 +116,6 @@ for N in runs: | ||||
|  | ||||
|         network = get_cell_based_tiny_net(config)  # create the network from configuration | ||||
|         network = network.to(device) | ||||
|         network.eval() | ||||
|  | ||||
|         jacobs, labels= get_batch_jacobian(network, x, target, 1, device, args) | ||||
|         jacobs = jacobs.reshape(jacobs.size(0), -1).cpu().numpy() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user