change 201 acc to more info func
This commit is contained in:
		| @@ -55,7 +55,8 @@ if __name__ == "__main__": | ||||
|  | ||||
|     results = [] | ||||
|  | ||||
|     nasbench_len = 15625 | ||||
|     # nasbench_len = 15625 | ||||
|     nasbench_len = 15 | ||||
|      | ||||
|     # for index, i in arch_info.iterrows(): | ||||
|     for i in range(nasbench_len): | ||||
| @@ -64,10 +65,12 @@ if __name__ == "__main__": | ||||
|  | ||||
|         config = api.get_net_config(i, 'cifar10') | ||||
|         network = get_cell_based_tiny_net(config) | ||||
|         nas_results = api.query_by_index(i, 'cifar10') | ||||
|         acc = nas_results[111].get_eval('ori-test') | ||||
|         # nas_results = api.query_by_index(i, 'cifar10') | ||||
|         # acc = nas_results[111].get_eval('ori-test') | ||||
|         nas_results = api.get_more_info(i, 'cifar10', None, hp=200, is_random=False) | ||||
|         acc = nas_results['test-accuracy'] | ||||
|  | ||||
|         print(type(network)) | ||||
|         # print(type(network)) | ||||
|         start_time = time.time() | ||||
|  | ||||
|         # network = Network(3, 10, 1, eval(i.genotype)) | ||||
| @@ -96,10 +99,10 @@ if __name__ == "__main__": | ||||
|         results.append([np.mean(swap_score), acc, i]) | ||||
|  | ||||
|     results = pd.DataFrame(results, columns=['swap_score', 'valid_acc', 'index']) | ||||
|     results.to_csv('swap_results.csv', float_format='%.4f', index=False) | ||||
|  | ||||
|     print()     | ||||
|     print(f'Spearman\'s Correlation Coefficient: {stats.spearmanr(results.swap_score, results.valid_acc)[0]}') | ||||
|     results.to_csv('swap_results.csv', float_format='%.4f', index=False) | ||||
|  | ||||
|      | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user