Fix bugs in test-ww
This commit is contained in:
parent
87545c4477
commit
b29bba159a
@ -6,7 +6,7 @@
|
||||
# python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1
|
||||
# bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1
|
||||
###############################################################################################
|
||||
import os, gc, sys, argparse, psutil
|
||||
import os, gc, sys, math, argparse, psutil
|
||||
import numpy as np
|
||||
import torch
|
||||
from pathlib import Path
|
||||
@ -57,8 +57,12 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
|
||||
with torch.no_grad():
|
||||
net.load_state_dict(param)
|
||||
_, summary = weight_watcher.analyze(net, alphas=False)
|
||||
cur_norms.append(summary['lognorm'])
|
||||
norms.append( float(np.mean(cur_norms)) )
|
||||
cur_norms.append(-summary['lognorm'])
|
||||
cur_norm = float(np.mean(cur_norms))
|
||||
if math.isnan(cur_norm):
|
||||
print (' IGNORE {:} due to nan.'.format(idx))
|
||||
continue
|
||||
norms.append(cur_norm)
|
||||
api.clear_params(idx, None)
|
||||
if idx % 200 == 199 or idx + 1 == len(api):
|
||||
head = '{:05d}/{:05d}'.format(idx, len(api))
|
||||
|
Loading…
Reference in New Issue
Block a user