##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # ######################################################## # python exps/NAS-Bench-201/test-correlation.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth ######################################################## import sys, argparse import numpy as np from copy import deepcopy from tqdm import tqdm import torch from pathlib import Path from xautodl.log_utils import time_string from xautodl.models import CellStructure from nas_201_api import NASBench201API as API def check_unique_arch(meta_file): api = API(str(meta_file)) arch_strs = deepcopy(api.meta_archs) xarchs = [CellStructure.str2structure(x) for x in arch_strs] def get_unique_matrix(archs, consider_zero): UniquStrs = [arch.to_unique_str(consider_zero) for arch in archs] print( "{:} create unique-string ({:}/{:}) done".format( time_string(), len(set(UniquStrs)), len(UniquStrs) ) ) Unique2Index = dict() for index, xstr in enumerate(UniquStrs): if xstr not in Unique2Index: Unique2Index[xstr] = list() Unique2Index[xstr].append(index) sm_matrix = torch.eye(len(archs)).bool() for _, xlist in Unique2Index.items(): for i in xlist: for j in xlist: sm_matrix[i, j] = True unique_ids, unique_num = [-1 for _ in archs], 0 for i in range(len(unique_ids)): if unique_ids[i] > -1: continue neighbours = sm_matrix[i].nonzero().view(-1).tolist() for nghb in neighbours: assert unique_ids[nghb] == -1, "impossible" unique_ids[nghb] = unique_num unique_num += 1 return sm_matrix, unique_ids, unique_num print( "There are {:} valid-archs".format(sum(arch.check_valid() for arch in xarchs)) ) sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, None) print( "{:} There are {:} unique architectures (considering nothing).".format( time_string(), unique_num ) ) sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, False) print( "{:} There are {:} unique architectures (not considering zero).".format( time_string(), unique_num ) ) sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, True) print( "{:} There are {:} unique architectures (considering zero).".format( time_string(), unique_num ) ) def check_cor_for_bandit( meta_file, test_epoch, use_less_or_not, is_rand=True, need_print=False ): if isinstance(meta_file, API): api = meta_file else: api = API(str(meta_file)) cifar10_currs = [] cifar10_valid = [] cifar10_test = [] cifar100_valid = [] cifar100_test = [] imagenet_test = [] imagenet_valid = [] for idx, arch in enumerate(api): results = api.get_more_info( idx, "cifar10-valid", test_epoch - 1, use_less_or_not, is_rand ) cifar10_currs.append(results["valid-accuracy"]) # --->>>>> results = api.get_more_info(idx, "cifar10-valid", None, False, is_rand) cifar10_valid.append(results["valid-accuracy"]) results = api.get_more_info(idx, "cifar10", None, False, is_rand) cifar10_test.append(results["test-accuracy"]) results = api.get_more_info(idx, "cifar100", None, False, is_rand) cifar100_test.append(results["test-accuracy"]) cifar100_valid.append(results["valid-accuracy"]) results = api.get_more_info(idx, "ImageNet16-120", None, False, is_rand) imagenet_test.append(results["test-accuracy"]) imagenet_valid.append(results["valid-accuracy"]) def get_cor(A, B): return float(np.corrcoef(A, B)[0, 1]) cors = [] for basestr, xlist in zip( ["C-010-V", "C-010-T", "C-100-V", "C-100-T", "I16-V", "I16-T"], [ cifar10_valid, cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test, ], ): correlation = get_cor(cifar10_currs, xlist) if need_print: print( "With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}".format( test_epoch, "012" if use_less_or_not else "200", basestr, correlation, ) ) cors.append(correlation) # print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist))) # print('-'*200) # print('*'*230) return cors def check_cor_for_bandit_v2(meta_file, test_epoch, use_less_or_not, is_rand): corrs = [] for i in tqdm(range(100)): x = check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand, False) corrs.append(x) # xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'] xstrs = ["C-010-V", "C-010-T", "C-100-V", "C-100-T", "I16-V", "I16-T"] correlations = np.array(corrs) print( "------>>>>>>>> {:03d}/{:} >>>>>>>> ------".format( test_epoch, "012" if use_less_or_not else "200" ) ) for idx, xstr in enumerate(xstrs): print( "{:8s} ::: mean={:.4f}, std={:.4f} :: {:.4f}\\pm{:.4f}".format( xstr, correlations[:, idx].mean(), correlations[:, idx].std(), correlations[:, idx].mean(), correlations[:, idx].std(), ) ) print("") if __name__ == "__main__": parser = argparse.ArgumentParser("Analysis of NAS-Bench-201") parser.add_argument( "--save_dir", type=str, default="./output/search-cell-nas-bench-201/visuals", help="The base-name of folder to save checkpoints and log.", ) parser.add_argument( "--api_path", type=str, default=None, help="The path to the NAS-Bench-201 benchmark file.", ) args = parser.parse_args() vis_save_dir = Path(args.save_dir) vis_save_dir.mkdir(parents=True, exist_ok=True) meta_file = Path(args.api_path) assert meta_file.exists(), "invalid path for api : {:}".format(meta_file) # check_unique_arch(meta_file) api = API(str(meta_file)) # for iepoch in [11, 25, 50, 100, 150, 175, 200]: # check_cor_for_bandit(api, 6, iepoch) # check_cor_for_bandit(api, 12, iepoch) check_cor_for_bandit_v2(api, 6, True, True) check_cor_for_bandit_v2(api, 12, True, True) check_cor_for_bandit_v2(api, 12, False, True) check_cor_for_bandit_v2(api, 24, False, True) check_cor_for_bandit_v2(api, 100, False, True) check_cor_for_bandit_v2(api, 150, False, True) check_cor_for_bandit_v2(api, 175, False, True) check_cor_for_bandit_v2(api, 200, False, True) print("----")