update NAS-Bench-102

This commit is contained in:
D-X-Y 2019-12-26 23:29:36 +11:00
parent 1d5e8debad
commit d791622b63
3 changed files with 116 additions and 2 deletions

View File

@ -16,7 +16,8 @@ Note: please use `PyTorch >= 1.2.0` and `Python >= 3.6.0`.
The benchmark file of NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) or [Baidu-Wangpan (code:6u5d)](https://pan.baidu.com/s/1CiaNH6C12zuZf7q-Ilm09w). The benchmark file of NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) or [Baidu-Wangpan (code:6u5d)](https://pan.baidu.com/s/1CiaNH6C12zuZf7q-Ilm09w).
You can move it to anywhere you want and send its path to our API for initialization. You can move it to anywhere you want and send its path to our API for initialization.
- v1.0: `NAS-Bench-102-v1_0-e61699.pth`, where `e61699` is the last six digits for this file. - v1.0: `NAS-Bench-102-v1_0-e61699.pth`, where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial.
- v1.0: The full data of each architecture can be download from [Google Drive](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights.
The training and evaluation data used in NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7) or [Baidu-Wangpan (code:4fg7)](https://pan.baidu.com/s/1XAzavPKq3zcat1yBA1L2tQ). The training and evaluation data used in NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7) or [Baidu-Wangpan (code:4fg7)](https://pan.baidu.com/s/1XAzavPKq3zcat1yBA1L2tQ).
It is recommended to put these data into `$TORCH_HOME` (`~/.torch/` by default). If you want to generate NAS-Bench-102 or similar NAS datasets or training models by yourself, you need these data. It is recommended to put these data into `$TORCH_HOME` (`~/.torch/` by default). If you want to generate NAS-Bench-102 or similar NAS datasets or training models by yourself, you need these data.
@ -108,8 +109,12 @@ print(archRes.get_metrics('cifar10-valid', 'x-valid', None, True)) # print loss
`NASBench102API` is the topest level api. Please see the following usages: `NASBench102API` is the topest level api. Please see the following usages:
``` ```
from nas_102_api import NASBench102API as API from nas_102_api import NASBench102API as API
api = API('NAS-Bench-102-v1_0-e61699.pth') api = API('NAS-Bench-102-v1_0-e61699.pth') # This will load all the information of NAS-Bench-102 except the trained weights
api = API('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-Bench-102-v1_0-e61699.pth')) # The same as the above line while I usually save NAS-Bench-102-v1_0-e61699.pth in ~/.torch/.
api.show(-1) # show info of all architectures api.show(-1) # show info of all architectures
api.reload('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-BENCH-102-4-v1.0-archive'), 3) # This code will reload the information 3-th architecture with the trained weights
weights = api.get_net_param(3, 'cifar10', None) # Obtaining the weights of all trials for the 3-th architecture on cifar10. It will returns a dict, where the key is the seed and the value is the trained weights.
``` ```

View File

@ -0,0 +1,84 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
# python exps/NAS-Bench-102/check.py --base_save_dir
##################################################
import os, sys, time, argparse, collections
from shutil import copyfile
import torch
import torch.nn as nn
from pathlib import Path
from collections import defaultdict
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from log_utils import AverageMeter, time_string, convert_secs2time
def check_files(save_dir, meta_file, basestr):
meta_infos = torch.load(meta_file, map_location='cpu')
meta_archs = meta_infos['archs']
meta_num_archs = meta_infos['total']
meta_max_node = meta_infos['max_node']
assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs))
sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr))))
print ('{:} find {:} directories used to save checkpoints'.format(time_string(), len(sub_model_dirs)))
subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0
num_seeds = defaultdict(lambda: 0)
for index, sub_dir in enumerate(sub_model_dirs):
xcheckpoints = list(sub_dir.glob('arch-*-seed-*.pth'))
#xcheckpoints = list(sub_dir.glob('arch-*-seed-0777.pth')) + list(sub_dir.glob('arch-*-seed-0888.pth')) + list(sub_dir.glob('arch-*-seed-0999.pth'))
arch_indexes = set()
for checkpoint in xcheckpoints:
temp_names = checkpoint.name.split('-')
assert len(temp_names) == 4 and temp_names[0] == 'arch' and temp_names[2] == 'seed', 'invalid checkpoint name : {:}'.format(checkpoint.name)
arch_indexes.add( temp_names[1] )
subdir2archs[sub_dir] = sorted(list(arch_indexes))
num_evaluated_arch += len(arch_indexes)
# count number of seeds for each architecture
for arch_index in arch_indexes:
num_seeds[ len(list(sub_dir.glob('arch-{:}-seed-*.pth'.format(arch_index)))) ] += 1
print('There are {:5d} architectures that have been evaluated ({:} in total, {:} ckps in total).'.format(num_evaluated_arch, meta_num_archs, sum(k*v for k, v in num_seeds.items())))
for key in sorted( list( num_seeds.keys() ) ): print ('There are {:5d} architectures that are evaluated {:} times.'.format(num_seeds[key], key))
dir2ckps, dir2ckp_exists = dict(), dict()
start_time, epoch_time = time.time(), AverageMeter()
for IDX, (sub_dir, arch_indexes) in enumerate(subdir2archs.items()):
seeds = [777, 888, 999]
numrs = defaultdict(lambda: 0)
all_checkpoints, all_ckp_exists = [], []
for arch_index in arch_indexes:
checkpoints = ['arch-{:}-seed-{:04d}.pth'.format(arch_index, seed) for seed in seeds]
ckp_exists = [(sub_dir/x).exists() for x in checkpoints]
arch_index = int(arch_index)
assert 0 <= arch_index < len(meta_archs), 'invalid arch-index {:} (not found in meta_archs)'.format(arch_index)
all_checkpoints += checkpoints
all_ckp_exists += ckp_exists
numrs[sum(ckp_exists)] += 1
dir2ckps[ str(sub_dir) ] = all_checkpoints
dir2ckp_exists[ str(sub_dir) ] = all_ckp_exists
# measure time
epoch_time.update(time.time() - start_time)
start_time = time.time()
numrstr = ', '.join( ['{:}: {:03d}'.format(x, numrs[x]) for x in sorted(numrs.keys())] )
print('{:} load [{:2d}/{:2d}] [{:03d} archs] [{:04d}->{:04d} ckps] {:} done, need {:}. {:}'.format(time_string(), IDX+1, len(subdir2archs), len(arch_indexes), len(all_checkpoints), sum(all_ckp_exists), sub_dir, convert_secs2time(epoch_time.avg * (len(subdir2archs)-IDX-1), True), numrstr))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NAS Benchmark 102', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--base_save_dir', type=str, default='./output/NAS-BENCH-102-4', help='The base-name of folder to save checkpoints and log.')
parser.add_argument('--max_node', type=int, default=4, help='The maximum node in a cell.')
parser.add_argument('--channel', type=int, default=16, help='The number of channels.')
parser.add_argument('--num_cells', type=int, default=5, help='The number of cells in one stage.')
args = parser.parse_args()
save_dir = Path( args.base_save_dir )
meta_path = save_dir / 'meta-node-{:}.pth'.format(args.max_node)
assert save_dir.exists(), 'invalid save dir path : {:}'.format(save_dir)
assert meta_path.exists(), 'invalid saved meta path : {:}'.format(meta_path)
print ('check NAS-Bench-102 in {:}'.format(save_dir))
basestr = 'C{:}-N{:}'.format(args.channel, args.num_cells)
check_files(save_dir, meta_path, basestr)

View File

@ -78,6 +78,16 @@ class NASBench102API(object):
else : arch_index = -1 else : arch_index = -1
else: arch_index = -1 else: arch_index = -1
return arch_index return arch_index
def reload(self, archive_root, index):
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index))
assert 0 <= index < len(self.meta_archs), 'invalid index of {:}'.format(index)
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
xdata = torch.load(xfile_path)
assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path)
self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] )
self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] )
def query_by_arch(self, arch, use_12epochs_result=False): def query_by_arch(self, arch, use_12epochs_result=False):
if isinstance(arch, int): if isinstance(arch, int):
@ -125,10 +135,18 @@ class NASBench102API(object):
best_index, highest_accuracy = idx, accuracy best_index, highest_accuracy = idx, accuracy
return best_index return best_index
# return the topology structure of the `index`-th architecture
def arch(self, index): def arch(self, index):
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs)) assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
return copy.deepcopy(self.meta_archs[index]) return copy.deepcopy(self.meta_archs[index])
# obtain the trained weights of the `index`-th architecture on `dataset` with the seed of `seed`
def get_net_param(self, index, dataset, seed, use_12epochs_result=False):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
else : basestr, arch2infos = '200epochs', self.arch2infos_full
archresult = arch2infos[index]
return archresult.get_net_param(dataset, seed)
def get_more_info(self, index, dataset, use_12epochs_result=False): def get_more_info(self, index, dataset, use_12epochs_result=False):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
else : basestr, arch2infos = '200epochs', self.arch2infos_full else : basestr, arch2infos = '200epochs', self.arch2infos_full
@ -238,6 +256,13 @@ class ArchResults(object):
def get_dataset_names(self): def get_dataset_names(self):
return list(self.dataset_seed.keys()) return list(self.dataset_seed.keys())
def get_net_param(self, dataset, seed=None):
if seed is None:
x_seeds = self.dataset_seed[dataset]
return {seed: self.all_results[(dataset, seed)].get_net_param() for seed in x_seeds}
else:
return self.all_results[(dataset, seed)].get_net_param()
def query(self, dataset, seed=None): def query(self, dataset, seed=None):
if seed is None: if seed is None:
x_seeds = self.dataset_seed[dataset] x_seeds = self.dataset_seed[dataset]