Add int search space
This commit is contained in:
		| @@ -42,7 +42,9 @@ def fetch_data(root_dir="./output/search", search_space="tss", dataset=None): | ||||
|     for alg, path in alg2path.items(): | ||||
|         data = torch.load(path) | ||||
|         for index, info in data.items(): | ||||
|             info["time_w_arch"] = [(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])] | ||||
|             info["time_w_arch"] = [ | ||||
|                 (x, y) for x, y in zip(info["all_total_times"], info["all_archs"]) | ||||
|             ] | ||||
|             for j, arch in enumerate(info["all_archs"]): | ||||
|                 assert arch != -1, "invalid arch from {:} {:} {:} ({:}, {:})".format( | ||||
|                     alg, search_space, dataset, index, j | ||||
| @@ -54,15 +56,28 @@ def fetch_data(root_dir="./output/search", search_space="tss", dataset=None): | ||||
| def get_valid_test_acc(api, arch, dataset): | ||||
|     is_size_space = api.search_space_name == "size" | ||||
|     if dataset == "cifar10": | ||||
|         xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) | ||||
|         xinfo = api.get_more_info( | ||||
|             arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False | ||||
|         ) | ||||
|         test_acc = xinfo["test-accuracy"] | ||||
|         xinfo = api.get_more_info(arch, dataset="cifar10-valid", hp=90 if is_size_space else 200, is_random=False) | ||||
|         xinfo = api.get_more_info( | ||||
|             arch, | ||||
|             dataset="cifar10-valid", | ||||
|             hp=90 if is_size_space else 200, | ||||
|             is_random=False, | ||||
|         ) | ||||
|         valid_acc = xinfo["valid-accuracy"] | ||||
|     else: | ||||
|         xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) | ||||
|         xinfo = api.get_more_info( | ||||
|             arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False | ||||
|         ) | ||||
|         valid_acc = xinfo["valid-accuracy"] | ||||
|         test_acc = xinfo["test-accuracy"] | ||||
|     return valid_acc, test_acc, "validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc) | ||||
|     return ( | ||||
|         valid_acc, | ||||
|         test_acc, | ||||
|         "validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc), | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def show_valid_test(api, arch): | ||||
| @@ -84,8 +99,16 @@ def find_best_valid(api, dataset): | ||||
|     best_test_index = sorted(all_test_accs, key=lambda x: -x[1])[0][0] | ||||
|  | ||||
|     print("-" * 50 + "{:10s}".format(dataset) + "-" * 50) | ||||
|     print("Best ({:}) architecture on validation: {:}".format(best_valid_index, api[best_valid_index])) | ||||
|     print("Best ({:}) architecture on       test: {:}".format(best_test_index, api[best_test_index])) | ||||
|     print( | ||||
|         "Best ({:}) architecture on validation: {:}".format( | ||||
|             best_valid_index, api[best_valid_index] | ||||
|         ) | ||||
|     ) | ||||
|     print( | ||||
|         "Best ({:}) architecture on       test: {:}".format( | ||||
|             best_test_index, api[best_test_index] | ||||
|         ) | ||||
|     ) | ||||
|     _, _, perf_str = get_valid_test_acc(api, best_valid_index, dataset) | ||||
|     print("using validation ::: {:}".format(perf_str)) | ||||
|     _, _, perf_str = get_valid_test_acc(api, best_test_index, dataset) | ||||
| @@ -130,10 +153,14 @@ def show_multi_trial(search_space): | ||||
|                 v_acc, t_acc = query_performance(api, x, xdataset, float(max_time)) | ||||
|                 valid_accs.append(v_acc) | ||||
|                 test_accs.append(t_acc) | ||||
|             valid_str = "{:.2f}$\pm${:.2f}".format(np.mean(valid_accs), np.std(valid_accs)) | ||||
|             valid_str = "{:.2f}$\pm${:.2f}".format( | ||||
|                 np.mean(valid_accs), np.std(valid_accs) | ||||
|             ) | ||||
|             test_str = "{:.2f}$\pm${:.2f}".format(np.mean(test_accs), np.std(test_accs)) | ||||
|             print( | ||||
|                 "{:} plot alg : {:10s}  | validation = {:} | test = {:}".format(time_string(), alg, valid_str, test_str) | ||||
|                 "{:} plot alg : {:10s}  | validation = {:} | test = {:}".format( | ||||
|                     time_string(), alg, valid_str, test_str | ||||
|                 ) | ||||
|             ) | ||||
|  | ||||
|     if search_space == "tss": | ||||
|   | ||||
		Reference in New Issue
	
	Block a user