Sync NATS-Bench's v1.0 and update algorithm names

This commit is contained in:
D-X-Y 2020-10-15 21:56:10 +11:00
parent 10e5f05935
commit 7d55192d83
7 changed files with 28 additions and 26 deletions

View File

@ -6,3 +6,4 @@
- [2019.01.31] [13e908f] GDAS codes were publicly released. - [2019.01.31] [13e908f] GDAS codes were publicly released.
- [2020.07.01] [a45808b] Upgrade NAS-API to the 2.0 version. - [2020.07.01] [a45808b] Upgrade NAS-API to the 2.0 version.
- [2020.09.16] [7052265] Create NATS-BENCH. - [2020.09.16] [7052265] Create NATS-BENCH.
- [2020.10.15] [ ] Update NATS-BENCH to version 1.0

View File

@ -7,6 +7,7 @@ We analyze the validity of our benchmark in terms of various criteria and perfor
We also show the versatility of NATS-Bench by benchmarking 13 recent state-of-the-art NAS algorithms on it. All logs and diagnostic information trained using the same setup for each candidate are provided. We also show the versatility of NATS-Bench by benchmarking 13 recent state-of-the-art NAS algorithms on it. All logs and diagnostic information trained using the same setup for each candidate are provided.
This facilitates a much larger community of researchers to focus on developing better NAS algorithms in a more comparable and computationally effective environment. This facilitates a much larger community of researchers to focus on developing better NAS algorithms in a more comparable and computationally effective environment.
**You can use `pip install nats_bench` to install the library of NATS-Bench.**
The structure of this Markdown file: The structure of this Markdown file:
- [How to use NATS-Bench?](#How-to-Use-NATS-Bench) - [How to use NATS-Bench?](#How-to-Use-NATS-Bench)
@ -175,18 +176,18 @@ python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HO
python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777 python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777
Run the channel search strategy in FBNet-V2 Run the channel search strategy in FBNet-V2 -- masking + Gumbel-Softmax :
python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777 python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777 python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed 777 python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_gumbel --rand_seed 777
Run the channel search strategy in TuNAS: Run the channel search strategy in TuNAS -- masking + sampling :
python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 --use_api 0 python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777 --use_api 0
python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777
python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --rand_seed 777 python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_rl --arch_weight_decay 0 --rand_seed 777
``` ```
### Final Discovered Architectures for Each Algorithm ### Final Discovered Architectures for Each Algorithm
@ -250,7 +251,7 @@ GDAS:
If you find that NATS-Bench helps your research, please consider citing it: If you find that NATS-Bench helps your research, please consider citing it:
``` ```
@article{dong2020nats, @article{dong2020nats,
title={NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size}, title={{NATS-Bench}: Benchmarking NAS algorithms for Architecture Topology and Size},
author={Dong, Xuanyi and Liu, Lu and Musial, Katarzyna and Gabrys, Bogdan}, author={Dong, Xuanyi and Liu, Lu and Musial, Katarzyna and Gabrys, Bogdan},
journal={arXiv preprint arXiv:2009.00437}, journal={arXiv preprint arXiv:2009.00437},
year={2020} year={2020}

View File

@ -43,7 +43,7 @@ from models import get_cell_based_tiny_net, get_search_spaces
from nats_bench import create from nats_bench import create
# Ad-hoc for TuNAS # Ad-hoc for RL algorithms.
class ExponentialMovingAverage(object): class ExponentialMovingAverage(object):
"""Class that maintains an exponential moving average.""" """Class that maintains an exponential moving average."""

View File

@ -44,8 +44,8 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suf
# alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix) # alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix)
# alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix) # alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix)
alg2name['channel-wise interpolation'] = 'tas-affine0_BN0-AWD0.001{:}'.format(suffix) alg2name['channel-wise interpolation'] = 'tas-affine0_BN0-AWD0.001{:}'.format(suffix)
alg2name['masking + Gumbel-Softmax'] = 'fbv2-affine0_BN0-AWD0.001{:}'.format(suffix) alg2name['masking + Gumbel-Softmax'] = 'mask_gumbel-affine0_BN0-AWD0.001{:}'.format(suffix)
alg2name['masking + sampling'] = 'tunas-affine0_BN0-AWD0.0{:}'.format(suffix) alg2name['masking + sampling'] = 'mask_rl-affine0_BN0-AWD0.0{:}'.format(suffix)
for alg, name in alg2name.items(): for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth') alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth')
alg2data = OrderedDict() alg2data = OrderedDict()

View File

@ -3,8 +3,8 @@
##################################################### #####################################################
# Here, we utilized three techniques to search for the number of channels: # Here, we utilized three techniques to search for the number of channels:
# - channel-wise interpolation from "Network Pruning via Transformable Architecture Search, NeurIPS 2019" # - channel-wise interpolation from "Network Pruning via Transformable Architecture Search, NeurIPS 2019"
# - masking + Gumbel-Softmax from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020" # - masking + Gumbel-Softmax (mask_gumbel) from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
# - masking + sampling from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020" # - masking + sampling (mask_rl) from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
from typing import List, Text, Any from typing import List, Text, Any
import random, torch import random, torch
import torch.nn as nn import torch.nn as nn
@ -52,10 +52,10 @@ class GenericNAS301Model(nn.Module):
def set_algo(self, algo: Text): def set_algo(self, algo: Text):
# used for searching # used for searching
assert self._algo is None, 'This functioin can only be called once.' assert self._algo is None, 'This functioin can only be called once.'
assert algo in ['fbv2', 'tunas', 'tas'], 'invalid algo : {:}'.format(algo) assert algo in ['mask_gumbel', 'mask_rl', 'tas'], 'invalid algo : {:}'.format(algo)
self._algo = algo self._algo = algo
self._arch_parameters = nn.Parameter(1e-3*torch.randn(self._max_num_Cs, len(self._candidate_Cs))) self._arch_parameters = nn.Parameter(1e-3*torch.randn(self._max_num_Cs, len(self._candidate_Cs)))
# if algo == 'fbv2' or algo == 'tunas': # if algo == 'mask_gumbel' or algo == 'mask_rl':
self.register_buffer('_masks', torch.zeros(len(self._candidate_Cs), max(self._candidate_Cs))) self.register_buffer('_masks', torch.zeros(len(self._candidate_Cs), max(self._candidate_Cs)))
for i in range(len(self._candidate_Cs)): for i in range(len(self._candidate_Cs)):
self._masks.data[i, :self._candidate_Cs[i]] = 1 self._masks.data[i, :self._candidate_Cs[i]] = 1
@ -130,7 +130,7 @@ class GenericNAS301Model(nn.Module):
else: else:
mask = self._masks[random.randint(0, len(self._masks)-1)] mask = self._masks[random.randint(0, len(self._masks)-1)]
feature = feature * mask.view(1, -1, 1, 1) feature = feature * mask.view(1, -1, 1, 1)
elif self._algo == 'fbv2': elif self._algo == 'mask_gumbel':
weights = nn.functional.gumbel_softmax(self._arch_parameters[idx:idx+1], tau=self.tau, dim=-1) weights = nn.functional.gumbel_softmax(self._arch_parameters[idx:idx+1], tau=self.tau, dim=-1)
mask = torch.matmul(weights, self._masks).view(1, -1, 1, 1) mask = torch.matmul(weights, self._masks).view(1, -1, 1, 1)
feature = feature * mask feature = feature * mask
@ -148,7 +148,7 @@ class GenericNAS301Model(nn.Module):
else: else:
miss = torch.zeros(feature.shape[0], feature.shape[1]-out.shape[1], feature.shape[2], feature.shape[3], device=feature.device) miss = torch.zeros(feature.shape[0], feature.shape[1]-out.shape[1], feature.shape[2], feature.shape[3], device=feature.device)
feature = torch.cat((out, miss), dim=1) feature = torch.cat((out, miss), dim=1)
elif self._algo == 'tunas': elif self._algo == 'mask_rl':
prob = nn.functional.softmax(self._arch_parameters[idx:idx+1], dim=-1) prob = nn.functional.softmax(self._arch_parameters[idx:idx+1], dim=-1)
dist = torch.distributions.Categorical(prob) dist = torch.distributions.Categorical(prob)
action = dist.sample() action = dist.sample()

View File

@ -939,9 +939,9 @@ class ArchResults(object):
x.load_state_dict(state_dict) x.load_state_dict(state_dict)
return x return x
# This function is used to clear the weights saved in each 'result'
# This can help reduce the memory footprint.
def clear_params(self): def clear_params(self):
"""Clear the weights saved in each 'result'."""
# NOTE(xuanyidong): This can help reduce the memory footprint.
for unused_key, result in self.all_results.items(): for unused_key, result in self.all_results.items():
del result.net_state_dict del result.net_state_dict
result.net_state_dict = None result.net_state_dict = None

View File

@ -23,11 +23,11 @@ CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --warmup_ratio ${ratio} --rand_seed ${seed} CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --warmup_ratio ${ratio} --rand_seed ${seed}
# #
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo fbv2 --warmup_ratio ${ratio} --rand_seed ${seed} CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --warmup_ratio ${ratio} --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --warmup_ratio ${ratio} --rand_seed ${seed} CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --warmup_ratio ${ratio} --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --warmup_ratio ${ratio} --rand_seed ${seed} CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_gumbel --warmup_ratio ${ratio} --rand_seed ${seed}
# #
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed} CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed} CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed} CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_rl --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}