Fix minor bugs in test-ww.py
This commit is contained in:
parent
22025887f1
commit
87545c4477
@ -1,7 +1,9 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||||
#####################################################
|
#####################################################
|
||||||
# [2020.03.09] Upgrade to v1.2
|
# [2020.02.25] Initialize the API as v1.1
|
||||||
|
# [2020.03.09] Upgrade the API to v1.2
|
||||||
|
# [2020.03.16] Upgrade the API to v1.3
|
||||||
import os
|
import os
|
||||||
from setuptools import setup
|
from setuptools import setup
|
||||||
|
|
||||||
@ -13,7 +15,7 @@ def read(fname='README.md'):
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name = "nas_bench_201",
|
name = "nas_bench_201",
|
||||||
version = "1.2",
|
version = "1.3",
|
||||||
author = "Xuanyi Dong",
|
author = "Xuanyi Dong",
|
||||||
author_email = "dongxuanyi888@gmail.com",
|
author_email = "dongxuanyi888@gmail.com",
|
||||||
description = "API for NAS-Bench-201 (a benchmark for neural architecture search).",
|
description = "API for NAS-Bench-201 (a benchmark for neural architecture search).",
|
||||||
|
@ -37,7 +37,8 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
|
|||||||
final_val_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
|
final_val_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
|
||||||
final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
|
final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
|
||||||
for idx in range(len(api)):
|
for idx in range(len(api)):
|
||||||
info = api.get_more_info(idx, data, use_12epochs_result=use_12epochs_result, is_random=False)
|
# info = api.get_more_info(idx, data, use_12epochs_result=use_12epochs_result, is_random=False)
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
for key in ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']:
|
for key in ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']:
|
||||||
info = api.get_more_info(idx, key, use_12epochs_result=False, is_random=False)
|
info = api.get_more_info(idx, key, use_12epochs_result=False, is_random=False)
|
||||||
if key == 'cifar10-valid':
|
if key == 'cifar10-valid':
|
||||||
@ -50,7 +51,7 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
|
|||||||
config = api.get_net_config(idx, data)
|
config = api.get_net_config(idx, data)
|
||||||
net = get_cell_based_tiny_net(config)
|
net = get_cell_based_tiny_net(config)
|
||||||
api.reload(weight_dir, idx)
|
api.reload(weight_dir, idx)
|
||||||
params = api.get_net_param(idx, data, None)
|
params = api.get_net_param(idx, data, None, use_12epochs_result=use_12epochs_result)
|
||||||
cur_norms = []
|
cur_norms = []
|
||||||
for seed, param in params.items():
|
for seed, param in params.items():
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
Loading…
Reference in New Issue
Block a user