From b29bba159a46d20bd1835c51c6c578f5d6a36b7a Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Sat, 21 Mar 2020 18:24:48 -0700 Subject: [PATCH] Fix bugs in test-ww --- exps/NAS-Bench-201/test-weights.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/exps/NAS-Bench-201/test-weights.py b/exps/NAS-Bench-201/test-weights.py index abc973f..20db0c0 100644 --- a/exps/NAS-Bench-201/test-weights.py +++ b/exps/NAS-Bench-201/test-weights.py @@ -6,7 +6,7 @@ # python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1 # bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1 ############################################################################################### -import os, gc, sys, argparse, psutil +import os, gc, sys, math, argparse, psutil import numpy as np import torch from pathlib import Path @@ -57,8 +57,12 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool): with torch.no_grad(): net.load_state_dict(param) _, summary = weight_watcher.analyze(net, alphas=False) - cur_norms.append(summary['lognorm']) - norms.append( float(np.mean(cur_norms)) ) + cur_norms.append(-summary['lognorm']) + cur_norm = float(np.mean(cur_norms)) + if math.isnan(cur_norm): + print (' IGNORE {:} due to nan.'.format(idx)) + continue + norms.append(cur_norm) api.clear_params(idx, None) if idx % 200 == 199 or idx + 1 == len(api): head = '{:05d}/{:05d}'.format(idx, len(api))