Add int search space
This commit is contained in:
		| @@ -25,7 +25,11 @@ def check_unique_arch(meta_file): | ||||
|  | ||||
|     def get_unique_matrix(archs, consider_zero): | ||||
|         UniquStrs = [arch.to_unique_str(consider_zero) for arch in archs] | ||||
|         print("{:} create unique-string ({:}/{:}) done".format(time_string(), len(set(UniquStrs)), len(UniquStrs))) | ||||
|         print( | ||||
|             "{:} create unique-string ({:}/{:}) done".format( | ||||
|                 time_string(), len(set(UniquStrs)), len(UniquStrs) | ||||
|             ) | ||||
|         ) | ||||
|         Unique2Index = dict() | ||||
|         for index, xstr in enumerate(UniquStrs): | ||||
|             if xstr not in Unique2Index: | ||||
| @@ -47,16 +51,32 @@ def check_unique_arch(meta_file): | ||||
|             unique_num += 1 | ||||
|         return sm_matrix, unique_ids, unique_num | ||||
|  | ||||
|     print("There are {:} valid-archs".format(sum(arch.check_valid() for arch in xarchs))) | ||||
|     print( | ||||
|         "There are {:} valid-archs".format(sum(arch.check_valid() for arch in xarchs)) | ||||
|     ) | ||||
|     sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, None) | ||||
|     print("{:} There are {:} unique architectures (considering nothing).".format(time_string(), unique_num)) | ||||
|     print( | ||||
|         "{:} There are {:} unique architectures (considering nothing).".format( | ||||
|             time_string(), unique_num | ||||
|         ) | ||||
|     ) | ||||
|     sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, False) | ||||
|     print("{:} There are {:} unique architectures (not considering zero).".format(time_string(), unique_num)) | ||||
|     print( | ||||
|         "{:} There are {:} unique architectures (not considering zero).".format( | ||||
|             time_string(), unique_num | ||||
|         ) | ||||
|     ) | ||||
|     sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, True) | ||||
|     print("{:} There are {:} unique architectures (considering zero).".format(time_string(), unique_num)) | ||||
|     print( | ||||
|         "{:} There are {:} unique architectures (considering zero).".format( | ||||
|             time_string(), unique_num | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, need_print=False): | ||||
| def check_cor_for_bandit( | ||||
|     meta_file, test_epoch, use_less_or_not, is_rand=True, need_print=False | ||||
| ): | ||||
|     if isinstance(meta_file, API): | ||||
|         api = meta_file | ||||
|     else: | ||||
| @@ -69,7 +89,9 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, n | ||||
|     imagenet_test = [] | ||||
|     imagenet_valid = [] | ||||
|     for idx, arch in enumerate(api): | ||||
|         results = api.get_more_info(idx, "cifar10-valid", test_epoch - 1, use_less_or_not, is_rand) | ||||
|         results = api.get_more_info( | ||||
|             idx, "cifar10-valid", test_epoch - 1, use_less_or_not, is_rand | ||||
|         ) | ||||
|         cifar10_currs.append(results["valid-accuracy"]) | ||||
|         # --->>>>> | ||||
|         results = api.get_more_info(idx, "cifar10-valid", None, False, is_rand) | ||||
| @@ -89,13 +111,23 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, n | ||||
|     cors = [] | ||||
|     for basestr, xlist in zip( | ||||
|         ["C-010-V", "C-010-T", "C-100-V", "C-100-T", "I16-V", "I16-T"], | ||||
|         [cifar10_valid, cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test], | ||||
|         [ | ||||
|             cifar10_valid, | ||||
|             cifar10_test, | ||||
|             cifar100_valid, | ||||
|             cifar100_test, | ||||
|             imagenet_valid, | ||||
|             imagenet_test, | ||||
|         ], | ||||
|     ): | ||||
|         correlation = get_cor(cifar10_currs, xlist) | ||||
|         if need_print: | ||||
|             print( | ||||
|                 "With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}".format( | ||||
|                     test_epoch, "012" if use_less_or_not else "200", basestr, correlation | ||||
|                     test_epoch, | ||||
|                     "012" if use_less_or_not else "200", | ||||
|                     basestr, | ||||
|                     correlation, | ||||
|                 ) | ||||
|             ) | ||||
|         cors.append(correlation) | ||||
| @@ -113,7 +145,11 @@ def check_cor_for_bandit_v2(meta_file, test_epoch, use_less_or_not, is_rand): | ||||
|     # xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'] | ||||
|     xstrs = ["C-010-V", "C-010-T", "C-100-V", "C-100-T", "I16-V", "I16-T"] | ||||
|     correlations = np.array(corrs) | ||||
|     print("------>>>>>>>> {:03d}/{:} >>>>>>>> ------".format(test_epoch, "012" if use_less_or_not else "200")) | ||||
|     print( | ||||
|         "------>>>>>>>> {:03d}/{:} >>>>>>>> ------".format( | ||||
|             test_epoch, "012" if use_less_or_not else "200" | ||||
|         ) | ||||
|     ) | ||||
|     for idx, xstr in enumerate(xstrs): | ||||
|         print( | ||||
|             "{:8s} ::: mean={:.4f}, std={:.4f} :: {:.4f}\\pm{:.4f}".format( | ||||
| @@ -135,7 +171,12 @@ if __name__ == "__main__": | ||||
|         default="./output/search-cell-nas-bench-201/visuals", | ||||
|         help="The base-name of folder to save checkpoints and log.", | ||||
|     ) | ||||
|     parser.add_argument("--api_path", type=str, default=None, help="The path to the NAS-Bench-201 benchmark file.") | ||||
|     parser.add_argument( | ||||
|         "--api_path", | ||||
|         type=str, | ||||
|         default=None, | ||||
|         help="The path to the NAS-Bench-201 benchmark file.", | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     vis_save_dir = Path(args.save_dir) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user