From ff85bba9cd17962c53d9a38ee662f6643e5b5996 Mon Sep 17 00:00:00 2001 From: Mhrooz Date: Mon, 26 Aug 2024 10:50:54 +0200 Subject: [PATCH] change 201 acc to more info func --- correlation.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/correlation.py b/correlation.py index fbaa828..956d8b6 100644 --- a/correlation.py +++ b/correlation.py @@ -55,7 +55,8 @@ if __name__ == "__main__": results = [] - nasbench_len = 15625 + # nasbench_len = 15625 + nasbench_len = 15 # for index, i in arch_info.iterrows(): for i in range(nasbench_len): @@ -64,10 +65,12 @@ if __name__ == "__main__": config = api.get_net_config(i, 'cifar10') network = get_cell_based_tiny_net(config) - nas_results = api.query_by_index(i, 'cifar10') - acc = nas_results[111].get_eval('ori-test') + # nas_results = api.query_by_index(i, 'cifar10') + # acc = nas_results[111].get_eval('ori-test') + nas_results = api.get_more_info(i, 'cifar10', None, hp=200, is_random=False) + acc = nas_results['test-accuracy'] - print(type(network)) + # print(type(network)) start_time = time.time() # network = Network(3, 10, 1, eval(i.genotype)) @@ -96,10 +99,10 @@ if __name__ == "__main__": results.append([np.mean(swap_score), acc, i]) results = pd.DataFrame(results, columns=['swap_score', 'valid_acc', 'index']) + results.to_csv('swap_results.csv', float_format='%.4f', index=False) print() print(f'Spearman\'s Correlation Coefficient: {stats.spearmanr(results.swap_score, results.valid_acc)[0]}') - results.to_csv('swap_results.csv', float_format='%.4f', index=False)