Reformulate via black
This commit is contained in:
		| @@ -9,123 +9,151 @@ from copy import deepcopy | ||||
| from tqdm import tqdm | ||||
| import torch | ||||
| from pathlib import Path | ||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from log_utils    import time_string | ||||
| from models       import CellStructure | ||||
| from nas_201_api  import NASBench201API as API | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
| from log_utils import time_string | ||||
| from models import CellStructure | ||||
| from nas_201_api import NASBench201API as API | ||||
|  | ||||
|  | ||||
| def check_unique_arch(meta_file): | ||||
|   api = API(str(meta_file)) | ||||
|   arch_strs = deepcopy(api.meta_archs) | ||||
|   xarchs = [CellStructure.str2structure(x) for x in arch_strs] | ||||
|   def get_unique_matrix(archs, consider_zero): | ||||
|     UniquStrs = [arch.to_unique_str(consider_zero) for arch in archs] | ||||
|     print ('{:} create unique-string ({:}/{:}) done'.format(time_string(), len(set(UniquStrs)), len(UniquStrs))) | ||||
|     Unique2Index = dict() | ||||
|     for index, xstr in enumerate(UniquStrs): | ||||
|       if xstr not in Unique2Index: Unique2Index[xstr] = list() | ||||
|       Unique2Index[xstr].append( index ) | ||||
|     sm_matrix = torch.eye(len(archs)).bool() | ||||
|     for _, xlist in Unique2Index.items(): | ||||
|       for i in xlist: | ||||
|         for j in xlist: | ||||
|           sm_matrix[i,j] = True | ||||
|     unique_ids, unique_num = [-1 for _ in archs], 0 | ||||
|     for i in range(len(unique_ids)): | ||||
|       if unique_ids[i] > -1: continue | ||||
|       neighbours = sm_matrix[i].nonzero().view(-1).tolist() | ||||
|       for nghb in neighbours: | ||||
|         assert unique_ids[nghb] == -1, 'impossible' | ||||
|         unique_ids[nghb] = unique_num | ||||
|       unique_num += 1 | ||||
|     return sm_matrix, unique_ids, unique_num | ||||
|     api = API(str(meta_file)) | ||||
|     arch_strs = deepcopy(api.meta_archs) | ||||
|     xarchs = [CellStructure.str2structure(x) for x in arch_strs] | ||||
|  | ||||
|   print ('There are {:} valid-archs'.format( sum(arch.check_valid() for arch in xarchs) )) | ||||
|   sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, None) | ||||
|   print ('{:} There are {:} unique architectures (considering nothing).'.format(time_string(), unique_num)) | ||||
|   sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, False) | ||||
|   print ('{:} There are {:} unique architectures (not considering zero).'.format(time_string(), unique_num)) | ||||
|   sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs,  True) | ||||
|   print ('{:} There are {:} unique architectures (considering zero).'.format(time_string(), unique_num)) | ||||
|     def get_unique_matrix(archs, consider_zero): | ||||
|         UniquStrs = [arch.to_unique_str(consider_zero) for arch in archs] | ||||
|         print("{:} create unique-string ({:}/{:}) done".format(time_string(), len(set(UniquStrs)), len(UniquStrs))) | ||||
|         Unique2Index = dict() | ||||
|         for index, xstr in enumerate(UniquStrs): | ||||
|             if xstr not in Unique2Index: | ||||
|                 Unique2Index[xstr] = list() | ||||
|             Unique2Index[xstr].append(index) | ||||
|         sm_matrix = torch.eye(len(archs)).bool() | ||||
|         for _, xlist in Unique2Index.items(): | ||||
|             for i in xlist: | ||||
|                 for j in xlist: | ||||
|                     sm_matrix[i, j] = True | ||||
|         unique_ids, unique_num = [-1 for _ in archs], 0 | ||||
|         for i in range(len(unique_ids)): | ||||
|             if unique_ids[i] > -1: | ||||
|                 continue | ||||
|             neighbours = sm_matrix[i].nonzero().view(-1).tolist() | ||||
|             for nghb in neighbours: | ||||
|                 assert unique_ids[nghb] == -1, "impossible" | ||||
|                 unique_ids[nghb] = unique_num | ||||
|             unique_num += 1 | ||||
|         return sm_matrix, unique_ids, unique_num | ||||
|  | ||||
|     print("There are {:} valid-archs".format(sum(arch.check_valid() for arch in xarchs))) | ||||
|     sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, None) | ||||
|     print("{:} There are {:} unique architectures (considering nothing).".format(time_string(), unique_num)) | ||||
|     sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, False) | ||||
|     print("{:} There are {:} unique architectures (not considering zero).".format(time_string(), unique_num)) | ||||
|     sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, True) | ||||
|     print("{:} There are {:} unique architectures (considering zero).".format(time_string(), unique_num)) | ||||
|  | ||||
|  | ||||
| def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, need_print=False): | ||||
|   if isinstance(meta_file, API): | ||||
|     api = meta_file | ||||
|   else: | ||||
|     api = API(str(meta_file)) | ||||
|   cifar10_currs     = [] | ||||
|   cifar10_valid     = [] | ||||
|   cifar10_test      = [] | ||||
|   cifar100_valid    = [] | ||||
|   cifar100_test     = [] | ||||
|   imagenet_test     = [] | ||||
|   imagenet_valid    = [] | ||||
|   for idx, arch in enumerate(api): | ||||
|     results = api.get_more_info(idx, 'cifar10-valid' , test_epoch-1, use_less_or_not, is_rand) | ||||
|     cifar10_currs.append( results['valid-accuracy'] ) | ||||
|     # --->>>>> | ||||
|     results = api.get_more_info(idx, 'cifar10-valid' , None, False, is_rand) | ||||
|     cifar10_valid.append( results['valid-accuracy'] ) | ||||
|     results = api.get_more_info(idx, 'cifar10'       , None, False, is_rand) | ||||
|     cifar10_test.append( results['test-accuracy'] ) | ||||
|     results = api.get_more_info(idx, 'cifar100'      , None, False, is_rand) | ||||
|     cifar100_test.append( results['test-accuracy'] ) | ||||
|     cifar100_valid.append( results['valid-accuracy'] ) | ||||
|     results = api.get_more_info(idx, 'ImageNet16-120', None, False, is_rand) | ||||
|     imagenet_test.append( results['test-accuracy'] ) | ||||
|     imagenet_valid.append( results['valid-accuracy'] ) | ||||
|   def get_cor(A, B): | ||||
|     return float(np.corrcoef(A, B)[0,1]) | ||||
|   cors = [] | ||||
|   for basestr, xlist in zip(['C-010-V', 'C-010-T', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'], [cifar10_valid, cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test]): | ||||
|     correlation = get_cor(cifar10_currs, xlist) | ||||
|     if need_print: print ('With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, '012' if use_less_or_not else '200', basestr, correlation)) | ||||
|     cors.append( correlation ) | ||||
|     #print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist))) | ||||
|     #print('-'*200) | ||||
|   #print('*'*230) | ||||
|   return cors | ||||
|     if isinstance(meta_file, API): | ||||
|         api = meta_file | ||||
|     else: | ||||
|         api = API(str(meta_file)) | ||||
|     cifar10_currs = [] | ||||
|     cifar10_valid = [] | ||||
|     cifar10_test = [] | ||||
|     cifar100_valid = [] | ||||
|     cifar100_test = [] | ||||
|     imagenet_test = [] | ||||
|     imagenet_valid = [] | ||||
|     for idx, arch in enumerate(api): | ||||
|         results = api.get_more_info(idx, "cifar10-valid", test_epoch - 1, use_less_or_not, is_rand) | ||||
|         cifar10_currs.append(results["valid-accuracy"]) | ||||
|         # --->>>>> | ||||
|         results = api.get_more_info(idx, "cifar10-valid", None, False, is_rand) | ||||
|         cifar10_valid.append(results["valid-accuracy"]) | ||||
|         results = api.get_more_info(idx, "cifar10", None, False, is_rand) | ||||
|         cifar10_test.append(results["test-accuracy"]) | ||||
|         results = api.get_more_info(idx, "cifar100", None, False, is_rand) | ||||
|         cifar100_test.append(results["test-accuracy"]) | ||||
|         cifar100_valid.append(results["valid-accuracy"]) | ||||
|         results = api.get_more_info(idx, "ImageNet16-120", None, False, is_rand) | ||||
|         imagenet_test.append(results["test-accuracy"]) | ||||
|         imagenet_valid.append(results["valid-accuracy"]) | ||||
|  | ||||
|     def get_cor(A, B): | ||||
|         return float(np.corrcoef(A, B)[0, 1]) | ||||
|  | ||||
|     cors = [] | ||||
|     for basestr, xlist in zip( | ||||
|         ["C-010-V", "C-010-T", "C-100-V", "C-100-T", "I16-V", "I16-T"], | ||||
|         [cifar10_valid, cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test], | ||||
|     ): | ||||
|         correlation = get_cor(cifar10_currs, xlist) | ||||
|         if need_print: | ||||
|             print( | ||||
|                 "With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}".format( | ||||
|                     test_epoch, "012" if use_less_or_not else "200", basestr, correlation | ||||
|                 ) | ||||
|             ) | ||||
|         cors.append(correlation) | ||||
|         # print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist))) | ||||
|         # print('-'*200) | ||||
|     # print('*'*230) | ||||
|     return cors | ||||
|  | ||||
|  | ||||
| def check_cor_for_bandit_v2(meta_file, test_epoch, use_less_or_not, is_rand): | ||||
|   corrs = [] | ||||
|   for i in tqdm(range(100)): | ||||
|     x = check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand, False) | ||||
|     corrs.append( x ) | ||||
|   #xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'] | ||||
|   xstrs = ['C-010-V', 'C-010-T', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'] | ||||
|   correlations = np.array(corrs) | ||||
|   print('------>>>>>>>> {:03d}/{:} >>>>>>>> ------'.format(test_epoch, '012' if use_less_or_not else '200')) | ||||
|   for idx, xstr in enumerate(xstrs): | ||||
|     print ('{:8s} ::: mean={:.4f}, std={:.4f} :: {:.4f}\\pm{:.4f}'.format(xstr, correlations[:,idx].mean(), correlations[:,idx].std(), correlations[:,idx].mean(), correlations[:,idx].std())) | ||||
|   print('') | ||||
|     corrs = [] | ||||
|     for i in tqdm(range(100)): | ||||
|         x = check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand, False) | ||||
|         corrs.append(x) | ||||
|     # xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'] | ||||
|     xstrs = ["C-010-V", "C-010-T", "C-100-V", "C-100-T", "I16-V", "I16-T"] | ||||
|     correlations = np.array(corrs) | ||||
|     print("------>>>>>>>> {:03d}/{:} >>>>>>>> ------".format(test_epoch, "012" if use_less_or_not else "200")) | ||||
|     for idx, xstr in enumerate(xstrs): | ||||
|         print( | ||||
|             "{:8s} ::: mean={:.4f}, std={:.4f} :: {:.4f}\\pm{:.4f}".format( | ||||
|                 xstr, | ||||
|                 correlations[:, idx].mean(), | ||||
|                 correlations[:, idx].std(), | ||||
|                 correlations[:, idx].mean(), | ||||
|                 correlations[:, idx].std(), | ||||
|             ) | ||||
|         ) | ||||
|     print("") | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   parser = argparse.ArgumentParser("Analysis of NAS-Bench-201") | ||||
|   parser.add_argument('--save_dir',  type=str, default='./output/search-cell-nas-bench-201/visuals', help='The base-name of folder to save checkpoints and log.') | ||||
|   parser.add_argument('--api_path',  type=str, default=None,                                         help='The path to the NAS-Bench-201 benchmark file.') | ||||
|   args = parser.parse_args() | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("Analysis of NAS-Bench-201") | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./output/search-cell-nas-bench-201/visuals", | ||||
|         help="The base-name of folder to save checkpoints and log.", | ||||
|     ) | ||||
|     parser.add_argument("--api_path", type=str, default=None, help="The path to the NAS-Bench-201 benchmark file.") | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|   vis_save_dir = Path(args.save_dir) | ||||
|   vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|   meta_file = Path(args.api_path) | ||||
|   assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file) | ||||
|     vis_save_dir = Path(args.save_dir) | ||||
|     vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|     meta_file = Path(args.api_path) | ||||
|     assert meta_file.exists(), "invalid path for api : {:}".format(meta_file) | ||||
|  | ||||
|   #check_unique_arch(meta_file) | ||||
|   api = API(str(meta_file)) | ||||
|   #for iepoch in [11, 25, 50, 100, 150, 175, 200]: | ||||
|   #  check_cor_for_bandit(api,  6, iepoch) | ||||
|   #  check_cor_for_bandit(api, 12, iepoch) | ||||
|   check_cor_for_bandit_v2(api,   6,  True, True) | ||||
|   check_cor_for_bandit_v2(api,  12,  True, True) | ||||
|   check_cor_for_bandit_v2(api,  12, False, True) | ||||
|   check_cor_for_bandit_v2(api,  24, False, True) | ||||
|   check_cor_for_bandit_v2(api, 100, False, True) | ||||
|   check_cor_for_bandit_v2(api, 150, False, True) | ||||
|   check_cor_for_bandit_v2(api, 175, False, True) | ||||
|   check_cor_for_bandit_v2(api, 200, False, True) | ||||
|   print('----') | ||||
|     # check_unique_arch(meta_file) | ||||
|     api = API(str(meta_file)) | ||||
|     # for iepoch in [11, 25, 50, 100, 150, 175, 200]: | ||||
|     #  check_cor_for_bandit(api,  6, iepoch) | ||||
|     #  check_cor_for_bandit(api, 12, iepoch) | ||||
|     check_cor_for_bandit_v2(api, 6, True, True) | ||||
|     check_cor_for_bandit_v2(api, 12, True, True) | ||||
|     check_cor_for_bandit_v2(api, 12, False, True) | ||||
|     check_cor_for_bandit_v2(api, 24, False, True) | ||||
|     check_cor_for_bandit_v2(api, 100, False, True) | ||||
|     check_cor_for_bandit_v2(api, 150, False, True) | ||||
|     check_cor_for_bandit_v2(api, 175, False, True) | ||||
|     check_cor_for_bandit_v2(api, 200, False, True) | ||||
|     print("----") | ||||
|   | ||||
		Reference in New Issue
	
	Block a user