Fix bugs in test-ww
This commit is contained in:
parent
87545c4477
commit
b29bba159a
@ -6,7 +6,7 @@
|
|||||||
# python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1
|
# python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1
|
||||||
# bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1
|
# bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1
|
||||||
###############################################################################################
|
###############################################################################################
|
||||||
import os, gc, sys, argparse, psutil
|
import os, gc, sys, math, argparse, psutil
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -57,8 +57,12 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
net.load_state_dict(param)
|
net.load_state_dict(param)
|
||||||
_, summary = weight_watcher.analyze(net, alphas=False)
|
_, summary = weight_watcher.analyze(net, alphas=False)
|
||||||
cur_norms.append(summary['lognorm'])
|
cur_norms.append(-summary['lognorm'])
|
||||||
norms.append( float(np.mean(cur_norms)) )
|
cur_norm = float(np.mean(cur_norms))
|
||||||
|
if math.isnan(cur_norm):
|
||||||
|
print (' IGNORE {:} due to nan.'.format(idx))
|
||||||
|
continue
|
||||||
|
norms.append(cur_norm)
|
||||||
api.clear_params(idx, None)
|
api.clear_params(idx, None)
|
||||||
if idx % 200 == 199 or idx + 1 == len(api):
|
if idx % 200 == 199 or idx + 1 == len(api):
|
||||||
head = '{:05d}/{:05d}'.format(idx, len(api))
|
head = '{:05d}/{:05d}'.format(idx, len(api))
|
||||||
|
Loading…
Reference in New Issue
Block a user