Update NATS-Bench (tss version 1.0) and remove the trace of 301

This commit is contained in:
D-X-Y 2020-09-16 08:28:27 +00:00
parent bd9288f45d
commit 9db28392c2
10 changed files with 169 additions and 249 deletions

View File

@ -26,7 +26,7 @@ from log_utils import time_string
from models import get_cell_based_tiny_net, CellStructure
def test_api(api, is_301=True):
def test_api(api, sss_or_tss=True):
print('{:} start testing the api : {:}'.format(time_string(), api))
api.clear_params(12)
api.reload(index=12)
@ -39,7 +39,7 @@ def test_api(api, is_301=True):
info = api.query_by_index(113, 'cifar100')
print('{:}\n'.format(info))
info = api.query_meta_info_by_index(115, '90' if is_301 else '200')
info = api.query_meta_info_by_index(115, '90' if sss_or_tss else '200')
print('{:}\n'.format(info))
for dataset in ['cifar10', 'cifar100', 'ImageNet16-120']:
@ -48,6 +48,7 @@ def test_api(api, is_301=True):
print('')
params = api.get_net_param(12, 'cifar10', None)
import pdb; pdb.set_trace()
# Obtain the config and create the network
config = api.get_net_config(12, 'cifar10')
print('{:}\n'.format(config))
@ -74,7 +75,7 @@ def test_api(api, is_301=True):
print('{:}\n'.format(info))
print('{:} finish testing the api : {:}'.format(time_string(), api))
if not is_301:
if not sss_or_tss:
arch_str = '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|'
matrix = api.str2matrix(arch_str)
print('Compute the adjacency matrix of {:}'.format(arch_str))
@ -88,13 +89,13 @@ if __name__ == '__main__':
# api201 = create('./output/NATS-Bench-topology/process-FULL', 'topology', fast_mode=True, verbose=True)
for fast_mode in [True, False]:
for verbose in [True, False]:
api201 = create(None, 'tss', fast_mode=fast_mode, verbose=True)
api_nats_tss = create(None, 'tss', fast_mode=fast_mode, verbose=True)
print('{:} create with fast_mode={:} and verbose={:}'.format(time_string(), fast_mode, verbose))
test_api(api201, False)
test_api(api_nats_tss, False)
for fast_mode in [True, False]:
for verbose in [True, False]:
print('{:} create with fast_mode={:} and verbose={:}'.format(time_string(), fast_mode, verbose))
api301 = create(None, 'size', fast_mode=fast_mode, verbose=True)
print('{:} --->>> {:}'.format(time_string(), api301))
test_api(api301, True)
api_nats_sss = create(None, 'size', fast_mode=fast_mode, verbose=True)
print('{:} --->>> {:}'.format(time_string(), api_nats_sss))
test_api(api_nats_sss, True)

View File

@ -0,0 +1,129 @@
##############################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
##############################################################################
# This file is used to re-orangize all checkpoints (created by main-tss.py) #
# into a single benchmark file. Besides, for each trial, we will merge the #
# information of all its trials into a single file. #
# #
# Usage: #
# python exps/NATS-Bench/tss-collect-patcher.py #
##############################################################################
import os, re, sys, time, shutil, random, argparse, collections
import numpy as np
from copy import deepcopy
import torch
from tqdm import tqdm
from pathlib import Path
from collections import defaultdict, OrderedDict
from typing import Dict, Any, Text, List
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from log_utils import AverageMeter, time_string, convert_secs2time
from config_utils import load_config, dict2config
from datasets import get_datasets
from models import CellStructure, get_cell_based_tiny_net, get_search_spaces
from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount
from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders
from utils import get_md5_file
from nas_201_api import NASBench201API
NATS_TSS_BASE_NAME = 'NATS-tss-v1_0' # 2020.08.28
def simplify(save_dir, save_name, nets, total, sup_config):
hps, seeds = ['12', '200'], set()
for hp in hps:
sub_save_dir = save_dir / 'raw-data-{:}'.format(hp)
ckps = sorted(list(sub_save_dir.glob('arch-*-seed-*.pth')))
seed2names = defaultdict(list)
for ckp in ckps:
parts = re.split('-|\.', ckp.name)
seed2names[parts[3]].append(ckp.name)
print('DIR : {:}'.format(sub_save_dir))
nums = []
for seed, xlist in seed2names.items():
seeds.add(seed)
nums.append(len(xlist))
print(' [seed={:}] there are {:} checkpoints.'.format(seed, len(xlist)))
assert len(nets) == total == max(nums), 'there are some missed files : {:} vs {:}'.format(max(nums), total)
print('{:} start simplify the checkpoint.'.format(time_string()))
datasets = ('cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120')
# Create the directory to save the processed data
# full_save_dir contains all benchmark files with trained weights.
# simplify_save_dir contains all benchmark files without trained weights.
full_save_dir = save_dir / (save_name + '-FULL')
simple_save_dir = save_dir / (save_name + '-SIMPLIFY')
full_save_dir.mkdir(parents=True, exist_ok=True)
simple_save_dir.mkdir(parents=True, exist_ok=True)
# all data in memory
arch2infos, evaluated_indexes = dict(), set()
end_time, arch_time = time.time(), AverageMeter()
# save the meta information
for index in tqdm(range(total)):
arch_str = nets[index]
hp2info = OrderedDict()
simple_save_path = simple_save_dir / '{:06d}.pickle'.format(index)
arch2infos[index] = pickle_load(simple_save_path)
evaluated_indexes.add(index)
# measure elapsed time
arch_time.update(time.time() - end_time)
end_time = time.time()
need_time = '{:}'.format(convert_secs2time(arch_time.avg * (total-index-1), True))
# print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time))
print('{:} {:} done.'.format(time_string(), save_name))
final_infos = {'meta_archs' : nets,
'total_archs': total,
'arch2infos' : arch2infos,
'evaluated_indexes': evaluated_indexes}
save_file_name = save_dir / '{:}.pickle'.format(save_name)
pickle_save(final_infos, str(save_file_name))
# move the benchmark file to a new path
hd5sum = get_md5_file(str(save_file_name) + '.pbz2')
hd5_file_name = save_dir / '{:}-{:}.pickle.pbz2'.format(NATS_TSS_BASE_NAME, hd5sum)
shutil.move(str(save_file_name) + '.pbz2', hd5_file_name)
print('Save {:} / {:} architecture results into {:} -> {:}.'.format(len(evaluated_indexes), total, save_file_name, hd5_file_name))
# move the directory to a new path
hd5_full_save_dir = save_dir / '{:}-{:}-full'.format(NATS_TSS_BASE_NAME, hd5sum)
hd5_simple_save_dir = save_dir / '{:}-{:}-simple'.format(NATS_TSS_BASE_NAME, hd5sum)
shutil.move(full_save_dir, hd5_full_save_dir)
shutil.move(simple_save_dir, hd5_simple_save_dir)
def traverse_net(max_node):
aa_nas_bench_ss = get_search_spaces('cell', 'nats-bench')
archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False)
print ('There are {:} archs vs {:}.'.format(len(archs), len(aa_nas_bench_ss) ** ((max_node-1)*max_node/2)))
random.seed( 88 ) # please do not change this line for reproducibility
random.shuffle( archs )
assert archs[0 ].tostr() == '|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|', 'please check the 0-th architecture : {:}'.format(archs[0])
assert archs[9 ].tostr() == '|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|', 'please check the 9-th architecture : {:}'.format(archs[9])
assert archs[123].tostr() == '|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|', 'please check the 123-th architecture : {:}'.format(archs[123])
return [x.tostr() for x in archs]
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NATS-Bench (topology search space)', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--base_save_dir', type=str, default='./output/NATS-Bench-topology', help='The base-name of folder to save checkpoints and log.')
parser.add_argument('--max_node' , type=int, default=4, help='The maximum node in a cell.')
parser.add_argument('--channel' , type=int, default=16, help='The number of channels.')
parser.add_argument('--num_cells' , type=int, default=5, help='The number of cells in one stage.')
parser.add_argument('--check_N' , type=int, default=15625, help='For safety.')
parser.add_argument('--save_name' , type=str, default='process', help='The save directory.')
args = parser.parse_args()
nets = traverse_net(args.max_node)
if len(nets) != args.check_N:
raise ValueError('Pre-num-check failed : {:} vs {:}'.format(len(nets), args.check_N))
save_dir = Path(args.base_save_dir)
simplify(save_dir, args.save_name, nets, args.check_N, {'name': 'infer.tiny', 'channel': args.channel, 'num_cells': args.num_cells})

View File

@ -10,7 +10,7 @@
# Usage: #
# python exps/NATS-Bench/tss-collect.py #
##############################################################################
import os, re, sys, time, random, argparse, collections
import os, re, sys, time, shutil, random, argparse, collections
import numpy as np
from copy import deepcopy
import torch
@ -26,6 +26,7 @@ from datasets import get_datasets
from models import CellStructure, get_cell_based_tiny_net, get_search_spaces
from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount
from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders
from utils import get_md5_file
from nas_201_api import NASBench201API

View File

@ -64,7 +64,7 @@ def get_search_spaces(xtype, name) -> List[Text]:
assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys())
return SearchSpaceNames[name]
elif xtype == 'sss': # The size search space.
if name == 'nas-bench-301' or name == 'nats-bench' or name == 'nats-bench-size':
if name in ['nats-bench', 'nats-bench-size']:
return {'candidates': [8, 16, 24, 32, 40, 48, 56, 64],
'numbers': 5}
else:

View File

@ -27,7 +27,6 @@ DARTS_SPACE = ['none', 'skip_connect', 'dua_sepc_3x3', 'dua_sepc_5x5',
SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
'nats-bench' : NAS_BENCH_201,
'nas-bench-201': NAS_BENCH_201,
'nas-bench-301': NAS_BENCH_201,
'darts' : DARTS_SPACE}

View File

@ -1,11 +1,15 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
#####################################################
#####################################################################
# This API will be updated after 2020.09.16. #
# Please use our new API for NATS-Bench, which is #
# more efficient and contains info of more architecture candidates. #
#####################################################################
from .api_utils import ArchResults, ResultsCount
from .api_201 import NASBench201API
from .api_301 import NASBench301API
# NAS_BENCH_201_API_VERSION="v1.1" # [2020.02.25]
# NAS_BENCH_201_API_VERSION="v1.2" # [2020.03.09]
# NAS_BENCH_201_API_VERSION="v1.3" # [2020.03.16]
NAS_BENCH_201_API_VERSION="v2.0" # [2020.06.30]

View File

@ -1,222 +0,0 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
############################################################################################
# NAS-Bench-301, coming soon.
############################################################################################
# The history of benchmark files:
# [2020.06.30] NAS-Bench-301-v1_0
#
import os, copy, random, torch, numpy as np
from pathlib import Path
from typing import List, Text, Union, Dict, Optional
from collections import OrderedDict, defaultdict
from .api_utils import ArchResults
from .api_utils import NASBenchMetaAPI
from .api_utils import remap_dataset_set_names
ALL_BENCHMARK_FILES = ['NAS-Bench-301-v1_0-363be7.pth']
ALL_ARCHIVE_DIRS = ['NAS-Bench-301-v1_0-archive']
def print_information(information, extra_info=None, show=False):
dataset_names = information.get_dataset_names()
strings = [information.arch_str, 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)]
def metric2str(loss, acc):
return 'loss = {:.3f} & top1 = {:.2f}%'.format(loss, acc)
for ida, dataset in enumerate(dataset_names):
metric = information.get_compute_costs(dataset)
flop, param, latency = metric['flops'], metric['params'], metric['latency']
str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency is not None and latency > 0 else None)
train_info = information.get_metrics(dataset, 'train')
if dataset == 'cifar10-valid':
valid_info = information.get_metrics(dataset, 'x-valid')
test__info = information.get_metrics(dataset, 'ori-test')
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(
dataset, metric2str(train_info['loss'], train_info['accuracy']),
metric2str(valid_info['loss'], valid_info['accuracy']),
metric2str(test__info['loss'], test__info['accuracy']))
elif dataset == 'cifar10':
test__info = information.get_metrics(dataset, 'ori-test')
str2 = '{:14s} train : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
else:
valid_info = information.get_metrics(dataset, 'x-valid')
test__info = information.get_metrics(dataset, 'x-test')
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
strings += [str1, str2]
if show: print('\n'.join(strings))
return strings
"""
This is the class for the API of NAS-Bench-301.
"""
class NASBench301API(NASBenchMetaAPI):
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, verbose: bool=True):
self.filename = None
self.reset_time()
if file_path_or_dict is None:
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], ALL_BENCHMARK_FILES[-1])
print ('Try to use the default NAS-Bench-301 path from {:}.'.format(file_path_or_dict))
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path):
file_path_or_dict = str(file_path_or_dict)
if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict))
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
self.filename = Path(file_path_or_dict).name
file_path_or_dict = torch.load(file_path_or_dict, map_location='cpu')
elif isinstance(file_path_or_dict, dict):
file_path_or_dict = copy.deepcopy( file_path_or_dict )
else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict)))
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
self.verbose = verbose # [TODO] a flag indicating whether to print more logs
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults
self.arch2infos_dict = OrderedDict()
self._avaliable_hps = set()
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
all_infos = file_path_or_dict['arch2infos'][xkey]
hp2archres = OrderedDict()
for hp_key, results in all_infos.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter
self.arch2infos_dict[xkey] = hp2archres
self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes']))
self.archstr2index = {}
for idx, arch in enumerate(self.meta_archs):
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
self.archstr2index[ arch ] = idx
if self.verbose:
print('Create NAS-Bench-301 done with {:}/{:} architectures avaliable.'.format(len(self.evaluated_indexes), len(self.meta_archs)))
def reload(self, archive_root: Text = None, index: int = None):
"""Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'.
If index is None, overwrite all ckps.
"""
if self.verbose:
print('Call clear_params with archive_root={:} and index={:}'.format(archive_root, index))
if archive_root is None:
archive_root = os.path.join(os.environ['TORCH_HOME'], ALL_ARCHIVE_DIRS[-1])
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
if index is None:
indexes = list(range(len(self)))
else:
indexes = [index]
for idx in indexes:
assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx)
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(idx))
if not os.path.isfile(xfile_path):
xfile_path = os.path.join(archive_root, '{:d}-FULL.pth'.format(idx))
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
xdata = torch.load(xfile_path, map_location='cpu')
assert isinstance(xdata, dict), 'invalid format of data in {:}'.format(xfile_path)
hp2archres = OrderedDict()
for hp_key, results in xdata.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self.arch2infos_dict[idx] = hp2archres
def query_info_str_by_arch(self, arch, hp: Text='12'):
""" This function is used to query the information of a specific architecture
'arch' can be an architecture index or an architecture string
When hp=01, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/01E.config'
When hp=12, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/12E.config'
When hp=90, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/90E.config'
The difference between these three configurations are the number of training epochs.
"""
if self.verbose:
print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp))
return self._query_info_str_by_arch(arch, hp, print_information)
def get_more_info(self, index, dataset: Text, iepoch=None, hp='12', is_random=True):
"""This function will return the metric for the `index`-th architecture
`dataset` indicates the dataset:
'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
'cifar10' : using the proposed train+valid set of CIFAR-10 as the training set
'cifar100' : using the proposed train set of CIFAR-100 as the training set
'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set
`iepoch` indicates the index of training epochs from 0 to 11/199.
When iepoch=None, it will return the metric for the last training epoch
When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0)
`hp` indicates different hyper-parameters for training
When hp=01, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 01 epochs
When hp=12, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 12 epochs
When hp=90, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 90 epochs
`is_random`
When is_random=True, the performance of a random architecture will be returned
When is_random=False, the performanceo of all trials will be averaged.
"""
if self.verbose:
print('Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(index, dataset, iepoch, hp, is_random))
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
if index not in self.arch2infos_dict:
raise ValueError('Did not find {:} from arch2infos_dict.'.format(index))
archresult = self.arch2infos_dict[index][str(hp)]
# if randomly select one trial, select the seed at first
if isinstance(is_random, bool) and is_random:
seeds = archresult.get_dataset_seeds(dataset)
is_random = random.choice(seeds)
# collect the training information
train_info = archresult.get_metrics(dataset, 'train', iepoch=iepoch, is_random=is_random)
total = train_info['iepoch'] + 1
xinfo = {'train-loss' : train_info['loss'],
'train-accuracy': train_info['accuracy'],
'train-per-time': train_info['all_time'] / total,
'train-all-time': train_info['all_time']}
# collect the evaluation information
if dataset == 'cifar10-valid':
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
try:
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
except:
test_info = None
valtest_info = None
else:
try: # collect results on the proposed test set
if dataset == 'cifar10':
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else:
test_info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
except:
test_info = None
try: # collect results on the proposed validation set
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
except:
valid_info = None
try:
if dataset != 'cifar10':
valtest_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else:
valtest_info = None
except:
valtest_info = None
if valid_info is not None:
xinfo['valid-loss'] = valid_info['loss']
xinfo['valid-accuracy'] = valid_info['accuracy']
xinfo['valid-per-time'] = valid_info['all_time'] / total
xinfo['valid-all-time'] = valid_info['all_time']
if test_info is not None:
xinfo['test-loss'] = test_info['loss']
xinfo['test-accuracy'] = test_info['accuracy']
xinfo['test-per-time'] = test_info['all_time'] / total
xinfo['test-all-time'] = test_info['all_time']
if valtest_info is not None:
xinfo['valtest-loss'] = valtest_info['loss']
xinfo['valtest-accuracy'] = valtest_info['accuracy']
xinfo['valtest-per-time'] = valtest_info['all_time'] / total
xinfo['valtest-all-time'] = valtest_info['all_time']
return xinfo
def show(self, index: int = -1) -> None:
"""
This function will print the information of a specific (or all) architecture(s).
:param index: If the index < 0: it will loop for all architectures and print their information one by one.
else: it will print the information of the 'index'-th architecture.
:return: nothing
"""
self._show(index, print_information)

View File

@ -716,7 +716,7 @@ class ResultsCount(object):
def get_config(self, str2structure):
"""This function is used to obtain the config dict for this architecture."""
if str2structure is None:
# In this case, this is NAS-Bench-301
# In this case, this is to handle the size search space.
if 'name' in self.arch_config and self.arch_config['name'] == 'infer.shape.tiny':
return {'name': 'infer.shape.tiny', 'channels': self.arch_config['channels'],
'genotype': self.arch_config['genotype'], 'num_classes': self.arch_config['class_num']}
@ -726,7 +726,7 @@ class ResultsCount(object):
'N' : self.arch_config['num_cells'],
'arch_str': self.arch_config['arch_str'], 'num_classes': self.arch_config['class_num']}
else:
# In this case, this is NAS-Bench-301
# In this case, this is to handle the size search space.
if 'name' in self.arch_config and self.arch_config['name'] == 'infer.shape.tiny':
return {'name': 'infer.shape.tiny', 'channels': self.arch_config['channels'],
'genotype': str2structure(self.arch_config['genotype']), 'num_classes': self.arch_config['class_num']}

View File

@ -68,7 +68,7 @@ class NATSsize(NASBenchMetaAPI):
self._archive_dir = os.path.join(os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1]))
else:
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], '{:}.{:}'.format(ALL_BASE_NAMES[-1], PICKLE_EXT))
print ('Try to use the default NATS-Bench (size) path from fast_mode={:} and path={:}.'.format(self._fast_mode, file_path_or_dict))
print ('{:} Try to use the default NATS-Bench (size) path from fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode, file_path_or_dict))
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path):
file_path_or_dict = str(file_path_or_dict)
if verbose:
@ -125,10 +125,15 @@ class NATSsize(NASBenchMetaAPI):
If index is None, overwrite all ckps.
"""
if self.verbose:
print('{:} Call clear_params with archive_root={:} and index={:}'.format(time_string(), archive_root, index))
print('{:} Call clear_params with archive_root={:} and index={:}'.format(
time_string(), archive_root, index))
if archive_root is None:
archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(ALL_BASE_NAMES[-1]))
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
if not os.path.isdir(archive_root):
warnings.warn('The input archive_root is None and the default archive_root path ({:}) does not exist, try to use self.archive_dir.'.format(archive_root))
archive_root = self.archive_dir
if archive_root is None or not os.path.isdir(archive_root):
raise ValueError('Invalid archive_root : {:}'.format(archive_root))
if index is None:
indexes = list(range(len(self)))
else:

View File

@ -4,7 +4,7 @@
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
#####################################################################################
# The history of benchmark files (the name is NATS-tss-[version]-[md5].pickle.pbz2) #
# [2020.08.31] #
# [2020.08.31] NATS-tss-v1_0-3ffb9.pickle.pbz2 #
#####################################################################################
import os, copy, random, numpy as np
from pathlib import Path
@ -19,14 +19,14 @@ from .api_utils import remap_dataset_set_names
PICKLE_EXT = 'pickle.pbz2'
ALL_BASE_NAMES = ['NATS-tss-v1_0-xxxxx']
ALL_BASE_NAMES = ['NATS-tss-v1_0-3ffb9']
def print_information(information, extra_info=None, show=False):
dataset_names = information.get_dataset_names()
strings = [information.arch_str, 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)]
def metric2str(loss, acc):
return 'loss = {:.3f}, top1 = {:.2f}%'.format(loss, acc)
return 'loss = {:.3f} & top1 = {:.2f}%'.format(loss, acc)
for ida, dataset in enumerate(dataset_names):
metric = information.get_compute_costs(dataset)
@ -61,12 +61,15 @@ class NATStopology(NASBenchMetaAPI):
self._archive_dir = None
self.reset_time()
if file_path_or_dict is None:
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], ALL_BENCHMARK_FILES[-1])
if self._fast_mode:
self._archive_dir = os.path.join(os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1]))
else:
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], '{:}.{:}'.format(ALL_BASE_NAMES[-1], PICKLE_EXT))
print ('{:} Try to use the default NATS-Bench (topology) path from {:}.'.format(time_string(), file_path_or_dict))
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path):
file_path_or_dict = str(file_path_or_dict)
if verbose:
print('{:} Try to create the NATS-Bench (topology) api from {:}'.format(time_string(), file_path_or_dict))
print('{:} Try to create the NATS-Bench (topology) api from {:} with fast_mode={:}'.format(time_string(), file_path_or_dict, fast_mode))
if not os.path.isfile(file_path_or_dict) and not os.path.isdir(file_path_or_dict):
raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict))
self.filename = Path(file_path_or_dict).name
@ -82,7 +85,7 @@ class NATStopology(NASBenchMetaAPI):
file_path_or_dict = pickle_load(file_path_or_dict)
elif isinstance(file_path_or_dict, dict):
file_path_or_dict = copy.deepcopy(file_path_or_dict)
self.verbose = verbose # [TODO] a flag indicating whether to print more logs
self.verbose = verbose
if isinstance(file_path_or_dict, dict):
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
@ -91,13 +94,13 @@ class NATStopology(NASBenchMetaAPI):
self.arch2infos_dict = OrderedDict()
self._avaliable_hps = set()
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
all_info = file_path_or_dict['arch2infos'][xkey]
all_infos = file_path_or_dict['arch2infos'][xkey]
hp2archres = OrderedDict()
for hp_key, results in all_infos.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter
self.arch2infos_dict[xkey] = hp2archres
self.evaluated_indexes = list(file_path_or_dict['evaluated_indexes'])
self.evaluated_indexes = set(file_path_or_dict['evaluated_indexes'])
elif self.archive_dir is not None:
benchmark_meta = pickle_load('{:}/meta.{:}'.format(self.archive_dir, PICKLE_EXT))
self.meta_archs = copy.deepcopy(benchmark_meta['meta_archs'])
@ -116,7 +119,7 @@ class NATStopology(NASBenchMetaAPI):
def reload(self, archive_root: Text = None, index: int = None):
"""Overwrite all information of the 'index'-th architecture in the search space.
It will load its data from 'archive_root'.
If index is None, overwrite all ckps.
"""
if self.verbose:
print('{:} Call clear_params with archive_root={:} and index={:}'.format(