Fix bugs in test-ww

This commit is contained in:
D-X-Y 2020-03-21 18:24:48 -07:00
parent 87545c4477
commit b29bba159a

View File

@ -6,7 +6,7 @@
# python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1 # python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1
# bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1 # bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1
############################################################################################### ###############################################################################################
import os, gc, sys, argparse, psutil import os, gc, sys, math, argparse, psutil
import numpy as np import numpy as np
import torch import torch
from pathlib import Path from pathlib import Path
@ -57,8 +57,12 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
with torch.no_grad(): with torch.no_grad():
net.load_state_dict(param) net.load_state_dict(param)
_, summary = weight_watcher.analyze(net, alphas=False) _, summary = weight_watcher.analyze(net, alphas=False)
cur_norms.append(summary['lognorm']) cur_norms.append(-summary['lognorm'])
norms.append( float(np.mean(cur_norms)) ) cur_norm = float(np.mean(cur_norms))
if math.isnan(cur_norm):
print (' IGNORE {:} due to nan.'.format(idx))
continue
norms.append(cur_norm)
api.clear_params(idx, None) api.clear_params(idx, None)
if idx % 200 == 199 or idx + 1 == len(api): if idx % 200 == 199 or idx + 1 == len(api):
head = '{:05d}/{:05d}'.format(idx, len(api)) head = '{:05d}/{:05d}'.format(idx, len(api))