Fix minor bugs in test-ww.py

This commit is contained in:
D-X-Y 2020-03-21 12:13:13 -07:00
parent 22025887f1
commit 87545c4477
2 changed files with 7 additions and 4 deletions

View File

@ -1,7 +1,9 @@
##################################################### #####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
##################################################### #####################################################
# [2020.03.09] Upgrade to v1.2 # [2020.02.25] Initialize the API as v1.1
# [2020.03.09] Upgrade the API to v1.2
# [2020.03.16] Upgrade the API to v1.3
import os import os
from setuptools import setup from setuptools import setup
@ -13,7 +15,7 @@ def read(fname='README.md'):
setup( setup(
name = "nas_bench_201", name = "nas_bench_201",
version = "1.2", version = "1.3",
author = "Xuanyi Dong", author = "Xuanyi Dong",
author_email = "dongxuanyi888@gmail.com", author_email = "dongxuanyi888@gmail.com",
description = "API for NAS-Bench-201 (a benchmark for neural architecture search).", description = "API for NAS-Bench-201 (a benchmark for neural architecture search).",

View File

@ -37,7 +37,8 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
final_val_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []}) final_val_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []}) final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
for idx in range(len(api)): for idx in range(len(api)):
info = api.get_more_info(idx, data, use_12epochs_result=use_12epochs_result, is_random=False) # info = api.get_more_info(idx, data, use_12epochs_result=use_12epochs_result, is_random=False)
# import pdb; pdb.set_trace()
for key in ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']: for key in ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']:
info = api.get_more_info(idx, key, use_12epochs_result=False, is_random=False) info = api.get_more_info(idx, key, use_12epochs_result=False, is_random=False)
if key == 'cifar10-valid': if key == 'cifar10-valid':
@ -50,7 +51,7 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
config = api.get_net_config(idx, data) config = api.get_net_config(idx, data)
net = get_cell_based_tiny_net(config) net = get_cell_based_tiny_net(config)
api.reload(weight_dir, idx) api.reload(weight_dir, idx)
params = api.get_net_param(idx, data, None) params = api.get_net_param(idx, data, None, use_12epochs_result=use_12epochs_result)
cur_norms = [] cur_norms = []
for seed, param in params.items(): for seed, param in params.items():
with torch.no_grad(): with torch.no_grad():