diff --git a/exps/experimental/example-nas-bench.py b/exps/experimental/example-nas-bench.py new file mode 100644 index 0000000..8142f1d --- /dev/null +++ b/exps/experimental/example-nas-bench.py @@ -0,0 +1,50 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # +########################################################################################################################################################### +# Before run these commands, the files must be properly put. +# +# python exps/experimental/example-nas-bench.py --api_path $HOME/.torch/NAS-Bench-201-v1_1-096897.pth --archive_path $HOME/.torch/NAS-Bench-201-v1_1-archive +########################################################################################################################################################### +import os, gc, sys, math, argparse, psutil +import numpy as np +import torch +from pathlib import Path +from collections import OrderedDict +import matplotlib +import seaborn as sns +matplotlib.use('agg') +import matplotlib.pyplot as plt +lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() +if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +from nas_201_api import NASBench201API, NASBench301API +from log_utils import time_string +from models import get_cell_based_tiny_net +from utils import weight_watcher + + +if __name__ == '__main__': + parser = argparse.ArgumentParser("Analysis of NAS-Bench-201") + parser.add_argument('--api_path' , type=str, default=None, help='The path to the NAS-Bench-201 benchmark file and weight dir.') + parser.add_argument('--archive_path', type=str, default=None, help='The path to the NAS-Bench-201 weight dir.') + args = parser.parse_args() + + meta_file = Path(args.api_path) + weight_dir = Path(args.archive_path) + assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file) + assert weight_dir.exists() and weight_dir.is_dir(), 'invalid path for weight dir : {:}'.format(weight_dir) + + + api = NASBench201API(meta_file, verbose=True) + + arch_index = 3 # query the 3-th architecture + api.reload(weight_dir, arch_index) # reload the data of 3-th from archive dir + + data = 'cifar10' # query the info from CIFAR-10 + config = api.get_net_config(arch_index, data) + net = get_cell_based_tiny_net(config) + meta_info = api.query_meta_info_by_index(arch_index, hp='200') # all info about this architecture + params = meta_info.get_net_param(data, 888) + + net.load_state_dict(params) + _, summary = weight_watcher.analyze(net, alphas=False) + print('The summary of {:}-th architecture:\n{:}'.format(arch_index, summary))