Sync NATS-Bench's d11018d

This commit is contained in:
D-X-Y 2020-10-15 19:56:20 +11:00
parent c7a54fd08b
commit bc0ac65882
5 changed files with 1065 additions and 543 deletions

View File

@ -3,15 +3,18 @@
############################################################################## ##############################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size # # NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
############################################################################## ##############################################################################
# The official Application Programming Interface (API) for NATS-Bench. # """The official Application Programming Interface (API) for NATS-Bench."""
############################################################################## from nats_bench.api_size import NATSsize
from .api_utils import pickle_save, pickle_load from nats_bench.api_topology import NATStopology
from .api_utils import ArchResults, ResultsCount from nats_bench.api_utils import ArchResults
from .api_topology import NATStopology from nats_bench.api_utils import pickle_load
from .api_size import NATSsize from nats_bench.api_utils import pickle_save
from nats_bench.api_utils import ResultsCount
NATS_BENCH_API_VERSIONs = ['v1.0'] # [2020.08.31] NATS_BENCH_API_VERSIONs = ['v1.0'] # [2020.08.31]
NATS_BENCH_SSS_NAMEs = ('sss', 'size')
NATS_BENCH_TSS_NAMEs = ('tss', 'topology')
def version(): def version():
@ -24,13 +27,43 @@ def create(file_path_or_dict, search_space, fast_mode=False, verbose=True):
Args: Args:
file_path_or_dict: None or a file path or a directory path. file_path_or_dict: None or a file path or a directory path.
search_space: This is a string indicates the search space in NATS-Bench. search_space: This is a string indicates the search space in NATS-Bench.
fast_mode: If True, we will not load all the data at initialization, instead, the data for each candidate architecture will be loaded when quering it; fast_mode: If True, we will not load all the data at initialization,
If False, we will load all the data during initialization. instead, the data for each candidate architecture will be loaded when
quering it; If False, we will load all the data during initialization.
verbose: This is a flag to indicate whether log additional information. verbose: This is a flag to indicate whether log additional information.
Raises:
ValueError: If not find the matched serach space description.
Returns:
The created NATS-Bench API.
""" """
if search_space in ['tss', 'topology']: if search_space in NATS_BENCH_TSS_NAMEs:
return NATStopology(file_path_or_dict, fast_mode, verbose) return NATStopology(file_path_or_dict, fast_mode, verbose)
elif search_space in ['sss', 'size']: elif search_space in NATS_BENCH_SSS_NAMEs:
return NATSsize(file_path_or_dict, fast_mode, verbose) return NATSsize(file_path_or_dict, fast_mode, verbose)
else: else:
raise ValueError('invalid search space : {:}'.format(search_space)) raise ValueError('invalid search space : {:}'.format(search_space))
def search_space_info(main_tag, aux_tag):
"""Obtain the search space information."""
nats_sss = dict(candidates=[8, 16, 24, 32, 40, 48, 56, 64],
num_layers=5)
nats_tss = dict(op_names=['none', 'skip_connect',
'nor_conv_1x1', 'nor_conv_3x3',
'avg_pool_3x3'],
num_nodes=4)
if main_tag == 'nats-bench':
if aux_tag in NATS_BENCH_SSS_NAMEs:
return nats_sss
elif aux_tag in NATS_BENCH_TSS_NAMEs:
return nats_tss
else:
raise ValueError('Unknown auxiliary tag: {:}'.format(aux_tag))
elif main_tag == 'nas-bench-201':
if aux_tag is not None:
raise ValueError('For NAS-Bench-201, the auxiliary tag should be None.')
return nats_tss
else:
raise ValueError('Unknown main tag: {:}'.format(main_tag))

View File

@ -1,65 +1,84 @@
##################################################### #####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
############################################################################## ##############################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size # # NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
##################################################################################### ##############################################################################
# The history of benchmark files (the name is NATS-sss-[version]-[md5].pickle.pbz2) # # The history of benchmark files are as follows, #
# [2020.08.31] NATS-sss-v1_0-50262.pickle.pbz2 # # where the format is (the name is NATS-sss-[version]-[md5].pickle.pbz2) #
##################################################################################### # [2020.08.31] NATS-sss-v1_0-50262.pickle.pbz2 #
import os, copy, random, numpy as np ##############################################################################
from typing import List, Text, Union, Dict, Optional # pylint: disable=line-too-long
from collections import OrderedDict, defaultdict """The API for size search space in NATS-Bench."""
from .api_utils import time_string import collections
from .api_utils import pickle_load import copy
from .api_utils import ArchResults import os
from .api_utils import NASBenchMetaAPI import random
from .api_utils import remap_dataset_set_names from typing import Dict, Optional, Text, Union, Any
from .api_utils import nats_is_dir
from .api_utils import nats_is_file from nats_bench.api_utils import ArchResults
from .api_utils import PICKLE_EXT from nats_bench.api_utils import NASBenchMetaAPI
from nats_bench.api_utils import nats_is_dir
from nats_bench.api_utils import nats_is_file
from nats_bench.api_utils import PICKLE_EXT
from nats_bench.api_utils import pickle_load
from nats_bench.api_utils import time_string
ALL_BASE_NAMES = ['NATS-sss-v1_0-50262'] ALL_BASE_NAMES = ['NATS-sss-v1_0-50262']
def print_information(information, extra_info=None, show=False): def print_information(information, extra_info=None, show=False):
"""print out the information of a given ArchResults."""
dataset_names = information.get_dataset_names() dataset_names = information.get_dataset_names()
strings = [information.arch_str, 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)] strings = [
information.arch_str,
'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)
]
def metric2str(loss, acc): def metric2str(loss, acc):
return 'loss = {:.3f} & top1 = {:.2f}%'.format(loss, acc) return 'loss = {:.3f} & top1 = {:.2f}%'.format(loss, acc)
for ida, dataset in enumerate(dataset_names): for dataset in dataset_names:
metric = information.get_compute_costs(dataset) metric = information.get_compute_costs(dataset)
flop, param, latency = metric['flops'], metric['params'], metric['latency'] flop, param, latency = metric['flops'], metric['params'], metric['latency']
str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency is not None and latency > 0 else None) str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(
dataset, flop, param,
'{:.2f}'.format(latency *
1000) if latency is not None and latency > 0 else None)
train_info = information.get_metrics(dataset, 'train') train_info = information.get_metrics(dataset, 'train')
if dataset == 'cifar10-valid': if dataset == 'cifar10-valid':
valid_info = information.get_metrics(dataset, 'x-valid') valid_info = information.get_metrics(dataset, 'x-valid')
test__info = information.get_metrics(dataset, 'ori-test') test__info = information.get_metrics(dataset, 'ori-test')
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format( str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(
dataset, metric2str(train_info['loss'], train_info['accuracy']), dataset, metric2str(train_info['loss'], train_info['accuracy']),
metric2str(valid_info['loss'], valid_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']),
metric2str(test__info['loss'], test__info['accuracy'])) metric2str(test__info['loss'], test__info['accuracy']))
elif dataset == 'cifar10': elif dataset == 'cifar10':
test__info = information.get_metrics(dataset, 'ori-test') test__info = information.get_metrics(dataset, 'ori-test')
str2 = '{:14s} train : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy'])) str2 = '{:14s} train : [{:}], test : [{:}]'.format(
dataset, metric2str(train_info['loss'], train_info['accuracy']),
metric2str(test__info['loss'], test__info['accuracy']))
else: else:
valid_info = information.get_metrics(dataset, 'x-valid') valid_info = information.get_metrics(dataset, 'x-valid')
test__info = information.get_metrics(dataset, 'x-test') test__info = information.get_metrics(dataset, 'x-test')
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy'])) str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(
dataset, metric2str(train_info['loss'], train_info['accuracy']),
metric2str(valid_info['loss'], valid_info['accuracy']),
metric2str(test__info['loss'], test__info['accuracy']))
strings += [str1, str2] strings += [str1, str2]
if show: print('\n'.join(strings)) if show: print('\n'.join(strings))
return strings return strings
"""
This is the class for the API of size search space in NATS-Bench.
"""
class NATSsize(NASBenchMetaAPI): class NATSsize(NASBenchMetaAPI):
"""This is the class for the API of size search space in NATS-Bench."""
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """ def __init__(self,
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, fast_mode: bool=False, verbose: bool=True): file_path_or_dict: Optional[Union[Text, Dict[Text, Any]]] = None,
self.ALL_BASE_NAMES = ALL_BASE_NAMES fast_mode: bool = False,
verbose: bool = True):
"""The initialization function that takes the dataset file path (or a dict loaded from that path) as input."""
self._all_base_names = ALL_BASE_NAMES
self.filename = None self.filename = None
self._search_space_name = 'size' self._search_space_name = 'size'
self._fast_mode = fast_mode self._fast_mode = fast_mode
@ -67,25 +86,36 @@ class NATSsize(NASBenchMetaAPI):
self.reset_time() self.reset_time()
if file_path_or_dict is None: if file_path_or_dict is None:
if self._fast_mode: if self._fast_mode:
self._archive_dir = os.path.join(os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1])) self._archive_dir = os.path.join(
os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1]))
else: else:
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], '{:}.{:}'.format(ALL_BASE_NAMES[-1], PICKLE_EXT)) file_path_or_dict = os.path.join(
print ('{:} Try to use the default NATS-Bench (size) path from fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode, file_path_or_dict)) os.environ['TORCH_HOME'], '{:}.{:}'.format(
ALL_BASE_NAMES[-1], PICKLE_EXT))
print('{:} Try to use the default NATS-Bench (size) path from '
'fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode,
file_path_or_dict))
if isinstance(file_path_or_dict, str): if isinstance(file_path_or_dict, str):
file_path_or_dict = str(file_path_or_dict) file_path_or_dict = str(file_path_or_dict)
if verbose: if verbose:
print('{:} Try to create the NATS-Bench (size) api from {:} with fast_mode={:}'.format(time_string(), file_path_or_dict, fast_mode)) print('{:} Try to create the NATS-Bench (size) api '
if not nats_is_file(file_path_or_dict) and not nats_is_dir(file_path_or_dict): 'from {:} with fast_mode={:}'.format(
raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict)) time_string(), file_path_or_dict, fast_mode))
if not nats_is_file(file_path_or_dict) and not nats_is_dir(
file_path_or_dict):
raise ValueError('{:} is neither a file or a dir.'.format(
file_path_or_dict))
self.filename = os.path.basename(file_path_or_dict) self.filename = os.path.basename(file_path_or_dict)
if fast_mode: if fast_mode:
if nats_is_file(file_path_or_dict): if nats_is_file(file_path_or_dict):
raise ValueError('fast_mode={:} must feed the path for directory : {:}'.format(fast_mode, file_path_or_dict)) raise ValueError('fast_mode={:} must feed the path for directory '
': {:}'.format(fast_mode, file_path_or_dict))
else: else:
self._archive_dir = file_path_or_dict self._archive_dir = file_path_or_dict
else: else:
if nats_is_dir(file_path_or_dict): if nats_is_dir(file_path_or_dict):
raise ValueError('fast_mode={:} must feed the path for file : {:}'.format(fast_mode, file_path_or_dict)) raise ValueError('fast_mode={:} must feed the path for file '
': {:}'.format(fast_mode, file_path_or_dict))
else: else:
file_path_or_dict = pickle_load(file_path_or_dict) file_path_or_dict = pickle_load(file_path_or_dict)
elif isinstance(file_path_or_dict, dict): elif isinstance(file_path_or_dict, dict):
@ -93,68 +123,95 @@ class NATSsize(NASBenchMetaAPI):
self.verbose = verbose self.verbose = verbose
if isinstance(file_path_or_dict, dict): if isinstance(file_path_or_dict, dict):
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes') keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key) for key in keys:
if key not in file_path_or_dict:
raise ValueError('Can not find key[{:}] in the dict'.format(key))
self.meta_archs = copy.deepcopy(file_path_or_dict['meta_archs']) self.meta_archs = copy.deepcopy(file_path_or_dict['meta_archs'])
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults # NOTE(xuanyidong): This is a dict mapping each architecture to a dict,
self.arch2infos_dict = OrderedDict() # where the key is #epochs and the value is ArchResults
self.arch2infos_dict = collections.OrderedDict()
self._avaliable_hps = set() self._avaliable_hps = set()
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())): for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
all_infos = file_path_or_dict['arch2infos'][xkey] all_infos = file_path_or_dict['arch2infos'][xkey]
hp2archres = OrderedDict() hp2archres = collections.OrderedDict()
for hp_key, results in all_infos.items(): for hp_key, results in all_infos.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results) hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter
self.arch2infos_dict[xkey] = hp2archres self.arch2infos_dict[xkey] = hp2archres
self.evaluated_indexes = set(file_path_or_dict['evaluated_indexes']) self.evaluated_indexes = set(file_path_or_dict['evaluated_indexes'])
elif self.archive_dir is not None: elif self.archive_dir is not None:
benchmark_meta = pickle_load('{:}/meta.{:}'.format(self.archive_dir, PICKLE_EXT)) benchmark_meta = pickle_load('{:}/meta.{:}'.format(
self.archive_dir, PICKLE_EXT))
self.meta_archs = copy.deepcopy(benchmark_meta['meta_archs']) self.meta_archs = copy.deepcopy(benchmark_meta['meta_archs'])
self.arch2infos_dict = OrderedDict() self.arch2infos_dict = collections.OrderedDict()
self._avaliable_hps = set() self._avaliable_hps = set()
self.evaluated_indexes = set() self.evaluated_indexes = set()
else: else:
raise ValueError('file_path_or_dict [{:}] must be a dict or archive_dir must be set'.format(type(file_path_or_dict))) raise ValueError('file_path_or_dict [{:}] must be a dict or archive_dir '
'must be set'.format(type(file_path_or_dict)))
self.archstr2index = {} self.archstr2index = {}
for idx, arch in enumerate(self.meta_archs): for idx, arch in enumerate(self.meta_archs):
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch]) if arch in self.archstr2index:
raise ValueError('This [{:}]-th arch {:} already in the '
'dict ({:}).'.format(
idx, arch, self.archstr2index[arch]))
self.archstr2index[arch] = idx self.archstr2index[arch] = idx
if self.verbose: if self.verbose:
print('{:} Create NATS-Bench (size) done with {:}/{:} architectures avaliable.'.format( print('{:} Create NATS-Bench (size) done with {:}/{:} architectures '
time_string(), len(self.evaluated_indexes), len(self.meta_archs))) 'avaliable.'.format(time_string(),
len(self.evaluated_indexes),
len(self.meta_archs)))
def query_info_str_by_arch(self, arch, hp: Text='12'): def query_info_str_by_arch(self, arch, hp: Text = '12'):
""" This function is used to query the information of a specific architecture """Query the information of a specific architecture.
'arch' can be an architecture index or an architecture string
When hp=01, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/01E.config' Args:
When hp=12, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/12E.config' arch: it can be an architecture index or an architecture string.
When hp=90, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/90E.config'
The difference between these three configurations are the number of training epochs. hp: the hyperparamete indicator, could be 01, 12, or 90. The difference
between these three configurations are the number of training epochs.
Returns:
ArchResults instance
""" """
if self.verbose: if self.verbose:
print('{:} Call query_info_str_by_arch with arch={:} and hp={:}'.format(time_string(), arch, hp)) print('{:} Call query_info_str_by_arch with arch={:}'
'and hp={:}'.format(time_string(), arch, hp))
return self._query_info_str_by_arch(arch, hp, print_information) return self._query_info_str_by_arch(arch, hp, print_information)
def get_more_info(self, index, dataset: Text, iepoch=None, hp='12', is_random=True): def get_more_info(self,
"""This function will return the metric for the `index`-th architecture index,
`dataset` indicates the dataset: dataset,
iepoch=None,
hp: Text = '12',
is_random: bool = True):
"""Return the metric for the `index`-th architecture.
Args:
index: the architecture index.
dataset:
'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set 'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
'cifar10' : using the proposed train+valid set of CIFAR-10 as the training set 'cifar10' : using the proposed train+valid set of CIFAR-10 as the training set
'cifar100' : using the proposed train set of CIFAR-100 as the training set 'cifar100' : using the proposed train set of CIFAR-100 as the training set
'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set 'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set
`iepoch` indicates the index of training epochs from 0 to 11/199. iepoch: the index of training epochs from 0 to 11/199.
When iepoch=None, it will return the metric for the last training epoch When iepoch=None, it will return the metric for the last training epoch
When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0) When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0)
`hp` indicates different hyper-parameters for training hp: indicates different hyper-parameters for training
When hp=01, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 01 epochs When hp=01, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 01 epochs
When hp=12, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 12 epochs When hp=12, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 12 epochs
When hp=90, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 90 epochs When hp=90, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 90 epochs
`is_random` is_random:
When is_random=True, the performance of a random architecture will be returned When is_random=True, the performance of a random architecture will be returned
When is_random=False, the performanceo of all trials will be averaged. When is_random=False, the performanceo of all trials will be averaged.
Returns:
a dict, where key is the metric name and value is its value.
""" """
if self.verbose: if self.verbose:
print('{:} Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format( print('{:} Call the get_more_info function with index={:}, dataset={:}, '
time_string(), index, dataset, iepoch, hp, is_random)) 'iepoch={:}, hp={:}, and is_random={:}.'.format(
time_string(), index, dataset, iepoch, hp, is_random))
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
self._prepare_info(index) self._prepare_info(index)
if index not in self.arch2infos_dict: if index not in self.arch2infos_dict:
@ -165,38 +222,47 @@ class NATSsize(NASBenchMetaAPI):
seeds = archresult.get_dataset_seeds(dataset) seeds = archresult.get_dataset_seeds(dataset)
is_random = random.choice(seeds) is_random = random.choice(seeds)
# collect the training information # collect the training information
train_info = archresult.get_metrics(dataset, 'train', iepoch=iepoch, is_random=is_random) train_info = archresult.get_metrics(
dataset, 'train', iepoch=iepoch, is_random=is_random)
total = train_info['iepoch'] + 1 total = train_info['iepoch'] + 1
xinfo = {'train-loss' : train_info['loss'], xinfo = {
'train-accuracy': train_info['accuracy'], 'train-loss': train_info['loss'],
'train-per-time': train_info['all_time'] / total, 'train-accuracy': train_info['accuracy'],
'train-all-time': train_info['all_time']} 'train-per-time': train_info['all_time'] / total,
'train-all-time': train_info['all_time']
}
# collect the evaluation information # collect the evaluation information
if dataset == 'cifar10-valid': if dataset == 'cifar10-valid':
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random) valid_info = archresult.get_metrics(
dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
try: try:
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) test_info = archresult.get_metrics(
except: dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
except Exception as unused_e: # pylint: disable=broad-except
test_info = None test_info = None
valtest_info = None valtest_info = None
else: else:
try: # collect results on the proposed test set try: # collect results on the proposed test set
if dataset == 'cifar10': if dataset == 'cifar10':
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) test_info = archresult.get_metrics(
dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else: else:
test_info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random) test_info = archresult.get_metrics(
except: dataset, 'x-test', iepoch=iepoch, is_random=is_random)
except Exception as unused_e: # pylint: disable=broad-except
test_info = None test_info = None
try: # collect results on the proposed validation set try: # collect results on the proposed validation set
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random) valid_info = archresult.get_metrics(
except: dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
except Exception as unused_e: # pylint: disable=broad-except
valid_info = None valid_info = None
try: try:
if dataset != 'cifar10': if dataset != 'cifar10':
valtest_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) valtest_info = archresult.get_metrics(
dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else: else:
valtest_info = None valtest_info = None
except: except Exception as unused_e: # pylint: disable=broad-except
valtest_info = None valtest_info = None
if valid_info is not None: if valid_info is not None:
xinfo['valid-loss'] = valid_info['loss'] xinfo['valid-loss'] = valid_info['loss']
@ -216,11 +282,5 @@ class NATSsize(NASBenchMetaAPI):
return xinfo return xinfo
def show(self, index: int = -1) -> None: def show(self, index: int = -1) -> None:
""" """Print the information of a specific (or all) architecture(s)."""
This function will print the information of a specific (or all) architecture(s).
:param index: If the index < 0: it will loop for all architectures and print their information one by one.
else: it will print the information of the 'index'-th architecture.
:return: nothing
"""
self._show(index, print_information) self._show(index, print_information)

View File

@ -0,0 +1,59 @@
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 ##########################
##############################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
##############################################################################
"""This file is used to quickly test the API."""
import random
from nats_bench.api_size import NATSsize
from nats_bench.api_topology import NATStopology
def test_nats_bench_tss(benchmark_dir):
return test_nats_bench(benchmark_dir, True)
def test_nats_bench_sss(benchmark_dir):
return test_nats_bench(benchmark_dir, False)
def test_nats_bench(benchmark_dir, is_tss, verbose=False):
if is_tss:
api = NATStopology(benchmark_dir, True, verbose)
else:
api = NATSsize(benchmark_dir, True, verbose)
test_indexes = [random.randint(0, len(api) - 1) for _ in range(10)]
key2dataset = {'cifar10': 'CIFAR-10',
'cifar100': 'CIFAR-100',
'ImageNet16-120': 'ImageNet16-120'}
for index in test_indexes:
print('\n\nEvaluate the {:5d}-th architecture.'.format(index))
for key, dataset in key2dataset.items():
# Query the loss / accuracy / time for the `index`-th candidate
# architecture on CIFAR-10
# info is a dict, where you can easily figure out the meaning by key
info = api.get_more_info(index, key)
print(' -->> The performance on {:}: {:}'.format(dataset, info))
# Query the flops, params, latency. info is a dict.
info = api.get_cost_info(index, key)
print(' -->> The cost info on {:}: {:}'.format(dataset, info))
# Simulate the training of the `index`-th candidate:
validation_accuracy, latency, time_cost, current_total_time_cost = api.simulate_train_eval(
index, dataset=key, hp='12')
print(' -->> The validation accuracy={:}, latency={:}, '
'the current time cost={:} s, accumulated time cost={:} s'
.format(validation_accuracy, latency, time_cost,
current_total_time_cost))
# Print the configuration of the `index`-th architecture on CIFAR-10
config = api.get_net_config(index, key)
print(' -->> The configuration on {:} is {:}'.format(dataset, config))
# Show the information of the `index`-th architecture
api.show(index)

View File

@ -2,61 +2,83 @@
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
############################################################################## ##############################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size # # NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
##################################################################################### ##############################################################################
# The history of benchmark files (the name is NATS-tss-[version]-[md5].pickle.pbz2) # # The history of benchmark files are as follows, #
# [2020.08.31] NATS-tss-v1_0-3ffb9.pickle.pbz2 # # where the format is (the name is NATS-tss-[version]-[md5].pickle.pbz2) #
##################################################################################### # [2020.08.31] NATS-tss-v1_0-3ffb9.pickle.pbz2 #
import os, copy, random, numpy as np ##############################################################################
from typing import List, Text, Union, Dict, Optional # pylint: disable=line-too-long
from collections import OrderedDict, defaultdict """The API for topology search space in NATS-Bench."""
import warnings import collections
from .api_utils import time_string import copy
from .api_utils import pickle_load import os
from .api_utils import ArchResults import random
from .api_utils import NASBenchMetaAPI from typing import Any, Dict, List, Optional, Text, Union
from .api_utils import remap_dataset_set_names
from .api_utils import nats_is_dir from nats_bench.api_utils import ArchResults
from .api_utils import nats_is_file from nats_bench.api_utils import NASBenchMetaAPI
from .api_utils import PICKLE_EXT from nats_bench.api_utils import nats_is_dir
from nats_bench.api_utils import nats_is_file
from nats_bench.api_utils import PICKLE_EXT
from nats_bench.api_utils import pickle_load
from nats_bench.api_utils import time_string
import numpy as np
ALL_BASE_NAMES = ['NATS-tss-v1_0-3ffb9'] ALL_BASE_NAMES = ['NATS-tss-v1_0-3ffb9']
def print_information(information, extra_info=None, show=False): def print_information(information, extra_info=None, show=False):
"""print out the information of a given ArchResults."""
dataset_names = information.get_dataset_names() dataset_names = information.get_dataset_names()
strings = [information.arch_str, 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)] strings = [
information.arch_str,
'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)
]
def metric2str(loss, acc): def metric2str(loss, acc):
return 'loss = {:.3f} & top1 = {:.2f}%'.format(loss, acc) return 'loss = {:.3f} & top1 = {:.2f}%'.format(loss, acc)
for ida, dataset in enumerate(dataset_names): for dataset in dataset_names:
metric = information.get_compute_costs(dataset) metric = information.get_compute_costs(dataset)
flop, param, latency = metric['flops'], metric['params'], metric['latency'] flop, param, latency = metric['flops'], metric['params'], metric['latency']
str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency is not None and latency > 0 else None) str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(
dataset, flop, param,
'{:.2f}'.format(latency *
1000) if latency is not None and latency > 0 else None)
train_info = information.get_metrics(dataset, 'train') train_info = information.get_metrics(dataset, 'train')
if dataset == 'cifar10-valid': if dataset == 'cifar10-valid':
valid_info = information.get_metrics(dataset, 'x-valid') valid_info = information.get_metrics(dataset, 'x-valid')
str2 = '{:14s} train : [{:}], valid : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy'])) str2 = '{:14s} train : [{:}], valid : [{:}]'.format(
dataset, metric2str(train_info['loss'], train_info['accuracy']),
metric2str(valid_info['loss'], valid_info['accuracy']))
elif dataset == 'cifar10': elif dataset == 'cifar10':
test__info = information.get_metrics(dataset, 'ori-test') test__info = information.get_metrics(dataset, 'ori-test')
str2 = '{:14s} train : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy'])) str2 = '{:14s} train : [{:}], test : [{:}]'.format(
dataset, metric2str(train_info['loss'], train_info['accuracy']),
metric2str(test__info['loss'], test__info['accuracy']))
else: else:
valid_info = information.get_metrics(dataset, 'x-valid') valid_info = information.get_metrics(dataset, 'x-valid')
test__info = information.get_metrics(dataset, 'x-test') test__info = information.get_metrics(dataset, 'x-test')
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy'])) str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(
dataset, metric2str(train_info['loss'], train_info['accuracy']),
metric2str(valid_info['loss'], valid_info['accuracy']),
metric2str(test__info['loss'], test__info['accuracy']))
strings += [str1, str2] strings += [str1, str2]
if show: print('\n'.join(strings)) if show: print('\n'.join(strings))
return strings return strings
"""
This is the class for the API of topology search space in NATS-Bench.
"""
class NATStopology(NASBenchMetaAPI): class NATStopology(NASBenchMetaAPI):
"""This is the class for the API of topology search space in NATS-Bench."""
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """ def __init__(self,
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, fast_mode: bool=False, verbose: bool=True): file_path_or_dict: Optional[Union[Text, Dict[Text, Any]]] = None,
self.ALL_BASE_NAMES = ALL_BASE_NAMES fast_mode: bool = False,
verbose: bool = True):
"""The initialization function that takes the dataset file path (or a dict loaded from that path) as input."""
self._all_base_names = ALL_BASE_NAMES
self.filename = None self.filename = None
self._search_space_name = 'topology' self._search_space_name = 'topology'
self._fast_mode = fast_mode self._fast_mode = fast_mode
@ -64,25 +86,35 @@ class NATStopology(NASBenchMetaAPI):
self.reset_time() self.reset_time()
if file_path_or_dict is None: if file_path_or_dict is None:
if self._fast_mode: if self._fast_mode:
self._archive_dir = os.path.join(os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1])) self._archive_dir = os.path.join(
os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1]))
else: else:
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], '{:}.{:}'.format(ALL_BASE_NAMES[-1], PICKLE_EXT)) file_path_or_dict = os.path.join(
print ('{:} Try to use the default NATS-Bench (topology) path from {:}.'.format(time_string(), file_path_or_dict)) os.environ['TORCH_HOME'], '{:}.{:}'.format(
ALL_BASE_NAMES[-1], PICKLE_EXT))
print('{:} Try to use the default NATS-Bench (topology) path '
'from {:}.'.format(time_string(), file_path_or_dict))
if isinstance(file_path_or_dict, str): if isinstance(file_path_or_dict, str):
file_path_or_dict = str(file_path_or_dict) file_path_or_dict = str(file_path_or_dict)
if verbose: if verbose:
print('{:} Try to create the NATS-Bench (topology) api from {:} with fast_mode={:}'.format(time_string(), file_path_or_dict, fast_mode)) print('{:} Try to create the NATS-Bench (topology) api '
if not nats_is_file(file_path_or_dict) and not nats_is_dir(file_path_or_dict): 'from {:} with fast_mode={:}'.format(
raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict)) time_string(), file_path_or_dict, fast_mode))
if not nats_is_file(file_path_or_dict) and not nats_is_dir(
file_path_or_dict):
raise ValueError('{:} is neither a file or a dir.'.format(
file_path_or_dict))
self.filename = os.path.basename(file_path_or_dict) self.filename = os.path.basename(file_path_or_dict)
if fast_mode: if fast_mode:
if nats_is_file(file_path_or_dict): if nats_is_file(file_path_or_dict):
raise ValueError('fast_mode={:} must feed the path for directory : {:}'.format(fast_mode, file_path_or_dict)) raise ValueError('fast_mode={:} must feed the path for directory '
': {:}'.format(fast_mode, file_path_or_dict))
else: else:
self._archive_dir = file_path_or_dict self._archive_dir = file_path_or_dict
else: else:
if nats_is_dir(file_path_or_dict): if nats_is_dir(file_path_or_dict):
raise ValueError('fast_mode={:} must feed the path for file : {:}'.format(fast_mode, file_path_or_dict)) raise ValueError('fast_mode={:} must feed the path for file '
': {:}'.format(fast_mode, file_path_or_dict))
else: else:
file_path_or_dict = pickle_load(file_path_or_dict) file_path_or_dict = pickle_load(file_path_or_dict)
elif isinstance(file_path_or_dict, dict): elif isinstance(file_path_or_dict, dict):
@ -90,65 +122,73 @@ class NATStopology(NASBenchMetaAPI):
self.verbose = verbose self.verbose = verbose
if isinstance(file_path_or_dict, dict): if isinstance(file_path_or_dict, dict):
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes') keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key) for key in keys:
if key not in file_path_or_dict:
raise ValueError('Can not find key[{:}] in the dict'.format(key))
self.meta_archs = copy.deepcopy(file_path_or_dict['meta_archs']) self.meta_archs = copy.deepcopy(file_path_or_dict['meta_archs'])
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults # NOTE(xuanyidong): This is a dict mapping each architecture to a dict,
self.arch2infos_dict = OrderedDict() # where the key is #epochs and the value is ArchResults
self.arch2infos_dict = collections.OrderedDict()
self._avaliable_hps = set() self._avaliable_hps = set()
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())): for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
all_infos = file_path_or_dict['arch2infos'][xkey] all_infos = file_path_or_dict['arch2infos'][xkey]
hp2archres = OrderedDict() hp2archres = collections.OrderedDict()
for hp_key, results in all_infos.items(): for hp_key, results in all_infos.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results) hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter
self.arch2infos_dict[xkey] = hp2archres self.arch2infos_dict[xkey] = hp2archres
self.evaluated_indexes = set(file_path_or_dict['evaluated_indexes']) self.evaluated_indexes = set(file_path_or_dict['evaluated_indexes'])
elif self.archive_dir is not None: elif self.archive_dir is not None:
benchmark_meta = pickle_load('{:}/meta.{:}'.format(self.archive_dir, PICKLE_EXT)) benchmark_meta = pickle_load('{:}/meta.{:}'.format(
self.archive_dir, PICKLE_EXT))
self.meta_archs = copy.deepcopy(benchmark_meta['meta_archs']) self.meta_archs = copy.deepcopy(benchmark_meta['meta_archs'])
self.arch2infos_dict = OrderedDict() self.arch2infos_dict = collections.OrderedDict()
self._avaliable_hps = set() self._avaliable_hps = set()
self.evaluated_indexes = set() self.evaluated_indexes = set()
else: else:
raise ValueError('file_path_or_dict [{:}] must be a dict or archive_dir must be set'.format(type(file_path_or_dict))) raise ValueError('file_path_or_dict [{:}] must be a dict or archive_dir '
'must be set'.format(type(file_path_or_dict)))
self.archstr2index = {} self.archstr2index = {}
for idx, arch in enumerate(self.meta_archs): for idx, arch in enumerate(self.meta_archs):
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch]) if arch in self.archstr2index:
raise ValueError('This [{:}]-th arch {:} already in the '
'dict ({:}).'.format(
idx, arch, self.archstr2index[arch]))
self.archstr2index[arch] = idx self.archstr2index[arch] = idx
if self.verbose: if self.verbose:
print('{:} Create NATS-Bench (topology) done with {:}/{:} architectures avaliable.'.format( print('{:} Create NATS-Bench (topology) done with {:}/{:} architectures '
time_string(), len(self.evaluated_indexes), len(self.meta_archs))) 'avaliable.'.format(time_string(),
len(self.evaluated_indexes),
len(self.meta_archs)))
def query_info_str_by_arch(self, arch, hp: Text='12'): def query_info_str_by_arch(self, arch, hp: Text = '12'):
""" This function is used to query the information of a specific architecture """Query the information of a specific architecture.
'arch' can be an architecture index or an architecture string
When hp=12, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/12E.config' Args:
When hp=200, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/200E.config' arch: it can be an architecture index or an architecture string.
The difference between these three configurations are the number of training epochs.
hp: the hyperparamete indicator, could be 12 or 200. The difference
between these three configurations are the number of training epochs.
Returns:
ArchResults instance
""" """
if self.verbose: if self.verbose:
print('{:} Call query_info_str_by_arch with arch={:} and hp={:}'.format(time_string(), arch, hp)) print('{:} Call query_info_str_by_arch with arch={:}'
'and hp={:}'.format(time_string(), arch, hp))
return self._query_info_str_by_arch(arch, hp, print_information) return self._query_info_str_by_arch(arch, hp, print_information)
# obtain the metric for the `index`-th architecture def get_more_info(self,
# `dataset` indicates the dataset: index,
# 'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set dataset,
# 'cifar10' : using the proposed train+valid set of CIFAR-10 as the training set iepoch=None,
# 'cifar100' : using the proposed train set of CIFAR-100 as the training set hp: Text = '12',
# 'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set is_random: bool = True):
# `iepoch` indicates the index of training epochs from 0 to 11/199. """Return the metric for the `index`-th architecture."""
# When iepoch=None, it will return the metric for the last training epoch
# When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0)
# `use_12epochs_result` indicates different hyper-parameters for training
# When use_12epochs_result=True, it trains the network with 12 epochs and the LR decayed from 0.1 to 0 within 12 epochs
# When use_12epochs_result=False, it trains the network with 200 epochs and the LR decayed from 0.1 to 0 within 200 epochs
# `is_random`
# When is_random=True, the performance of a random architecture will be returned
# When is_random=False, the performanceo of all trials will be averaged.
def get_more_info(self, index, dataset, iepoch=None, hp='12', is_random=True):
if self.verbose: if self.verbose:
print('{:} Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format( print('{:} Call the get_more_info function with index={:}, dataset={:}, '
time_string(), index, dataset, iepoch, hp, is_random)) 'iepoch={:}, hp={:}, and is_random={:}.'.format(
time_string(), index, dataset, iepoch, hp, is_random))
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
self._prepare_info(index) self._prepare_info(index)
if index not in self.arch2infos_dict: if index not in self.arch2infos_dict:
@ -161,36 +201,43 @@ class NATStopology(NASBenchMetaAPI):
# collect the training information # collect the training information
train_info = archresult.get_metrics(dataset, 'train', iepoch=iepoch, is_random=is_random) train_info = archresult.get_metrics(dataset, 'train', iepoch=iepoch, is_random=is_random)
total = train_info['iepoch'] + 1 total = train_info['iepoch'] + 1
xinfo = {'train-loss' : train_info['loss'], xinfo = {
'train-accuracy': train_info['accuracy'], 'train-loss':
'train-per-time': train_info['all_time'] / total if train_info['all_time'] is not None else None, train_info['loss'],
'train-all-time': train_info['all_time']} 'train-accuracy':
train_info['accuracy'],
'train-per-time':
train_info['all_time'] /
total if train_info['all_time'] is not None else None,
'train-all-time':
train_info['all_time']
}
# collect the evaluation information # collect the evaluation information
if dataset == 'cifar10-valid': if dataset == 'cifar10-valid':
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random) valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
try: try:
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
except: except Exception as unused_e: # pylint: disable=broad-except
test_info = None test_info = None
valtest_info = None valtest_info = None
else: else:
try: # collect results on the proposed test set try: # collect results on the proposed test set
if dataset == 'cifar10': if dataset == 'cifar10':
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else: else:
test_info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random) test_info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
except: except Exception as unused_e: # pylint: disable=broad-except
test_info = None test_info = None
try: # collect results on the proposed validation set try: # collect results on the proposed validation set
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random) valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
except: except Exception as unused_e: # pylint: disable=broad-except
valid_info = None valid_info = None
try: try:
if dataset != 'cifar10': if dataset != 'cifar10':
valtest_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) valtest_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else: else:
valtest_info = None valtest_info = None
except: except Exception as unused_e: # pylint: disable=broad-except
valtest_info = None valtest_info = None
if valid_info is not None: if valid_info is not None:
xinfo['valid-loss'] = valid_info['loss'] xinfo['valid-loss'] = valid_info['loss']
@ -214,46 +261,52 @@ class NATStopology(NASBenchMetaAPI):
self._show(index, print_information) self._show(index, print_information)
@staticmethod @staticmethod
def str2lists(arch_str: Text) -> List[tuple]: def str2lists(arch_str: Text) -> List[Any]:
""" """Shows how to read the string-based architecture encoding.
This function shows how to read the string-based architecture encoding.
It is the same as the `str2structure` func in `AutoDL-Projects/lib/models/cell_searchs/genotypes.py`
:param Args:
arch_str: the input is a string indicates the architecture topology, such as arch_str: the input is a string indicates the architecture topology, such as
|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2| |nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|
:return: a list of tuple, contains multiple (op, input_node_index) pairs. Returns:
a list of tuple, contains multiple (op, input_node_index) pairs.
:usage [USAGE]
It is the same as the `str2structure` func in AutoDL-Projects:
`github.com/D-X-Y/AutoDL-Projects/lib/models/cell_searchs/genotypes.py`
```
arch = api.str2lists( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' ) arch = api.str2lists( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
print ('there are {:} nodes in this arch'.format(len(arch)+1)) # arch is a list print ('there are {:} nodes in this arch'.format(len(arch)+1)) # arch is a list
for i, node in enumerate(arch): for i, node in enumerate(arch):
print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node)) print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node))
```
""" """
node_strs = arch_str.split('+') node_strs = arch_str.split('+')
genotypes = [] genotypes = []
for i, node_str in enumerate(node_strs): for unused_i, node_str in enumerate(node_strs):
inputs = list(filter(lambda x: x != '', node_str.split('|'))) inputs = list(filter(lambda x: x != '', node_str.split('|'))) # pylint: disable=g-explicit-bool-comparison
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput) for xinput in inputs:
inputs = ( xi.split('~') for xi in inputs ) assert len(
input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs) xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
genotypes.append( input_infos ) inputs = (xi.split('~') for xi in inputs)
input_infos = tuple((op, int(idx)) for (op, idx) in inputs)
genotypes.append(input_infos)
return genotypes return genotypes
@staticmethod @staticmethod
def str2matrix(arch_str: Text, def str2matrix(arch_str: Text,
search_space: List[Text] = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']) -> np.ndarray: search_space: List[Text] = ('none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3')) -> np.ndarray:
""" """Convert the string-based architecture encoding to the encoding strategy in NAS-Bench-101.
This func shows how to convert the string-based architecture encoding to the encoding strategy in NAS-Bench-101.
:param Args:
arch_str: the input is a string indicates the architecture topology, such as arch_str: the input is a string indicates the architecture topology, such as
|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2| |nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|
search_space: a list of operation string, the default list is the topology search space for NATS-BENCH. search_space: a list of operation string, the default list is the topology search space for NATS-BENCH.
the default value should be be consistent with this line https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/models/cell_operations.py#L24 the default value should be be consistent with this line https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/models/cell_operations.py#L24
:return
Returns:
the numpy matrix (2-D np.ndarray) representing the DAG of this architecture topology the numpy matrix (2-D np.ndarray) representing the DAG of this architecture topology
:usage
[USAGE]
matrix = api.str2matrix( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' ) matrix = api.str2matrix( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
This matrix is 4-by-4 matrix representing a cell with 4 nodes (only the lower left triangle is useful). This matrix is 4-by-4 matrix representing a cell with 4 nodes (only the lower left triangle is useful).
[ [0, 0, 0, 0], # the first line represents the input (0-th) node [ [0, 0, 0, 0], # the first line represents the input (0-th) node
@ -262,19 +315,19 @@ class NATStopology(NASBenchMetaAPI):
[0, 0, 1, 0] ] # the fourth line represents the 3-rd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node ) + 1-th-op( 2-th-node ) [0, 0, 1, 0] ] # the fourth line represents the 3-rd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node ) + 1-th-op( 2-th-node )
In the topology search space in NATS-BENCH, 0-th-op is 'none', 1-th-op is 'skip_connect', In the topology search space in NATS-BENCH, 0-th-op is 'none', 1-th-op is 'skip_connect',
2-th-op is 'nor_conv_1x1', 3-th-op is 'nor_conv_3x3', 4-th-op is 'avg_pool_3x3'. 2-th-op is 'nor_conv_1x1', 3-th-op is 'nor_conv_3x3', 4-th-op is 'avg_pool_3x3'.
:(NOTE) [NOTE]
If a node has two input-edges from the same node, this function does not work. One edge will be overlapped. If a node has two input-edges from the same node, this function does not work. One edge will be overlapped.
""" """
node_strs = arch_str.split('+') node_strs = arch_str.split('+')
num_nodes = len(node_strs) + 1 num_nodes = len(node_strs) + 1
matrix = np.zeros((num_nodes, num_nodes)) matrix = np.zeros((num_nodes, num_nodes))
for i, node_str in enumerate(node_strs): for i, node_str in enumerate(node_strs):
inputs = list(filter(lambda x: x != '', node_str.split('|'))) inputs = list(filter(lambda x: x != '', node_str.split('|'))) # pylint: disable=g-explicit-bool-comparison
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput) for xinput in inputs:
assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
for xi in inputs: for xi in inputs:
op, idx = xi.split('~') op, idx = xi.split('~')
if op not in search_space: raise ValueError('this op ({:}) is not in {:}'.format(op, search_space)) if op not in search_space: raise ValueError('this op ({:}) is not in {:}'.format(op, search_space))
op_idx, node_idx = search_space.index(op), int(idx) op_idx, node_idx = search_space.index(op), int(idx)
matrix[i+1, node_idx] = op_idx matrix[i+1, node_idx] = op_idx
return matrix return matrix

File diff suppressed because it is too large Load Diff