Fix minor bugs in test-ww.py
This commit is contained in:
parent
22025887f1
commit
87545c4477
@ -1,7 +1,9 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
#####################################################
|
||||
# [2020.03.09] Upgrade to v1.2
|
||||
# [2020.02.25] Initialize the API as v1.1
|
||||
# [2020.03.09] Upgrade the API to v1.2
|
||||
# [2020.03.16] Upgrade the API to v1.3
|
||||
import os
|
||||
from setuptools import setup
|
||||
|
||||
@ -13,7 +15,7 @@ def read(fname='README.md'):
|
||||
|
||||
setup(
|
||||
name = "nas_bench_201",
|
||||
version = "1.2",
|
||||
version = "1.3",
|
||||
author = "Xuanyi Dong",
|
||||
author_email = "dongxuanyi888@gmail.com",
|
||||
description = "API for NAS-Bench-201 (a benchmark for neural architecture search).",
|
||||
|
@ -37,7 +37,8 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
|
||||
final_val_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
|
||||
final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
|
||||
for idx in range(len(api)):
|
||||
info = api.get_more_info(idx, data, use_12epochs_result=use_12epochs_result, is_random=False)
|
||||
# info = api.get_more_info(idx, data, use_12epochs_result=use_12epochs_result, is_random=False)
|
||||
# import pdb; pdb.set_trace()
|
||||
for key in ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']:
|
||||
info = api.get_more_info(idx, key, use_12epochs_result=False, is_random=False)
|
||||
if key == 'cifar10-valid':
|
||||
@ -50,7 +51,7 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
|
||||
config = api.get_net_config(idx, data)
|
||||
net = get_cell_based_tiny_net(config)
|
||||
api.reload(weight_dir, idx)
|
||||
params = api.get_net_param(idx, data, None)
|
||||
params = api.get_net_param(idx, data, None, use_12epochs_result=use_12epochs_result)
|
||||
cur_norms = []
|
||||
for seed, param in params.items():
|
||||
with torch.no_grad():
|
||||
|
Loading…
Reference in New Issue
Block a user