adjust threshhold for cifar100
This commit is contained in:
		| @@ -10,6 +10,7 @@ api = API('./NAS-Bench-201-v1_1-096897.pth') | ||||
| parser = argparse.ArgumentParser(description='Process some integers.') | ||||
|  | ||||
| parser.add_argument('--file_path', type=str, default='211035.txt',) | ||||
| parser.add_argument('--datasets', type=str, default='cifar10',) | ||||
| args = parser.parse_args() | ||||
|  | ||||
| def process_graph_data(text): | ||||
| @@ -89,6 +90,7 @@ def nodes_to_arch_str(nodes): | ||||
|     return arch_str | ||||
|  | ||||
| filename = args.file_path | ||||
| datasets_name = args.datasets | ||||
|  | ||||
| with open('./output_graphs/' + filename, 'r') as f: | ||||
|     texts = f.read() | ||||
| @@ -96,7 +98,15 @@ with open('./output_graphs/' + filename, 'r') as f: | ||||
|     valid = 0 | ||||
|     not_valid = 0 | ||||
|     scores = [] | ||||
|     dist = {'<90':0, '<91':0, '<92':0, '<93':0, '<94':0, '>94':0} | ||||
|  | ||||
|     # 定义分类标准和分布字典的映射 | ||||
|     thresholds = { | ||||
|         'cifar10': [90, 91, 92, 93, 94], | ||||
|         'cifar100': [68,69,70, 71, 72, 73] | ||||
|     } | ||||
|     dist = {f'<{threshold}': 0 for threshold in thresholds[datasets_name]} | ||||
|     dist[f'>{thresholds[datasets_name][-1]}'] = 0 | ||||
|  | ||||
|     for i in range(len(df)): | ||||
|         nodes = df['nodes'][i] | ||||
|         edges = df['edges'][i] | ||||
| @@ -105,32 +115,30 @@ with open('./output_graphs/' + filename, 'r') as f: | ||||
|             valid += 1 | ||||
|             arch_str = nodes_to_arch_str(nodes) | ||||
|             index = api.query_index_by_arch(arch_str) | ||||
|             # results = api.query_by_index(index, 'cifar10', hp='200') | ||||
|             # print(results) | ||||
|             # result = results[888].get_eval('ori-test') | ||||
|             res = api.get_more_info(index, 'cifar10', None, hp=200, is_random=False) | ||||
|             res = api.get_more_info(index, datasets_name, None, hp=200, is_random=False) | ||||
|             acc = res['test-accuracy'] | ||||
|             scores.append((index, acc)) | ||||
|             if acc < 90: | ||||
|                 dist['<90'] += 1 | ||||
|             elif acc < 91 and acc >= 90: | ||||
|                 dist['<91'] += 1 | ||||
|             elif acc < 92 and acc >= 91: | ||||
|                 dist['<92'] += 1 | ||||
|             elif acc < 93 and acc >= 92:  | ||||
|                 dist['<93'] += 1 | ||||
|             elif acc < 94 and acc >= 93: | ||||
|                 dist['<94'] += 1 | ||||
|             else:     | ||||
|                 dist['>94'] += 1 | ||||
|  | ||||
|             # 根据阈值更新分布 | ||||
|             updated = False | ||||
|             for threshold in thresholds[datasets_name]: | ||||
|                 if acc < threshold: | ||||
|                     dist[f'<{threshold}'] += 1 | ||||
|                     updated = True | ||||
|                     break | ||||
|             if not updated: | ||||
|                 dist[f'>{thresholds[datasets_name][-1]}'] += 1 | ||||
|         else: | ||||
|             not_valid += 1 | ||||
|     with open('./output_graphs/' + filename + '.json', 'w') as f: | ||||
|  | ||||
|     with open('./output_graphs/' + filename + '_' + datasets_name +'.json', 'w') as f: | ||||
|         json.dump(scores, f) | ||||
|  | ||||
|     print(scores) | ||||
|     print(valid, not_valid) | ||||
|     print(dist) | ||||
|     print("mean: ", np.mean([x[1] for x in scores])) | ||||
|     print("max: ", np.max([x[1] for x in scores])) | ||||
|     print("min: ", np.min([x[1] for x in scores])) | ||||
|  | ||||
|          | ||||
|   | ||||
		Reference in New Issue
	
	Block a user