201 lines
		
	
	
		
			7.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			201 lines
		
	
	
		
			7.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #####################################################
 | |
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
 | |
| ########################################################
 | |
| # python exps/NAS-Bench-201/test-correlation.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth
 | |
| ########################################################
 | |
| import sys, argparse
 | |
| import numpy as np
 | |
| from copy import deepcopy
 | |
| from tqdm import tqdm
 | |
| import torch
 | |
| from pathlib import Path
 | |
| 
 | |
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
 | |
| if str(lib_dir) not in sys.path:
 | |
|     sys.path.insert(0, str(lib_dir))
 | |
| from log_utils import time_string
 | |
| from models import CellStructure
 | |
| from nas_201_api import NASBench201API as API
 | |
| 
 | |
| 
 | |
| def check_unique_arch(meta_file):
 | |
|     api = API(str(meta_file))
 | |
|     arch_strs = deepcopy(api.meta_archs)
 | |
|     xarchs = [CellStructure.str2structure(x) for x in arch_strs]
 | |
| 
 | |
|     def get_unique_matrix(archs, consider_zero):
 | |
|         UniquStrs = [arch.to_unique_str(consider_zero) for arch in archs]
 | |
|         print(
 | |
|             "{:} create unique-string ({:}/{:}) done".format(
 | |
|                 time_string(), len(set(UniquStrs)), len(UniquStrs)
 | |
|             )
 | |
|         )
 | |
|         Unique2Index = dict()
 | |
|         for index, xstr in enumerate(UniquStrs):
 | |
|             if xstr not in Unique2Index:
 | |
|                 Unique2Index[xstr] = list()
 | |
|             Unique2Index[xstr].append(index)
 | |
|         sm_matrix = torch.eye(len(archs)).bool()
 | |
|         for _, xlist in Unique2Index.items():
 | |
|             for i in xlist:
 | |
|                 for j in xlist:
 | |
|                     sm_matrix[i, j] = True
 | |
|         unique_ids, unique_num = [-1 for _ in archs], 0
 | |
|         for i in range(len(unique_ids)):
 | |
|             if unique_ids[i] > -1:
 | |
|                 continue
 | |
|             neighbours = sm_matrix[i].nonzero().view(-1).tolist()
 | |
|             for nghb in neighbours:
 | |
|                 assert unique_ids[nghb] == -1, "impossible"
 | |
|                 unique_ids[nghb] = unique_num
 | |
|             unique_num += 1
 | |
|         return sm_matrix, unique_ids, unique_num
 | |
| 
 | |
|     print(
 | |
|         "There are {:} valid-archs".format(sum(arch.check_valid() for arch in xarchs))
 | |
|     )
 | |
|     sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, None)
 | |
|     print(
 | |
|         "{:} There are {:} unique architectures (considering nothing).".format(
 | |
|             time_string(), unique_num
 | |
|         )
 | |
|     )
 | |
|     sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, False)
 | |
|     print(
 | |
|         "{:} There are {:} unique architectures (not considering zero).".format(
 | |
|             time_string(), unique_num
 | |
|         )
 | |
|     )
 | |
|     sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, True)
 | |
|     print(
 | |
|         "{:} There are {:} unique architectures (considering zero).".format(
 | |
|             time_string(), unique_num
 | |
|         )
 | |
|     )
 | |
| 
 | |
| 
 | |
| def check_cor_for_bandit(
 | |
|     meta_file, test_epoch, use_less_or_not, is_rand=True, need_print=False
 | |
| ):
 | |
|     if isinstance(meta_file, API):
 | |
|         api = meta_file
 | |
|     else:
 | |
|         api = API(str(meta_file))
 | |
|     cifar10_currs = []
 | |
|     cifar10_valid = []
 | |
|     cifar10_test = []
 | |
|     cifar100_valid = []
 | |
|     cifar100_test = []
 | |
|     imagenet_test = []
 | |
|     imagenet_valid = []
 | |
|     for idx, arch in enumerate(api):
 | |
|         results = api.get_more_info(
 | |
|             idx, "cifar10-valid", test_epoch - 1, use_less_or_not, is_rand
 | |
|         )
 | |
|         cifar10_currs.append(results["valid-accuracy"])
 | |
|         # --->>>>>
 | |
|         results = api.get_more_info(idx, "cifar10-valid", None, False, is_rand)
 | |
|         cifar10_valid.append(results["valid-accuracy"])
 | |
|         results = api.get_more_info(idx, "cifar10", None, False, is_rand)
 | |
|         cifar10_test.append(results["test-accuracy"])
 | |
|         results = api.get_more_info(idx, "cifar100", None, False, is_rand)
 | |
|         cifar100_test.append(results["test-accuracy"])
 | |
|         cifar100_valid.append(results["valid-accuracy"])
 | |
|         results = api.get_more_info(idx, "ImageNet16-120", None, False, is_rand)
 | |
|         imagenet_test.append(results["test-accuracy"])
 | |
|         imagenet_valid.append(results["valid-accuracy"])
 | |
| 
 | |
|     def get_cor(A, B):
 | |
|         return float(np.corrcoef(A, B)[0, 1])
 | |
| 
 | |
|     cors = []
 | |
|     for basestr, xlist in zip(
 | |
|         ["C-010-V", "C-010-T", "C-100-V", "C-100-T", "I16-V", "I16-T"],
 | |
|         [
 | |
|             cifar10_valid,
 | |
|             cifar10_test,
 | |
|             cifar100_valid,
 | |
|             cifar100_test,
 | |
|             imagenet_valid,
 | |
|             imagenet_test,
 | |
|         ],
 | |
|     ):
 | |
|         correlation = get_cor(cifar10_currs, xlist)
 | |
|         if need_print:
 | |
|             print(
 | |
|                 "With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}".format(
 | |
|                     test_epoch,
 | |
|                     "012" if use_less_or_not else "200",
 | |
|                     basestr,
 | |
|                     correlation,
 | |
|                 )
 | |
|             )
 | |
|         cors.append(correlation)
 | |
|         # print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist)))
 | |
|         # print('-'*200)
 | |
|     # print('*'*230)
 | |
|     return cors
 | |
| 
 | |
| 
 | |
| def check_cor_for_bandit_v2(meta_file, test_epoch, use_less_or_not, is_rand):
 | |
|     corrs = []
 | |
|     for i in tqdm(range(100)):
 | |
|         x = check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand, False)
 | |
|         corrs.append(x)
 | |
|     # xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T']
 | |
|     xstrs = ["C-010-V", "C-010-T", "C-100-V", "C-100-T", "I16-V", "I16-T"]
 | |
|     correlations = np.array(corrs)
 | |
|     print(
 | |
|         "------>>>>>>>> {:03d}/{:} >>>>>>>> ------".format(
 | |
|             test_epoch, "012" if use_less_or_not else "200"
 | |
|         )
 | |
|     )
 | |
|     for idx, xstr in enumerate(xstrs):
 | |
|         print(
 | |
|             "{:8s} ::: mean={:.4f}, std={:.4f} :: {:.4f}\\pm{:.4f}".format(
 | |
|                 xstr,
 | |
|                 correlations[:, idx].mean(),
 | |
|                 correlations[:, idx].std(),
 | |
|                 correlations[:, idx].mean(),
 | |
|                 correlations[:, idx].std(),
 | |
|             )
 | |
|         )
 | |
|     print("")
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
 | |
|     parser.add_argument(
 | |
|         "--save_dir",
 | |
|         type=str,
 | |
|         default="./output/search-cell-nas-bench-201/visuals",
 | |
|         help="The base-name of folder to save checkpoints and log.",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--api_path",
 | |
|         type=str,
 | |
|         default=None,
 | |
|         help="The path to the NAS-Bench-201 benchmark file.",
 | |
|     )
 | |
|     args = parser.parse_args()
 | |
| 
 | |
|     vis_save_dir = Path(args.save_dir)
 | |
|     vis_save_dir.mkdir(parents=True, exist_ok=True)
 | |
|     meta_file = Path(args.api_path)
 | |
|     assert meta_file.exists(), "invalid path for api : {:}".format(meta_file)
 | |
| 
 | |
|     # check_unique_arch(meta_file)
 | |
|     api = API(str(meta_file))
 | |
|     # for iepoch in [11, 25, 50, 100, 150, 175, 200]:
 | |
|     #  check_cor_for_bandit(api,  6, iepoch)
 | |
|     #  check_cor_for_bandit(api, 12, iepoch)
 | |
|     check_cor_for_bandit_v2(api, 6, True, True)
 | |
|     check_cor_for_bandit_v2(api, 12, True, True)
 | |
|     check_cor_for_bandit_v2(api, 12, False, True)
 | |
|     check_cor_for_bandit_v2(api, 24, False, True)
 | |
|     check_cor_for_bandit_v2(api, 100, False, True)
 | |
|     check_cor_for_bandit_v2(api, 150, False, True)
 | |
|     check_cor_for_bandit_v2(api, 175, False, True)
 | |
|     check_cor_for_bandit_v2(api, 200, False, True)
 | |
|     print("----")
 |