adjust threshhold for cifar100
This commit is contained in:
		| @@ -10,6 +10,7 @@ api = API('./NAS-Bench-201-v1_1-096897.pth') | |||||||
| parser = argparse.ArgumentParser(description='Process some integers.') | parser = argparse.ArgumentParser(description='Process some integers.') | ||||||
|  |  | ||||||
| parser.add_argument('--file_path', type=str, default='211035.txt',) | parser.add_argument('--file_path', type=str, default='211035.txt',) | ||||||
|  | parser.add_argument('--datasets', type=str, default='cifar10',) | ||||||
| args = parser.parse_args() | args = parser.parse_args() | ||||||
|  |  | ||||||
| def process_graph_data(text): | def process_graph_data(text): | ||||||
| @@ -89,6 +90,7 @@ def nodes_to_arch_str(nodes): | |||||||
|     return arch_str |     return arch_str | ||||||
|  |  | ||||||
| filename = args.file_path | filename = args.file_path | ||||||
|  | datasets_name = args.datasets | ||||||
|  |  | ||||||
| with open('./output_graphs/' + filename, 'r') as f: | with open('./output_graphs/' + filename, 'r') as f: | ||||||
|     texts = f.read() |     texts = f.read() | ||||||
| @@ -96,7 +98,15 @@ with open('./output_graphs/' + filename, 'r') as f: | |||||||
|     valid = 0 |     valid = 0 | ||||||
|     not_valid = 0 |     not_valid = 0 | ||||||
|     scores = [] |     scores = [] | ||||||
|     dist = {'<90':0, '<91':0, '<92':0, '<93':0, '<94':0, '>94':0} |  | ||||||
|  |     # 定义分类标准和分布字典的映射 | ||||||
|  |     thresholds = { | ||||||
|  |         'cifar10': [90, 91, 92, 93, 94], | ||||||
|  |         'cifar100': [68,69,70, 71, 72, 73] | ||||||
|  |     } | ||||||
|  |     dist = {f'<{threshold}': 0 for threshold in thresholds[datasets_name]} | ||||||
|  |     dist[f'>{thresholds[datasets_name][-1]}'] = 0 | ||||||
|  |  | ||||||
|     for i in range(len(df)): |     for i in range(len(df)): | ||||||
|         nodes = df['nodes'][i] |         nodes = df['nodes'][i] | ||||||
|         edges = df['edges'][i] |         edges = df['edges'][i] | ||||||
| @@ -105,32 +115,30 @@ with open('./output_graphs/' + filename, 'r') as f: | |||||||
|             valid += 1 |             valid += 1 | ||||||
|             arch_str = nodes_to_arch_str(nodes) |             arch_str = nodes_to_arch_str(nodes) | ||||||
|             index = api.query_index_by_arch(arch_str) |             index = api.query_index_by_arch(arch_str) | ||||||
|             # results = api.query_by_index(index, 'cifar10', hp='200') |             res = api.get_more_info(index, datasets_name, None, hp=200, is_random=False) | ||||||
|             # print(results) |  | ||||||
|             # result = results[888].get_eval('ori-test') |  | ||||||
|             res = api.get_more_info(index, 'cifar10', None, hp=200, is_random=False) |  | ||||||
|             acc = res['test-accuracy'] |             acc = res['test-accuracy'] | ||||||
|             scores.append((index, acc)) |             scores.append((index, acc)) | ||||||
|             if acc < 90: |  | ||||||
|                 dist['<90'] += 1 |             # 根据阈值更新分布 | ||||||
|             elif acc < 91 and acc >= 90: |             updated = False | ||||||
|                 dist['<91'] += 1 |             for threshold in thresholds[datasets_name]: | ||||||
|             elif acc < 92 and acc >= 91: |                 if acc < threshold: | ||||||
|                 dist['<92'] += 1 |                     dist[f'<{threshold}'] += 1 | ||||||
|             elif acc < 93 and acc >= 92:  |                     updated = True | ||||||
|                 dist['<93'] += 1 |                     break | ||||||
|             elif acc < 94 and acc >= 93: |             if not updated: | ||||||
|                 dist['<94'] += 1 |                 dist[f'>{thresholds[datasets_name][-1]}'] += 1 | ||||||
|             else:     |  | ||||||
|                 dist['>94'] += 1 |  | ||||||
|         else: |         else: | ||||||
|             not_valid += 1 |             not_valid += 1 | ||||||
|     with open('./output_graphs/' + filename + '.json', 'w') as f: |  | ||||||
|  |     with open('./output_graphs/' + filename + '_' + datasets_name +'.json', 'w') as f: | ||||||
|         json.dump(scores, f) |         json.dump(scores, f) | ||||||
|  |  | ||||||
|     print(scores) |     print(scores) | ||||||
|     print(valid, not_valid) |     print(valid, not_valid) | ||||||
|     print(dist) |     print(dist) | ||||||
|     print("mean: ", np.mean([x[1] for x in scores])) |     print("mean: ", np.mean([x[1] for x in scores])) | ||||||
|     print("max: ", np.max([x[1] for x in scores])) |     print("max: ", np.max([x[1] for x in scores])) | ||||||
|     print("min: ", np.min([x[1] for x in scores])) |     print("min: ", np.min([x[1] for x in scores])) | ||||||
|  |  | ||||||
|          |          | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user