Compare commits
	
		
			3 Commits
		
	
	
		
			aead4df707
			...
			aa4b38a0cc
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| aa4b38a0cc | |||
| f72990a675 | |||
| ff85bba9cd | 
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -1,2 +1,3 @@ | |||||||
| __pycache__/ | __pycache__/ | ||||||
| datasets/ | datasets/ | ||||||
|  | swap_results.csv | ||||||
| @@ -55,19 +55,22 @@ if __name__ == "__main__": | |||||||
|  |  | ||||||
|     results = [] |     results = [] | ||||||
|  |  | ||||||
|  |     # nasbench_len = 15625 | ||||||
|     nasbench_len = 15625 |     nasbench_len = 15625 | ||||||
|      |      | ||||||
|     # for index, i in arch_info.iterrows(): |     # for index, i in arch_info.iterrows(): | ||||||
|     for i in range(nasbench_len): |     for ind in range(nasbench_len): | ||||||
|         # print(f'Evaluating network: {index}') |         # print(f'Evaluating network: {index}') | ||||||
|         print(f'Evaluating network: {i}') |         print(f'Evaluating network: {ind}') | ||||||
|  |  | ||||||
|         config = api.get_net_config(i, 'cifar10') |         config = api.get_net_config(ind, 'cifar10') | ||||||
|         network = get_cell_based_tiny_net(config) |         network = get_cell_based_tiny_net(config) | ||||||
|         nas_results = api.query_by_index(i, 'cifar10') |         # nas_results = api.query_by_index(i, 'cifar10') | ||||||
|         acc = nas_results[111].get_eval('ori-test') |         # acc = nas_results[111].get_eval('ori-test') | ||||||
|  |         nas_results = api.get_more_info(ind, 'cifar10', None, hp=200, is_random=False) | ||||||
|  |         acc = nas_results['test-accuracy'] | ||||||
|  |  | ||||||
|         print(type(network)) |         # print(type(network)) | ||||||
|         start_time = time.time() |         start_time = time.time() | ||||||
|  |  | ||||||
|         # network = Network(3, 10, 1, eval(i.genotype)) |         # network = Network(3, 10, 1, eval(i.genotype)) | ||||||
| @@ -93,13 +96,13 @@ if __name__ == "__main__": | |||||||
|         print(f'Average SWAP score: {np.mean(swap_score)}') |         print(f'Average SWAP score: {np.mean(swap_score)}') | ||||||
|         print(f'Elapsed time: {end_time - start_time:.2f} seconds') |         print(f'Elapsed time: {end_time - start_time:.2f} seconds') | ||||||
|  |  | ||||||
|         results.append([np.mean(swap_score), acc, i]) |         results.append([np.mean(swap_score), acc, ind]) | ||||||
|  |  | ||||||
|     results = pd.DataFrame(results, columns=['swap_score', 'valid_acc', 'index']) |     results = pd.DataFrame(results, columns=['swap_score', 'valid_acc', 'index']) | ||||||
|  |     results.to_csv('output/swap_results.csv', float_format='%.4f', index=False) | ||||||
|  |  | ||||||
|     print()     |     print()     | ||||||
|     print(f'Spearman\'s Correlation Coefficient: {stats.spearmanr(results.swap_score, results.valid_acc)[0]}') |     print(f'Spearman\'s Correlation Coefficient: {stats.spearmanr(results.swap_score, results.valid_acc)[0]}') | ||||||
|     results.to_csv('swap_results.csv', float_format='%.4f', index=False) |  | ||||||
|  |  | ||||||
|      |      | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user