Update weight watcher codes
This commit is contained in:
parent
6facc39a42
commit
7d02870bf8
@ -4,3 +4,4 @@
|
|||||||
- [2019.12.20] [69ca086] Release NAS-Bench-201.
|
- [2019.12.20] [69ca086] Release NAS-Bench-201.
|
||||||
- [2019.09.28] [f8f3f38] TAS and SETN codes were publicly released.
|
- [2019.09.28] [f8f3f38] TAS and SETN codes were publicly released.
|
||||||
- [2019.01.31] [13e908f] GDAS codes were publicly released.
|
- [2019.01.31] [13e908f] GDAS codes were publicly released.
|
||||||
|
- [2020.07.01] [a45808b] Upgrade NAS-API to the 2.0 version.
|
||||||
|
@ -9,6 +9,7 @@
|
|||||||
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NAS-Bench-301-v1_0 --dataset cifar10
|
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NAS-Bench-301-v1_0 --dataset cifar10
|
||||||
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NAS-Bench-301-v1_0 --dataset cifar100
|
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NAS-Bench-301-v1_0 --dataset cifar100
|
||||||
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NAS-Bench-301-v1_0 --dataset ImageNet16-120
|
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NAS-Bench-301-v1_0 --dataset ImageNet16-120
|
||||||
|
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space tss --base_path $HOME/.torch/NAS-Bench-201-v1_1 --dataset cifar10
|
||||||
###########################################################################################################################################################
|
###########################################################################################################################################################
|
||||||
import os, gc, sys, math, argparse, psutil
|
import os, gc, sys, math, argparse, psutil
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -411,7 +411,11 @@ class ArchResults(object):
|
|||||||
x_seeds = self.dataset_seed[dataset]
|
x_seeds = self.dataset_seed[dataset]
|
||||||
return {seed: self.all_results[(dataset, seed)].get_net_param() for seed in x_seeds}
|
return {seed: self.all_results[(dataset, seed)].get_net_param() for seed in x_seeds}
|
||||||
else:
|
else:
|
||||||
return self.all_results[(dataset, seed)].get_net_param()
|
xkey = (dataset, seed)
|
||||||
|
if xkey in self.all_results:
|
||||||
|
return self.all_results[xkey].get_net_param()
|
||||||
|
else:
|
||||||
|
raise ValueError('key={:} not in {:}'.format(xkey, list(self.all_results.keys())))
|
||||||
|
|
||||||
def reset_latency(self, dataset: Text, seed: Union[None, Text], latency: float) -> None:
|
def reset_latency(self, dataset: Text, seed: Union[None, Text], latency: float) -> None:
|
||||||
"""This function is used to reset the latency in all corresponding ResultsCount(s)."""
|
"""This function is used to reset the latency in all corresponding ResultsCount(s)."""
|
||||||
|
Loading…
Reference in New Issue
Block a user