Update visualization codes for NATS-Bench
This commit is contained in:
		| @@ -33,7 +33,7 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None): | ||||
|   alg2name['REA'] = 'R-EA-SS3' | ||||
|   alg2name['REINFORCE'] = 'REINFORCE-0.01' | ||||
|   alg2name['RANDOM'] = 'RANDOM' | ||||
|   # alg2name['BOHB'] = 'BOHB' | ||||
|   alg2name['BOHB'] = 'BOHB' | ||||
|   for alg, name in alg2name.items(): | ||||
|     alg2path[alg] = os.path.join(ss_dir, dataset, name, 'results.pth') | ||||
|     assert os.path.isfile(alg2path[alg]), 'invalid path : {:}'.format(alg2path[alg]) | ||||
| @@ -59,7 +59,26 @@ def query_performance(api, data, dataset, ticket): | ||||
|     accuracy_a, accuracy_b = info_a['test-accuracy'], info_b['test-accuracy'] | ||||
|     interplate = (time_b-ticket) / (time_b-time_a) * accuracy_a + (ticket-time_a) / (time_b-time_a) * accuracy_b | ||||
|     results.append(interplate) | ||||
|   return sum(results) / len(results) | ||||
|   # return sum(results) / len(results) | ||||
|   return np.mean(results), np.std(results) | ||||
|  | ||||
|  | ||||
| def show_valid_test(api, data, dataset): | ||||
|   valid_accs, test_accs, is_size_space = [], [], api.search_space_name == 'size' | ||||
|   for i, info in data.items(): | ||||
|     time, arch = info['time_w_arch'][-1] | ||||
|     if dataset == 'cifar10': | ||||
|       xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) | ||||
|       test_accs.append(xinfo['test-accuracy']) | ||||
|       xinfo = api.get_more_info(arch, dataset='cifar10-valid', hp=90 if is_size_space else 200, is_random=False) | ||||
|       valid_accs.append(xinfo['valid-accuracy']) | ||||
|     else: | ||||
|       xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) | ||||
|       valid_accs.append(xinfo['valid-accuracy']) | ||||
|       test_accs.append(xinfo['test-accuracy']) | ||||
|   valid_str = '{:.2f}$\pm${:.2f}'.format(np.mean(valid_accs), np.std(valid_accs)) | ||||
|   test_str = '{:.2f}$\pm${:.2f}'.format(np.mean(test_accs), np.std(test_accs)) | ||||
|   return valid_str, test_str | ||||
|  | ||||
|  | ||||
| y_min_s = {('cifar10', 'tss'): 90, | ||||
| @@ -69,11 +88,11 @@ y_min_s = {('cifar10', 'tss'): 90, | ||||
|            ('ImageNet16-120', 'tss'): 36, | ||||
|            ('ImageNet16-120', 'sss'): 40} | ||||
|  | ||||
| y_max_s = {('cifar10', 'tss'): 94.5, | ||||
| y_max_s = {('cifar10', 'tss'): 94.3, | ||||
|            ('cifar10', 'sss'): 93.3, | ||||
|            ('cifar100', 'tss'): 72, | ||||
|            ('cifar100', 'sss'): 70, | ||||
|            ('ImageNet16-120', 'tss'): 44, | ||||
|            ('cifar100', 'tss'): 72.5, | ||||
|            ('cifar100', 'sss'): 70.5, | ||||
|            ('ImageNet16-120', 'tss'): 46, | ||||
|            ('ImageNet16-120', 'sss'): 46} | ||||
|  | ||||
| x_axis_s = {('cifar10', 'tss'): 200, | ||||
| @@ -87,6 +106,7 @@ name2label = {'cifar10': 'CIFAR-10', | ||||
|               'cifar100': 'CIFAR-100', | ||||
|               'ImageNet16-120': 'ImageNet-16-120'} | ||||
|  | ||||
|  | ||||
| def visualize_curve(api, vis_save_dir, search_space): | ||||
|   vis_save_dir = vis_save_dir.resolve() | ||||
|   vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
| @@ -106,11 +126,13 @@ def visualize_curve(api, vis_save_dir, search_space): | ||||
|     ax.set_ylim(y_min_s[(xdataset, search_space)], | ||||
|                 y_max_s[(xdataset, search_space)]) | ||||
|     for idx, (alg, data) in enumerate(alg2data.items()): | ||||
|       print('{:} plot alg : {:}'.format(time_string(), alg)) | ||||
|       accuracies = [] | ||||
|       for ticket in time_tickets: | ||||
|         accuracy = query_performance(api, data, xdataset, ticket) | ||||
|         accuracy, accuracy_std = query_performance(api, data, xdataset, ticket) | ||||
|         accuracies.append(accuracy) | ||||
|       valid_str, test_str = show_valid_test(api, data, xdataset) | ||||
|       # print('{:} plot alg : {:10s}, final accuracy = {:.2f}$\pm${:.2f}'.format(time_string(), alg, accuracy, accuracy_std)) | ||||
|       print('{:} plot alg : {:10s}  | validation = {:} | test = {:}'.format(time_string(), alg, valid_str, test_str)) | ||||
|       alg2accuracies[alg] = accuracies | ||||
|       ax.plot([x/100 for x in time_tickets], accuracies, c=colors[idx], label='{:}'.format(alg)) | ||||
|       ax.set_xlabel('Estimated wall-clock time (1e2 seconds)', fontsize=LabelSize) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user