Add int search space
This commit is contained in:
		| @@ -30,15 +30,28 @@ from log_utils import time_string | ||||
| def get_valid_test_acc(api, arch, dataset): | ||||
|     is_size_space = api.search_space_name == "size" | ||||
|     if dataset == "cifar10": | ||||
|         xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) | ||||
|         xinfo = api.get_more_info( | ||||
|             arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False | ||||
|         ) | ||||
|         test_acc = xinfo["test-accuracy"] | ||||
|         xinfo = api.get_more_info(arch, dataset="cifar10-valid", hp=90 if is_size_space else 200, is_random=False) | ||||
|         xinfo = api.get_more_info( | ||||
|             arch, | ||||
|             dataset="cifar10-valid", | ||||
|             hp=90 if is_size_space else 200, | ||||
|             is_random=False, | ||||
|         ) | ||||
|         valid_acc = xinfo["valid-accuracy"] | ||||
|     else: | ||||
|         xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) | ||||
|         xinfo = api.get_more_info( | ||||
|             arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False | ||||
|         ) | ||||
|         valid_acc = xinfo["valid-accuracy"] | ||||
|         test_acc = xinfo["test-accuracy"] | ||||
|     return valid_acc, test_acc, "validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc) | ||||
|     return ( | ||||
|         valid_acc, | ||||
|         test_acc, | ||||
|         "validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc), | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def compute_kendalltau(vectori, vectorj): | ||||
| @@ -61,9 +74,17 @@ if __name__ == "__main__": | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", type=str, default="output/vis-nas-bench/nas-algos", help="Folder to save checkpoints and log." | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="output/vis-nas-bench/nas-algos", | ||||
|         help="Folder to save checkpoints and log.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--search_space", | ||||
|         type=str, | ||||
|         choices=["tss", "sss"], | ||||
|         help="Choose the search space.", | ||||
|     ) | ||||
|     parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.") | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     save_dir = Path(args.save_dir) | ||||
| @@ -77,9 +98,17 @@ if __name__ == "__main__": | ||||
|         scores_1.append(valid_acc) | ||||
|         scores_2.append(test_acc) | ||||
|     correlation = compute_kendalltau(scores_1, scores_2) | ||||
|     print("The kendall tau correlation of {:} samples : {:}".format(len(indexes), correlation)) | ||||
|     print( | ||||
|         "The kendall tau correlation of {:} samples : {:}".format( | ||||
|             len(indexes), correlation | ||||
|         ) | ||||
|     ) | ||||
|     correlation = compute_spearmanr(scores_1, scores_2) | ||||
|     print("The spearmanr correlation of {:} samples : {:}".format(len(indexes), correlation)) | ||||
|     print( | ||||
|         "The spearmanr correlation of {:} samples : {:}".format( | ||||
|             len(indexes), correlation | ||||
|         ) | ||||
|     ) | ||||
|     # scores_1 = ['{:.2f}'.format(x) for x in scores_1] | ||||
|     # scores_2 = ['{:.2f}'.format(x) for x in scores_2] | ||||
|     # print(', '.join(scores_1)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user