change 201 acc to more info func
This commit is contained in:
parent
aead4df707
commit
ff85bba9cd
@ -55,7 +55,8 @@ if __name__ == "__main__":
|
||||
|
||||
results = []
|
||||
|
||||
nasbench_len = 15625
|
||||
# nasbench_len = 15625
|
||||
nasbench_len = 15
|
||||
|
||||
# for index, i in arch_info.iterrows():
|
||||
for i in range(nasbench_len):
|
||||
@ -64,10 +65,12 @@ if __name__ == "__main__":
|
||||
|
||||
config = api.get_net_config(i, 'cifar10')
|
||||
network = get_cell_based_tiny_net(config)
|
||||
nas_results = api.query_by_index(i, 'cifar10')
|
||||
acc = nas_results[111].get_eval('ori-test')
|
||||
# nas_results = api.query_by_index(i, 'cifar10')
|
||||
# acc = nas_results[111].get_eval('ori-test')
|
||||
nas_results = api.get_more_info(i, 'cifar10', None, hp=200, is_random=False)
|
||||
acc = nas_results['test-accuracy']
|
||||
|
||||
print(type(network))
|
||||
# print(type(network))
|
||||
start_time = time.time()
|
||||
|
||||
# network = Network(3, 10, 1, eval(i.genotype))
|
||||
@ -96,10 +99,10 @@ if __name__ == "__main__":
|
||||
results.append([np.mean(swap_score), acc, i])
|
||||
|
||||
results = pd.DataFrame(results, columns=['swap_score', 'valid_acc', 'index'])
|
||||
results.to_csv('swap_results.csv', float_format='%.4f', index=False)
|
||||
|
||||
print()
|
||||
print(f'Spearman\'s Correlation Coefficient: {stats.spearmanr(results.swap_score, results.valid_acc)[0]}')
|
||||
results.to_csv('swap_results.csv', float_format='%.4f', index=False)
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user