Fix bugs in test-ww

This commit is contained in:
D-X-Y 2020-03-21 18:24:48 -07:00
parent 87545c4477
commit b29bba159a

View File

@ -6,7 +6,7 @@
# python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1
# bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1
###############################################################################################
import os, gc, sys, argparse, psutil
import os, gc, sys, math, argparse, psutil
import numpy as np
import torch
from pathlib import Path
@ -57,8 +57,12 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
with torch.no_grad():
net.load_state_dict(param)
_, summary = weight_watcher.analyze(net, alphas=False)
cur_norms.append(summary['lognorm'])
norms.append( float(np.mean(cur_norms)) )
cur_norms.append(-summary['lognorm'])
cur_norm = float(np.mean(cur_norms))
if math.isnan(cur_norm):
print (' IGNORE {:} due to nan.'.format(idx))
continue
norms.append(cur_norm)
api.clear_params(idx, None)
if idx % 200 == 199 or idx + 1 == len(api):
head = '{:05d}/{:05d}'.format(idx, len(api))