This commit is contained in:
Jack Turner
2021-02-26 16:12:51 +00:00
parent c895924c99
commit b74255e1f3
74 changed files with 11326 additions and 537 deletions

21
LICENCE
View File

@@ -1,21 +0,0 @@
MIT License
Copyright (c) 2020 BayesWatch
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,63 +1,31 @@
# [Neural Architecture Search Without Training](https://arxiv.org/abs/2006.04647)
# Neural Architecture Search Without Training
This repository contains code for replicating our paper, [NAS Without Training](https://arxiv.org/abs/2006.04647).
> :warning: Note: this repository has been updated to reflect the second version of the paper to appear on arXiv 1 March. :warning
## Setup
## Usage
1. Download the [datasets](https://drive.google.com/drive/folders/1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7).
2. Download [NAS-Bench-201](https://drive.google.com/file/d/16Y0UwGisiouVRxW-W5hEtbxmcHw_0hF_/view).
3. Install the requirements in a conda environment with `conda env create -f environment.yml`.
Create a conda environment using the env.yml file
We also refer the reader to instructions in the official [NAS-Bench-201 README](https://github.com/D-X-Y/NAS-Bench-201).
## Reproducing our results
To reproduce our results:
```
conda activate nas-wot
./reproduce.sh 3 # average accuracy over 3 runs
./reproduce.sh 500 # average accuracy over 500 runs (this will take longer)
```bash
conda env create -f env.yml
```
Each command will finish by calling `process_results.py`, which will print a table. `./reproduce.sh 3` should print the following table:
Activate the environment and follow the instructions to install
| Method | Search time (s) | CIFAR-10 (val) | CIFAR-10 (test) | CIFAR-100 (val) | CIFAR-100 (test) | ImageNet16-120 (val) | ImageNet16-120 (test) |
|:-------------|------------------:|:-----------------|:------------------|:------------------|:-------------------|:-----------------------|:------------------------|
| Ours (N=10) | 1.75 | 89.50 +- 0.51 | 92.98 +- 0.82 | 69.80 +- 2.46 | 69.86 +- 2.21 | 42.35 +- 1.19 | 42.38 +- 1.37 |
| Ours (N=100) | 17.76 | 87.44 +- 1.45 | 92.27 +- 1.53 | 70.26 +- 1.09 | 69.86 +- 0.60 | 43.30 +- 1.62 | 43.51 +- 1.40
Install nasbench (see https://github.com/google-research/nasbench)
`./reproduce 500` will produce the following table:
Download the NDS data from https://github.com/facebookresearch/nds and place the json files in naswot-codebase/nds_data/
Download the NASbench101 data (see https://github.com/google-research/nasbench)
Download the NASbench201 data (see https://github.com/D-X-Y/NAS-Bench-201)
| Method | Search time (s) | CIFAR-10 (val) | CIFAR-10 (test) | CIFAR-100 (val) | CIFAR-100 (test) | ImageNet16-120 (val) | ImageNet16-120 (test) |
|:-------------|------------------:|:-----------------|:------------------|:------------------|:-------------------|:-----------------------|:------------------------|
| Ours (N=10) | 1.67 | 88.61 +- 1.58 | 91.58 +- 1.70 | 67.03 +- 3.01 | 67.15 +- 3.08 | 39.74 +- 4.17 | 39.76 +- 4.39 |
| Ours (N=100) | 17.12 | 88.43 +- 1.67 | 91.24 +- 1.70 | 67.04 +- 2.91 | 67.12 +- 2.98 | 40.68 +- 3.41 | 40.67 +- 3.55 |
To try different sample sizes, simply change the `--n_samples` argument in the call to `search.py`, and update the list of sample sizes [this line](https://github.com/BayesWatch/nas-without-training/blob/master/process_results.py#L51) of `process_results.py`.
Note that search times may vary from the reported result owing to hardware setup.
## Plotting histograms
In order to plot the histograms in Figure 1 of the paper, run:
Reproduce all of the results by running
```bash
./scorehook.sh
```
python plot_histograms.py
```
to produce:
![alt text](results/histograms_cifar10val_batch256.png)
The code is licensed under the MIT licence.
## Acknowledgements
This repository makes liberal use of code from the [AutoDL](https://github.com/D-X-Y/AutoDL-Projects) library. We also rely on [NAS-Bench-201](https://github.com/D-X-Y/NAS-Bench-201).
## Citing us
If you use or build on our work, please consider citing us:

1
autodl/__init__.py Normal file
View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,11 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
#####################################################
from .api_utils import ArchResults, ResultsCount
from .api_201 import NASBench201API
from .api_301 import NASBench301API
# NAS_BENCH_201_API_VERSION="v1.1" # [2020.02.25]
# NAS_BENCH_201_API_VERSION="v1.2" # [2020.03.09]
# NAS_BENCH_201_API_VERSION="v1.3" # [2020.03.16]
NAS_BENCH_201_API_VERSION="v2.0" # [2020.06.30]

View File

@@ -0,0 +1,274 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
############################################################################################
# NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 #
############################################################################################
# The history of benchmark files:
# [2020.02.25] NAS-Bench-201-v1_0-e61699.pth : 6219 architectures are trained once, 1621 architectures are trained twice, 7785 architectures are trained three times. `LESS` only supports CIFAR10-VALID.
# [2020.03.16] NAS-Bench-201-v1_1-096897.pth : 2225 architectures are trained once, 5439 archiitectures are trained twice, 7961 architectures are trained three times on all training sets. For the hyper-parameters with the total epochs of 12, each model is trained on CIFAR-10, CIFAR-100, ImageNet16-120 once, and is trained on CIFAR-10-VALID twice.
#
# I'm still actively enhancing this benchmark. Please feel free to contact me if you have any question w.r.t. NAS-Bench-201.
#
import os, copy, random, torch, numpy as np
from pathlib import Path
from typing import List, Text, Union, Dict, Optional
from collections import OrderedDict, defaultdict
from .api_utils import ArchResults
from .api_utils import NASBenchMetaAPI
from .api_utils import remap_dataset_set_names
ALL_BENCHMARK_FILES = ['NAS-Bench-201-v1_0-e61699.pth', 'NAS-Bench-201-v1_1-096897.pth']
ALL_ARCHIVE_DIRS = ['NAS-Bench-201-v1_1-archive']
def print_information(information, extra_info=None, show=False):
dataset_names = information.get_dataset_names()
strings = [information.arch_str, 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)]
def metric2str(loss, acc):
return 'loss = {:.3f}, top1 = {:.2f}%'.format(loss, acc)
for ida, dataset in enumerate(dataset_names):
metric = information.get_compute_costs(dataset)
flop, param, latency = metric['flops'], metric['params'], metric['latency']
str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency is not None and latency > 0 else None)
train_info = information.get_metrics(dataset, 'train')
if dataset == 'cifar10-valid':
valid_info = information.get_metrics(dataset, 'x-valid')
str2 = '{:14s} train : [{:}], valid : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']))
elif dataset == 'cifar10':
test__info = information.get_metrics(dataset, 'ori-test')
str2 = '{:14s} train : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
else:
valid_info = information.get_metrics(dataset, 'x-valid')
test__info = information.get_metrics(dataset, 'x-test')
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
strings += [str1, str2]
if show: print('\n'.join(strings))
return strings
"""
This is the class for the API of NAS-Bench-201.
"""
class NASBench201API(NASBenchMetaAPI):
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None,
verbose: bool=True):
self.filename = None
self.reset_time()
if file_path_or_dict is None:
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], ALL_BENCHMARK_FILES[-1])
print ('Try to use the default NAS-Bench-201 path from {:}.'.format(file_path_or_dict))
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path):
file_path_or_dict = str(file_path_or_dict)
if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict))
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
self.filename = Path(file_path_or_dict).name
file_path_or_dict = torch.load(file_path_or_dict, map_location='cpu')
elif isinstance(file_path_or_dict, dict):
file_path_or_dict = copy.deepcopy(file_path_or_dict)
else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict)))
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
self.verbose = verbose # [TODO] a flag indicating whether to print more logs
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults
self.arch2infos_dict = OrderedDict()
self._avaliable_hps = set(['12', '200'])
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
all_info = file_path_or_dict['arch2infos'][xkey]
hp2archres = OrderedDict()
# self.arch2infos_less[xkey] = ArchResults.create_from_state_dict( all_info['less'] )
# self.arch2infos_full[xkey] = ArchResults.create_from_state_dict( all_info['full'] )
hp2archres['12'] = ArchResults.create_from_state_dict(all_info['less'])
hp2archres['200'] = ArchResults.create_from_state_dict(all_info['full'])
self.arch2infos_dict[xkey] = hp2archres
self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes']))
self.archstr2index = {}
for idx, arch in enumerate(self.meta_archs):
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
self.archstr2index[ arch ] = idx
def reload(self, archive_root: Text = None, index: int = None):
"""Overwrite all information of the 'index'-th architecture in the search space.
It will load its data from 'archive_root'.
"""
if archive_root is None:
archive_root = os.path.join(os.environ['TORCH_HOME'], ALL_ARCHIVE_DIRS[-1])
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
if index is None:
indexes = list(range(len(self)))
else:
indexes = [index]
for idx in indexes:
assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx)
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(idx))
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
xdata = torch.load(xfile_path, map_location='cpu')
assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path)
hp2archres = OrderedDict()
hp2archres['12'] = ArchResults.create_from_state_dict(xdata['less'])
hp2archres['200'] = ArchResults.create_from_state_dict(xdata['full'])
self.arch2infos_dict[idx] = hp2archres
def query_info_str_by_arch(self, arch, hp: Text='12'):
""" This function is used to query the information of a specific architecture
'arch' can be an architecture index or an architecture string
When hp=12, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/12E.config'
When hp=200, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/200E.config'
The difference between these three configurations are the number of training epochs.
"""
if self.verbose:
print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp))
return self._query_info_str_by_arch(arch, hp, print_information)
# obtain the metric for the `index`-th architecture
# `dataset` indicates the dataset:
# 'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
# 'cifar10' : using the proposed train+valid set of CIFAR-10 as the training set
# 'cifar100' : using the proposed train set of CIFAR-100 as the training set
# 'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set
# `iepoch` indicates the index of training epochs from 0 to 11/199.
# When iepoch=None, it will return the metric for the last training epoch
# When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0)
# `use_12epochs_result` indicates different hyper-parameters for training
# When use_12epochs_result=True, it trains the network with 12 epochs and the LR decayed from 0.1 to 0 within 12 epochs
# When use_12epochs_result=False, it trains the network with 200 epochs and the LR decayed from 0.1 to 0 within 200 epochs
# `is_random`
# When is_random=True, the performance of a random architecture will be returned
# When is_random=False, the performanceo of all trials will be averaged.
def get_more_info(self, index, dataset, iepoch=None, hp='12', is_random=True):
if self.verbose:
print('Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(index, dataset, iepoch, hp, is_random))
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
if index not in self.arch2infos_dict:
raise ValueError('Did not find {:} from arch2infos_dict.'.format(index))
archresult = self.arch2infos_dict[index][str(hp)]
# if randomly select one trial, select the seed at first
if isinstance(is_random, bool) and is_random:
seeds = archresult.get_dataset_seeds(dataset)
is_random = random.choice(seeds)
# collect the training information
train_info = archresult.get_metrics(dataset, 'train', iepoch=iepoch, is_random=is_random)
total = train_info['iepoch'] + 1
xinfo = {'train-loss' : train_info['loss'],
'train-accuracy': train_info['accuracy'],
'train-per-time': train_info['all_time'] / total if train_info['all_time'] is not None else None,
'train-all-time': train_info['all_time']}
# collect the evaluation information
if dataset == 'cifar10-valid':
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
try:
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
except:
test_info = None
valtest_info = None
else:
try: # collect results on the proposed test set
if dataset == 'cifar10':
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else:
test_info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
except:
test_info = None
try: # collect results on the proposed validation set
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
except:
valid_info = None
try:
if dataset != 'cifar10':
valtest_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else:
valtest_info = None
except:
valtest_info = None
if valid_info is not None:
xinfo['valid-loss'] = valid_info['loss']
xinfo['valid-accuracy'] = valid_info['accuracy']
xinfo['valid-per-time'] = valid_info['all_time'] / total if valid_info['all_time'] is not None else None
xinfo['valid-all-time'] = valid_info['all_time']
if test_info is not None:
xinfo['test-loss'] = test_info['loss']
xinfo['test-accuracy'] = test_info['accuracy']
xinfo['test-per-time'] = test_info['all_time'] / total if test_info['all_time'] is not None else None
xinfo['test-all-time'] = test_info['all_time']
if valtest_info is not None:
xinfo['valtest-loss'] = valtest_info['loss']
xinfo['valtest-accuracy'] = valtest_info['accuracy']
xinfo['valtest-per-time'] = valtest_info['all_time'] / total if valtest_info['all_time'] is not None else None
xinfo['valtest-all-time'] = valtest_info['all_time']
return xinfo
def show(self, index: int = -1) -> None:
"""This function will print the information of a specific (or all) architecture(s)."""
self._show(index, print_information)
@staticmethod
def str2lists(arch_str: Text) -> List[tuple]:
"""
This function shows how to read the string-based architecture encoding.
It is the same as the `str2structure` func in `AutoDL-Projects/lib/models/cell_searchs/genotypes.py`
:param
arch_str: the input is a string indicates the architecture topology, such as
|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|
:return: a list of tuple, contains multiple (op, input_node_index) pairs.
:usage
arch = api.str2lists( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
print ('there are {:} nodes in this arch'.format(len(arch)+1)) # arch is a list
for i, node in enumerate(arch):
print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node))
"""
node_strs = arch_str.split('+')
genotypes = []
for i, node_str in enumerate(node_strs):
inputs = list(filter(lambda x: x != '', node_str.split('|')))
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
inputs = ( xi.split('~') for xi in inputs )
input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs)
genotypes.append( input_infos )
return genotypes
@staticmethod
def str2matrix(arch_str: Text,
search_space: List[Text] = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']) -> np.ndarray:
"""
This func shows how to convert the string-based architecture encoding to the encoding strategy in NAS-Bench-101.
:param
arch_str: the input is a string indicates the architecture topology, such as
|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|
search_space: a list of operation string, the default list is the search space for NAS-Bench-201
the default value should be be consistent with this line https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/models/cell_operations.py#L24
:return
the numpy matrix (2-D np.ndarray) representing the DAG of this architecture topology
:usage
matrix = api.str2matrix( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
This matrix is 4-by-4 matrix representing a cell with 4 nodes (only the lower left triangle is useful).
[ [0, 0, 0, 0], # the first line represents the input (0-th) node
[2, 0, 0, 0], # the second line represents the 1-st node, is calculated by 2-th-op( 0-th-node )
[0, 0, 0, 0], # the third line represents the 2-nd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node )
[0, 0, 1, 0] ] # the fourth line represents the 3-rd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node ) + 1-th-op( 2-th-node )
In NAS-Bench-201 search space, 0-th-op is 'none', 1-th-op is 'skip_connect',
2-th-op is 'nor_conv_1x1', 3-th-op is 'nor_conv_3x3', 4-th-op is 'avg_pool_3x3'.
:(NOTE)
If a node has two input-edges from the same node, this function does not work. One edge will be overlapped.
"""
node_strs = arch_str.split('+')
num_nodes = len(node_strs) + 1
matrix = np.zeros((num_nodes, num_nodes))
for i, node_str in enumerate(node_strs):
inputs = list(filter(lambda x: x != '', node_str.split('|')))
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
for xi in inputs:
op, idx = xi.split('~')
if op not in search_space: raise ValueError('this op ({:}) is not in {:}'.format(op, search_space))
op_idx, node_idx = search_space.index(op), int(idx)
matrix[i+1, node_idx] = op_idx
return matrix

View File

@@ -0,0 +1,222 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
############################################################################################
# NAS-Bench-301, coming soon.
############################################################################################
# The history of benchmark files:
# [2020.06.30] NAS-Bench-301-v1_0
#
import os, copy, random, torch, numpy as np
from pathlib import Path
from typing import List, Text, Union, Dict, Optional
from collections import OrderedDict, defaultdict
from .api_utils import ArchResults
from .api_utils import NASBenchMetaAPI
from .api_utils import remap_dataset_set_names
ALL_BENCHMARK_FILES = ['NAS-Bench-301-v1_0-363be7.pth']
ALL_ARCHIVE_DIRS = ['NAS-Bench-301-v1_0-archive']
def print_information(information, extra_info=None, show=False):
dataset_names = information.get_dataset_names()
strings = [information.arch_str, 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)]
def metric2str(loss, acc):
return 'loss = {:.3f} & top1 = {:.2f}%'.format(loss, acc)
for ida, dataset in enumerate(dataset_names):
metric = information.get_compute_costs(dataset)
flop, param, latency = metric['flops'], metric['params'], metric['latency']
str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency is not None and latency > 0 else None)
train_info = information.get_metrics(dataset, 'train')
if dataset == 'cifar10-valid':
valid_info = information.get_metrics(dataset, 'x-valid')
test__info = information.get_metrics(dataset, 'ori-test')
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(
dataset, metric2str(train_info['loss'], train_info['accuracy']),
metric2str(valid_info['loss'], valid_info['accuracy']),
metric2str(test__info['loss'], test__info['accuracy']))
elif dataset == 'cifar10':
test__info = information.get_metrics(dataset, 'ori-test')
str2 = '{:14s} train : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
else:
valid_info = information.get_metrics(dataset, 'x-valid')
test__info = information.get_metrics(dataset, 'x-test')
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
strings += [str1, str2]
if show: print('\n'.join(strings))
return strings
"""
This is the class for the API of NAS-Bench-301.
"""
class NASBench301API(NASBenchMetaAPI):
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, verbose: bool=True):
self.filename = None
self.reset_time()
if file_path_or_dict is None:
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], ALL_BENCHMARK_FILES[-1])
print ('Try to use the default NAS-Bench-301 path from {:}.'.format(file_path_or_dict))
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path):
file_path_or_dict = str(file_path_or_dict)
if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict))
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
self.filename = Path(file_path_or_dict).name
file_path_or_dict = torch.load(file_path_or_dict, map_location='cpu')
elif isinstance(file_path_or_dict, dict):
file_path_or_dict = copy.deepcopy( file_path_or_dict )
else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict)))
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
self.verbose = verbose # [TODO] a flag indicating whether to print more logs
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults
self.arch2infos_dict = OrderedDict()
self._avaliable_hps = set()
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
all_infos = file_path_or_dict['arch2infos'][xkey]
hp2archres = OrderedDict()
for hp_key, results in all_infos.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter
self.arch2infos_dict[xkey] = hp2archres
self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes']))
self.archstr2index = {}
for idx, arch in enumerate(self.meta_archs):
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
self.archstr2index[ arch ] = idx
if self.verbose:
print('Create NAS-Bench-301 done with {:}/{:} architectures avaliable.'.format(len(self.evaluated_indexes), len(self.meta_archs)))
def reload(self, archive_root: Text = None, index: int = None):
"""Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'.
If index is None, overwrite all ckps.
"""
if self.verbose:
print('Call clear_params with archive_root={:} and index={:}'.format(archive_root, index))
if archive_root is None:
archive_root = os.path.join(os.environ['TORCH_HOME'], ALL_ARCHIVE_DIRS[-1])
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
if index is None:
indexes = list(range(len(self)))
else:
indexes = [index]
for idx in indexes:
assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx)
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(idx))
if not os.path.isfile(xfile_path):
xfile_path = os.path.join(archive_root, '{:d}-FULL.pth'.format(idx))
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
xdata = torch.load(xfile_path, map_location='cpu')
assert isinstance(xdata, dict), 'invalid format of data in {:}'.format(xfile_path)
hp2archres = OrderedDict()
for hp_key, results in xdata.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self.arch2infos_dict[idx] = hp2archres
def query_info_str_by_arch(self, arch, hp: Text='12'):
""" This function is used to query the information of a specific architecture
'arch' can be an architecture index or an architecture string
When hp=01, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/01E.config'
When hp=12, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/12E.config'
When hp=90, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/90E.config'
The difference between these three configurations are the number of training epochs.
"""
if self.verbose:
print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp))
return self._query_info_str_by_arch(arch, hp, print_information)
def get_more_info(self, index, dataset: Text, iepoch=None, hp='12', is_random=True):
"""This function will return the metric for the `index`-th architecture
`dataset` indicates the dataset:
'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
'cifar10' : using the proposed train+valid set of CIFAR-10 as the training set
'cifar100' : using the proposed train set of CIFAR-100 as the training set
'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set
`iepoch` indicates the index of training epochs from 0 to 11/199.
When iepoch=None, it will return the metric for the last training epoch
When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0)
`hp` indicates different hyper-parameters for training
When hp=01, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 01 epochs
When hp=12, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 12 epochs
When hp=90, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 90 epochs
`is_random`
When is_random=True, the performance of a random architecture will be returned
When is_random=False, the performanceo of all trials will be averaged.
"""
if self.verbose:
print('Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(index, dataset, iepoch, hp, is_random))
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
if index not in self.arch2infos_dict:
raise ValueError('Did not find {:} from arch2infos_dict.'.format(index))
archresult = self.arch2infos_dict[index][str(hp)]
# if randomly select one trial, select the seed at first
if isinstance(is_random, bool) and is_random:
seeds = archresult.get_dataset_seeds(dataset)
is_random = random.choice(seeds)
# collect the training information
train_info = archresult.get_metrics(dataset, 'train', iepoch=iepoch, is_random=is_random)
total = train_info['iepoch'] + 1
xinfo = {'train-loss' : train_info['loss'],
'train-accuracy': train_info['accuracy'],
'train-per-time': train_info['all_time'] / total,
'train-all-time': train_info['all_time']}
# collect the evaluation information
if dataset == 'cifar10-valid':
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
try:
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
except:
test_info = None
valtest_info = None
else:
try: # collect results on the proposed test set
if dataset == 'cifar10':
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else:
test_info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
except:
test_info = None
try: # collect results on the proposed validation set
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
except:
valid_info = None
try:
if dataset != 'cifar10':
valtest_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else:
valtest_info = None
except:
valtest_info = None
if valid_info is not None:
xinfo['valid-loss'] = valid_info['loss']
xinfo['valid-accuracy'] = valid_info['accuracy']
xinfo['valid-per-time'] = valid_info['all_time'] / total
xinfo['valid-all-time'] = valid_info['all_time']
if test_info is not None:
xinfo['test-loss'] = test_info['loss']
xinfo['test-accuracy'] = test_info['accuracy']
xinfo['test-per-time'] = test_info['all_time'] / total
xinfo['test-all-time'] = test_info['all_time']
if valtest_info is not None:
xinfo['valtest-loss'] = valtest_info['loss']
xinfo['valtest-accuracy'] = valtest_info['accuracy']
xinfo['valtest-per-time'] = valtest_info['all_time'] / total
xinfo['valtest-all-time'] = valtest_info['all_time']
return xinfo
def show(self, index: int = -1) -> None:
"""
This function will print the information of a specific (or all) architecture(s).
:param index: If the index < 0: it will loop for all architectures and print their information one by one.
else: it will print the information of the 'index'-th architecture.
:return: nothing
"""
self._show(index, print_information)

View File

@@ -0,0 +1,750 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
############################################################################################
# NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 #
############################################################################################
# In this Python file, we define NASBenchMetaAPI, the abstract class for benchmark APIs.
# We also define the class ArchResults, which contains all information of a single architecture trained by one kind of hyper-parameters on three datasets.
# We also define the class ResultsCount, which contains all information of a single trial for a single architecture.
############################################################################################
# History:
# [2020.06.30] The first version.
#
import os, abc, copy, random, torch, numpy as np
from pathlib import Path
from typing import List, Text, Union, Dict, Optional
from collections import OrderedDict, defaultdict
def remap_dataset_set_names(dataset, metric_on_set, verbose=False):
"""re-map the metric_on_set to internal keys"""
if verbose:
print('Call internal function _remap_dataset_set_names with dataset={:} and metric_on_set={:}'.format(dataset, metric_on_set))
if dataset == 'cifar10' and metric_on_set == 'valid':
dataset, metric_on_set = 'cifar10-valid', 'x-valid'
elif dataset == 'cifar10' and metric_on_set == 'test':
dataset, metric_on_set = 'cifar10', 'ori-test'
elif dataset == 'cifar10' and metric_on_set == 'train':
dataset, metric_on_set = 'cifar10', 'train'
elif (dataset == 'cifar100' or dataset == 'ImageNet16-120') and metric_on_set == 'valid':
metric_on_set = 'x-valid'
elif (dataset == 'cifar100' or dataset == 'ImageNet16-120') and metric_on_set == 'test':
metric_on_set = 'x-test'
if verbose:
print(' return dataset={:} and metric_on_set={:}'.format(dataset, metric_on_set))
return dataset, metric_on_set
class NASBenchMetaAPI(metaclass=abc.ABCMeta):
@abc.abstractmethod
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, verbose: bool=True):
"""The initialization function that takes the dataset file path (or a dict loaded from that path) as input."""
def __getitem__(self, index: int):
return copy.deepcopy(self.meta_archs[index])
def arch(self, index: int):
"""Return the topology structure of the `index`-th architecture."""
if self.verbose:
print('Call the arch function with index={:}'.format(index))
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
return copy.deepcopy(self.meta_archs[index])
def __len__(self):
return len(self.meta_archs)
def __repr__(self):
return ('{name}({num}/{total} architectures, file={filename})'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs), filename=self.filename))
@property
def avaliable_hps(self):
return list(copy.deepcopy(self._avaliable_hps))
@property
def used_time(self):
return self._used_time
def reset_time(self):
self._used_time = 0
def simulate_train_eval(self, arch, dataset, iepoch=None, hp='12', account_time=True):
index = self.query_index_by_arch(arch)
all_names = ('cifar10', 'cifar100', 'ImageNet16-120')
assert dataset in all_names, 'Invalid dataset name : {:} vs {:}'.format(dataset, all_names)
if dataset == 'cifar10':
info = self.get_more_info(index, 'cifar10-valid', iepoch=iepoch, hp=hp, is_random=True)
else:
info = self.get_more_info(index, dataset, iepoch=iepoch, hp=hp, is_random=True)
valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time']
latency = self.get_latency(index, dataset)
if account_time:
self._used_time += time_cost
return valid_acc, latency, time_cost, self._used_time
def random(self):
"""Return a random index of all architectures."""
return random.randint(0, len(self.meta_archs)-1)
def query_index_by_arch(self, arch):
""" This function is used to query the index of an architecture in the search space.
In the topology search space, the input arch can be an architecture string such as '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|';
or an instance that has the 'tostr' function that can generate the architecture string;
or it is directly an architecture index, in this case, we will check whether it is valid or not.
This function will return the index.
If return -1, it means this architecture is not in the search space.
Otherwise, it will return an int in [0, the-number-of-candidates-in-the-search-space).
"""
if self.verbose:
print('Call query_index_by_arch with arch={:}'.format(arch))
if isinstance(arch, int):
if 0 <= arch < len(self):
return arch
else:
raise ValueError('Invalid architecture index {:} vs [{:}, {:}].'.format(arch, 0, len(self)))
elif isinstance(arch, str):
if arch in self.archstr2index: arch_index = self.archstr2index[ arch ]
else : arch_index = -1
elif hasattr(arch, 'tostr'):
if arch.tostr() in self.archstr2index: arch_index = self.archstr2index[ arch.tostr() ]
else : arch_index = -1
else: arch_index = -1
return arch_index
def query_by_arch(self, arch, hp):
# This is to make the current version be compatible with the old version.
return self.query_info_str_by_arch(arch, hp)
@abc.abstractmethod
def reload(self, archive_root: Text = None, index: int = None):
"""Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'.
If index is None, overwrite all ckps.
"""
def clear_params(self, index: int, hp: Optional[Text]=None):
"""Remove the architecture's weights to save memory.
:arg
index: the index of the target architecture
hp: a flag to controll how to clear the parameters.
-- None: clear all the weights in '01'/'12'/'90', which indicates the number of training epochs.
-- '01' or '12' or '90': clear all the weights in arch2infos_dict[index][hp].
"""
if self.verbose:
print('Call clear_params with index={:} and hp={:}'.format(index, hp))
if hp is None:
for key, result in self.arch2infos_dict[index].items():
result.clear_params()
else:
if str(hp) not in self.arch2infos_dict[index]:
raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(index, list(self.arch2infos_dict[index].keys()), hp))
self.arch2infos_dict[index][str(hp)].clear_params()
@abc.abstractmethod
def query_info_str_by_arch(self, arch, hp: Text='12'):
"""This function is used to query the information of a specific architecture."""
def _query_info_str_by_arch(self, arch, hp: Text='12', print_information=None):
arch_index = self.query_index_by_arch(arch)
if arch_index in self.arch2infos_dict:
if hp not in self.arch2infos_dict[arch_index]:
raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(index, list(self.arch2infos_dict[arch_index].keys()), hp))
info = self.arch2infos_dict[arch_index][hp]
strings = print_information(info, 'arch-index={:}'.format(arch_index))
return '\n'.join(strings)
else:
print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
return None
def query_meta_info_by_index(self, arch_index, hp: Text = '12'):
"""Return the ArchResults for the 'arch_index'-th architecture. This function is similar to query_by_index."""
if self.verbose:
print('Call query_meta_info_by_index with arch_index={:}, hp={:}'.format(arch_index, hp))
if arch_index in self.arch2infos_dict:
if hp not in self.arch2infos_dict[arch_index]:
raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(arch_index, list(self.arch2infos_dict[arch_index].keys()), hp))
info = self.arch2infos_dict[arch_index][hp]
else:
raise ValueError('arch_index [{:}] does not in arch2infos'.format(arch_index))
return copy.deepcopy(info)
def query_by_index(self, arch_index: int, dataname: Union[None, Text] = None, hp: Text = '12'):
""" This 'query_by_index' function is used to query information with the training of 01 epochs, 12 epochs, 90 epochs, or 200 epochs.
------
If hp=01, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/01E.config)
If hp=12, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/12E.config)
If hp=90, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/90E.config)
If hp=200, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/200E.config)
------
If dataname is None, return the ArchResults
else, return a dict with all trials on that dataset (the key is the seed)
Options are 'cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120'.
-- cifar10-valid : training the model on the CIFAR-10 training set.
-- cifar10 : training the model on the CIFAR-10 training + validation set.
-- cifar100 : training the model on the CIFAR-100 training set.
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
"""
if self.verbose:
print('Call query_by_index with arch_index={:}, dataname={:}, hp={:}'.format(arch_index, dataname, hp))
info = self.query_meta_info_by_index(arch_index, hp)
if dataname is None: return info
else:
if dataname not in info.get_dataset_names():
raise ValueError('invalid dataset-name : {:} vs. {:}'.format(dataname, info.get_dataset_names()))
return info.query(dataname)
def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None, hp: Text = '12'):
"""Find the architecture with the highest accuracy based on some constraints."""
if self.verbose:
print('Call find_best with dataset={:}, metric_on_set={:}, hp={:} | with #FLOPs < {:} and #Params < {:}'.format(dataset, metric_on_set, hp, FLOP_max, Param_max))
dataset, metric_on_set = remap_dataset_set_names(dataset, metric_on_set, self.verbose)
best_index, highest_accuracy = -1, None
for i, arch_index in enumerate(self.evaluated_indexes):
arch_info = self.arch2infos_dict[arch_index][hp]
info = arch_info.get_compute_costs(dataset) # the information of costs
flop, param, latency = info['flops'], info['params'], info['latency']
if FLOP_max is not None and flop > FLOP_max : continue
if Param_max is not None and param > Param_max: continue
xinfo = arch_info.get_metrics(dataset, metric_on_set) # the information of loss and accuracy
loss, accuracy = xinfo['loss'], xinfo['accuracy']
if best_index == -1:
best_index, highest_accuracy = arch_index, accuracy
elif highest_accuracy < accuracy:
best_index, highest_accuracy = arch_index, accuracy
if self.verbose:
print(' the best architecture : [{:}] {:} with accuracy={:.3f}%'.format(best_index, self.arch(best_index), highest_accuracy))
return best_index, highest_accuracy
def get_net_param(self, index, dataset, seed: Optional[int], hp: Text = '12'):
"""
This function is used to obtain the trained weights of the `index`-th architecture on `dataset` with the seed of `seed`
Args [seed]:
-- None : return a dict containing the trained weights of all trials, where each key is a seed and its corresponding value is the weights.
-- a interger : return the weights of a specific trial, whose seed is this interger.
Args [hp]:
-- 01 : train the model by 01 epochs
-- 12 : train the model by 12 epochs
-- 90 : train the model by 90 epochs
-- 200 : train the model by 200 epochs
"""
if self.verbose:
print('Call the get_net_param function with index={:}, dataset={:}, seed={:}, hp={:}'.format(index, dataset, seed, hp))
info = self.query_meta_info_by_index(index, hp)
return info.get_net_param(dataset, seed)
def get_net_config(self, index: int, dataset: Text):
"""
This function is used to obtain the configuration for the `index`-th architecture on `dataset`.
Args [dataset] (4 possible options):
-- cifar10-valid : training the model on the CIFAR-10 training set.
-- cifar10 : training the model on the CIFAR-10 training + validation set.
-- cifar100 : training the model on the CIFAR-100 training set.
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
This function will return a dict.
========= Some examlpes for using this function:
config = api.get_net_config(128, 'cifar10')
"""
if self.verbose:
print('Call the get_net_config function with index={:}, dataset={:}.'.format(index, dataset))
if index in self.arch2infos_dict:
info = self.arch2infos_dict[index]
else:
raise ValueError('The arch_index={:} is not in arch2infos_dict.'.format(arch_index))
info = next(iter(info.values()))
results = info.query(dataset, None)
results = next(iter(results.values()))
return results.get_config(None)
def get_cost_info(self, index: int, dataset: Text, hp: Text = '12') -> Dict[Text, float]:
"""To obtain the cost metric for the `index`-th architecture on a dataset."""
if self.verbose:
print('Call the get_cost_info function with index={:}, dataset={:}, and hp={:}.'.format(index, dataset, hp))
info = self.query_meta_info_by_index(index, hp)
return info.get_compute_costs(dataset)
def get_latency(self, index: int, dataset: Text, hp: Text = '12') -> float:
"""
To obtain the latency of the network (by default it will return the latency with the batch size of 256).
:param index: the index of the target architecture
:param dataset: the dataset name (cifar10-valid, cifar10, cifar100, ImageNet16-120)
:return: return a float value in seconds
"""
if self.verbose:
print('Call the get_latency function with index={:}, dataset={:}, and hp={:}.'.format(index, dataset, hp))
cost_dict = self.get_cost_info(index, dataset, hp)
return cost_dict['latency']
@abc.abstractmethod
def show(self, index=-1):
"""This function will print the information of a specific (or all) architecture(s)."""
def _show(self, index=-1, print_information=None) -> None:
"""
This function will print the information of a specific (or all) architecture(s).
:param index: If the index < 0: it will loop for all architectures and print their information one by one.
else: it will print the information of the 'index'-th architecture.
:return: nothing
"""
if index < 0: # show all architectures
print(self)
for i, idx in enumerate(self.evaluated_indexes):
print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th architecture! '.format(i, len(self.evaluated_indexes), idx) + '-'*10)
print('arch : {:}'.format(self.meta_archs[idx]))
for key, result in self.arch2infos_dict[index].items():
strings = print_information(result)
print('>' * 40 + ' {:03d} epochs '.format(result.get_total_epoch()) + '>' * 40)
print('\n'.join(strings))
print('<' * 40 + '------------' + '<' * 40)
else:
if 0 <= index < len(self.meta_archs):
if index not in self.evaluated_indexes: print('The {:}-th architecture has not been evaluated or not saved.'.format(index))
else:
arch_info = self.arch2infos_dict[index]
for key, result in self.arch2infos_dict[index].items():
strings = print_information(result)
print('>' * 40 + ' {:03d} epochs '.format(result.get_total_epoch()) + '>' * 40)
print('\n'.join(strings))
print('<' * 40 + '------------' + '<' * 40)
else:
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))
def statistics(self, dataset: Text, hp: Union[Text, int]) -> Dict[int, int]:
"""This function will count the number of total trials."""
if self.verbose:
print('Call the statistics function with dataset={:} and hp={:}.'.format(dataset, hp))
valid_datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
if dataset not in valid_datasets:
raise ValueError('{:} not in {:}'.format(dataset, valid_datasets))
nums, hp = defaultdict(lambda: 0), str(hp)
for index in range(len(self)):
archInfo = self.arch2infos_dict[index][hp]
dataset_seed = archInfo.dataset_seed
if dataset not in dataset_seed:
nums[0] += 1
else:
nums[len(dataset_seed[dataset])] += 1
return dict(nums)
class ArchResults(object):
def __init__(self, arch_index, arch_str):
self.arch_index = int(arch_index)
self.arch_str = copy.deepcopy(arch_str)
self.all_results = dict()
self.dataset_seed = dict()
self.clear_net_done = False
def get_compute_costs(self, dataset):
x_seeds = self.dataset_seed[dataset]
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
flops = [result.flop for result in results]
params = [result.params for result in results]
latencies = [result.get_latency() for result in results]
latencies = [x for x in latencies if x > 0]
mean_latency = np.mean(latencies) if len(latencies) > 0 else None
time_infos = defaultdict(list)
for result in results:
time_info = result.get_times()
for key, value in time_info.items(): time_infos[key].append( value )
info = {'flops' : np.mean(flops),
'params' : np.mean(params),
'latency': mean_latency}
for key, value in time_infos.items():
if len(value) > 0 and value[0] is not None:
info[key] = np.mean(value)
else: info[key] = None
return info
def get_metrics(self, dataset, setname, iepoch=None, is_random=False):
"""
This `get_metrics` function is used to obtain obtain the loss, accuracy, etc information on a specific dataset.
If not specify, each set refer to the proposed split in NAS-Bench-201 paper.
If some args return None or raise error, then it is not avaliable.
========================================
Args [dataset] (4 possible options):
-- cifar10-valid : training the model on the CIFAR-10 training set.
-- cifar10 : training the model on the CIFAR-10 training + validation set.
-- cifar100 : training the model on the CIFAR-100 training set.
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
Args [setname] (each dataset has different setnames):
-- When dataset = cifar10-valid, you can use 'train', 'x-valid', 'ori-test'
------ 'train' : the metric on the training set.
------ 'x-valid' : the metric on the validation set.
------ 'ori-test' : the metric on the test set.
-- When dataset = cifar10, you can use 'train', 'ori-test'.
------ 'train' : the metric on the training + validation set.
------ 'ori-test' : the metric on the test set.
-- When dataset = cifar100 or ImageNet16-120, you can use 'train', 'ori-test', 'x-valid', 'x-test'
------ 'train' : the metric on the training set.
------ 'x-valid' : the metric on the validation set.
------ 'x-test' : the metric on the test set.
------ 'ori-test' : the metric on the validation + test set.
Args [iepoch] (None or an integer in [0, the-number-of-total-training-epochs)
------ None : return the metric after the last training epoch.
------ an integer i : return the metric after the i-th training epoch.
Args [is_random]:
------ True : return the metric of a randomly selected trial.
------ False : return the averaged metric of all avaliable trials.
------ an integer indicating the 'seed' value : return the metric of a specific trial (whose random seed is 'is_random').
"""
x_seeds = self.dataset_seed[dataset]
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
infos = defaultdict(list)
for result in results:
if setname == 'train':
info = result.get_train(iepoch)
else:
info = result.get_eval(setname, iepoch)
for key, value in info.items(): infos[key].append( value )
return_info = dict()
if isinstance(is_random, bool) and is_random: # randomly select one
index = random.randint(0, len(results)-1)
for key, value in infos.items(): return_info[key] = value[index]
elif isinstance(is_random, bool) and not is_random: # average
for key, value in infos.items():
if len(value) > 0 and value[0] is not None:
return_info[key] = np.mean(value)
else: return_info[key] = None
elif isinstance(is_random, int): # specify the seed
if is_random not in x_seeds: raise ValueError('can not find random seed ({:}) from {:}'.format(is_random, x_seeds))
index = x_seeds.index(is_random)
for key, value in infos.items(): return_info[key] = value[index]
else:
raise ValueError('invalid value for is_random: {:}'.format(is_random))
return return_info
def show(self, is_print=False):
return print_information(self, None, is_print)
def get_dataset_names(self):
return list(self.dataset_seed.keys())
def get_dataset_seeds(self, dataset):
return copy.deepcopy( self.dataset_seed[dataset] )
def get_net_param(self, dataset: Text, seed: Union[None, int] =None):
"""
This function will return the trained network's weights on the 'dataset'.
:arg
dataset: one of 'cifar10-valid', 'cifar10', 'cifar100', and 'ImageNet16-120'.
seed: an integer indicates the seed value or None that indicates returing all trials.
"""
if seed is None:
x_seeds = self.dataset_seed[dataset]
return {seed: self.all_results[(dataset, seed)].get_net_param() for seed in x_seeds}
else:
xkey = (dataset, seed)
if xkey in self.all_results:
return self.all_results[xkey].get_net_param()
else:
raise ValueError('key={:} not in {:}'.format(xkey, list(self.all_results.keys())))
def reset_latency(self, dataset: Text, seed: Union[None, Text], latency: float) -> None:
"""This function is used to reset the latency in all corresponding ResultsCount(s)."""
if seed is None:
for seed in self.dataset_seed[dataset]:
self.all_results[(dataset, seed)].update_latency([latency])
else:
self.all_results[(dataset, seed)].update_latency([latency])
def reset_pseudo_train_times(self, dataset: Text, seed: Union[None, Text], estimated_per_epoch_time: float) -> None:
"""This function is used to reset the train-times in all corresponding ResultsCount(s)."""
if seed is None:
for seed in self.dataset_seed[dataset]:
self.all_results[(dataset, seed)].reset_pseudo_train_times(estimated_per_epoch_time)
else:
self.all_results[(dataset, seed)].reset_pseudo_train_times(estimated_per_epoch_time)
def reset_pseudo_eval_times(self, dataset: Text, seed: Union[None, Text], eval_name: Text, estimated_per_epoch_time: float) -> None:
"""This function is used to reset the eval-times in all corresponding ResultsCount(s)."""
if seed is None:
for seed in self.dataset_seed[dataset]:
self.all_results[(dataset, seed)].reset_pseudo_eval_times(eval_name, estimated_per_epoch_time)
else:
self.all_results[(dataset, seed)].reset_pseudo_eval_times(eval_name, estimated_per_epoch_time)
def get_latency(self, dataset: Text) -> float:
"""Get the latency of a model on the target dataset. [Timestamp: 2020.03.09]"""
latencies = []
for seed in self.dataset_seed[dataset]:
latency = self.all_results[(dataset, seed)].get_latency()
if not isinstance(latency, float) or latency <= 0:
raise ValueError('invalid latency of {:} with seed={:} : {:}'.format(dataset, seed, latency))
latencies.append(latency)
return sum(latencies) / len(latencies)
def get_total_epoch(self, dataset=None):
"""Return the total number of training epochs."""
if dataset is None:
epochss = []
for xdata, x_seeds in self.dataset_seed.items():
epochss += [self.all_results[(xdata, seed)].get_total_epoch() for seed in x_seeds]
elif isinstance(dataset, str):
x_seeds = self.dataset_seed[dataset]
epochss = [self.all_results[(dataset, seed)].get_total_epoch() for seed in x_seeds]
else:
raise ValueError('invalid dataset={:}'.format(dataset))
if len(set(epochss)) > 1: raise ValueError('Each trial mush have the same number of training epochs : {:}'.format(epochss))
return epochss[-1]
def query(self, dataset, seed=None):
"""Return the ResultsCount object (containing all information of a single trial) for 'dataset' and 'seed'"""
if seed is None:
x_seeds = self.dataset_seed[dataset]
return {seed: self.all_results[(dataset, seed)] for seed in x_seeds}
else:
return self.all_results[(dataset, seed)]
def arch_idx_str(self):
return '{:06d}'.format(self.arch_index)
def update(self, dataset_name, seed, result):
if dataset_name not in self.dataset_seed:
self.dataset_seed[dataset_name] = []
assert seed not in self.dataset_seed[dataset_name], '{:}-th arch alreadly has this seed ({:}) on {:}'.format(self.arch_index, seed, dataset_name)
self.dataset_seed[ dataset_name ].append( seed )
self.dataset_seed[ dataset_name ] = sorted( self.dataset_seed[ dataset_name ] )
assert (dataset_name, seed) not in self.all_results
self.all_results[ (dataset_name, seed) ] = result
self.clear_net_done = False
def state_dict(self):
state_dict = dict()
for key, value in self.__dict__.items():
if key == 'all_results': # contain the class of ResultsCount
xvalue = dict()
assert isinstance(value, dict), 'invalid type of value for {:} : {:}'.format(key, type(value))
for _k, _v in value.items():
assert isinstance(_v, ResultsCount), 'invalid type of value for {:}/{:} : {:}'.format(key, _k, type(_v))
xvalue[_k] = _v.state_dict()
else:
xvalue = value
state_dict[key] = xvalue
return state_dict
def load_state_dict(self, state_dict):
new_state_dict = dict()
for key, value in state_dict.items():
if key == 'all_results': # to convert to the class of ResultsCount
xvalue = dict()
assert isinstance(value, dict), 'invalid type of value for {:} : {:}'.format(key, type(value))
for _k, _v in value.items():
xvalue[_k] = ResultsCount.create_from_state_dict(_v)
else: xvalue = value
new_state_dict[key] = xvalue
self.__dict__.update(new_state_dict)
@staticmethod
def create_from_state_dict(state_dict_or_file):
x = ArchResults(-1, -1)
if isinstance(state_dict_or_file, str): # a file path
state_dict = torch.load(state_dict_or_file, map_location='cpu')
elif isinstance(state_dict_or_file, dict):
state_dict = state_dict_or_file
else:
raise ValueError('invalid type of state_dict_or_file : {:}'.format(type(state_dict_or_file)))
x.load_state_dict(state_dict)
return x
# This function is used to clear the weights saved in each 'result'
# This can help reduce the memory footprint.
def clear_params(self):
for key, result in self.all_results.items():
del result.net_state_dict
result.net_state_dict = None
self.clear_net_done = True
def debug_test(self):
"""This function is used for me to debug and test, which will call most methods."""
all_dataset = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
for dataset in all_dataset:
print('---->>>> {:}'.format(dataset))
print('The latency on {:} is {:} s'.format(dataset, self.get_latency(dataset)))
for seed in self.dataset_seed[dataset]:
result = self.all_results[(dataset, seed)]
print(' ==>> result = {:}'.format(result))
print(' ==>> cost = {:}'.format(result.get_times()))
def __repr__(self):
return ('{name}(arch-index={index}, arch={arch}, {num} runs, clear={clear})'.format(name=self.__class__.__name__, index=self.arch_index, arch=self.arch_str, num=len(self.all_results), clear=self.clear_net_done))
"""
This class (ResultsCount) is used to save the information of one trial for a single architecture.
I did not write much comment for this class, because it is the lowest-level class in NAS-Bench-201 API, which will be rarely called.
If you have any question regarding this class, please open an issue or email me.
"""
class ResultsCount(object):
def __init__(self, name, state_dict, train_accs, train_losses, params, flop, arch_config, seed, epochs, latency):
self.name = name
self.net_state_dict = state_dict
self.train_acc1es = copy.deepcopy(train_accs)
self.train_acc5es = None
self.train_losses = copy.deepcopy(train_losses)
self.train_times = None
self.arch_config = copy.deepcopy(arch_config)
self.params = params
self.flop = flop
self.seed = seed
self.epochs = epochs
self.latency = latency
# evaluation results
self.reset_eval()
def update_train_info(self, train_acc1es, train_acc5es, train_losses, train_times) -> None:
self.train_acc1es = train_acc1es
self.train_acc5es = train_acc5es
self.train_losses = train_losses
self.train_times = train_times
def reset_pseudo_train_times(self, estimated_per_epoch_time: float) -> None:
"""Assign the training times."""
train_times = OrderedDict()
for i in range(self.epochs):
train_times[i] = estimated_per_epoch_time
self.train_times = train_times
def reset_pseudo_eval_times(self, eval_name: Text, estimated_per_epoch_time: float) -> None:
"""Assign the evaluation times."""
if eval_name not in self.eval_names: raise ValueError('invalid eval name : {:}'.format(eval_name))
for i in range(self.epochs):
self.eval_times['{:}@{:}'.format(eval_name,i)] = estimated_per_epoch_time
def reset_eval(self):
self.eval_names = []
self.eval_acc1es = {}
self.eval_times = {}
self.eval_losses = {}
def update_latency(self, latency):
self.latency = copy.deepcopy( latency )
def get_latency(self) -> float:
"""Return the latency value in seconds. -1 represents not avaliable ; otherwise it should be a float value"""
if self.latency is None: return -1.0
else: return sum(self.latency) / len(self.latency)
def update_eval(self, accs, losses, times): # new version
data_names = set([x.split('@')[0] for x in accs.keys()])
for data_name in data_names:
assert data_name not in self.eval_names, '{:} has already been added into eval-names'.format(data_name)
self.eval_names.append( data_name )
for iepoch in range(self.epochs):
xkey = '{:}@{:}'.format(data_name, iepoch)
self.eval_acc1es[ xkey ] = accs[ xkey ]
self.eval_losses[ xkey ] = losses[ xkey ]
self.eval_times [ xkey ] = times[ xkey ]
def update_OLD_eval(self, name, accs, losses): # old version
assert name not in self.eval_names, '{:} has already added'.format(name)
self.eval_names.append( name )
for iepoch in range(self.epochs):
if iepoch in accs:
self.eval_acc1es['{:}@{:}'.format(name,iepoch)] = accs[iepoch]
self.eval_losses['{:}@{:}'.format(name,iepoch)] = losses[iepoch]
def __repr__(self):
num_eval = len(self.eval_names)
set_name = '[' + ', '.join(self.eval_names) + ']'
return ('{name}({xname}, arch={arch}, FLOP={flop:.2f}M, Param={param:.3f}MB, seed={seed}, {num_eval} eval-sets: {set_name})'.format(name=self.__class__.__name__, xname=self.name, arch=self.arch_config['arch_str'], flop=self.flop, param=self.params, seed=self.seed, num_eval=num_eval, set_name=set_name))
def get_total_epoch(self):
return copy.deepcopy(self.epochs)
def get_times(self):
"""Obtain the information regarding both training and evaluation time."""
if self.train_times is not None and isinstance(self.train_times, dict):
train_times = list( self.train_times.values() )
time_info = {'T-train@epoch': np.mean(train_times), 'T-train@total': np.sum(train_times)}
else:
time_info = {'T-train@epoch': None, 'T-train@total': None }
for name in self.eval_names:
try:
xtimes = [self.eval_times['{:}@{:}'.format(name,i)] for i in range(self.epochs)]
time_info['T-{:}@epoch'.format(name)] = np.mean(xtimes)
time_info['T-{:}@total'.format(name)] = np.sum(xtimes)
except:
time_info['T-{:}@epoch'.format(name)] = None
time_info['T-{:}@total'.format(name)] = None
return time_info
def get_eval_set(self):
return self.eval_names
# get the training information
def get_train(self, iepoch=None):
if iepoch is None: iepoch = self.epochs-1
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
if self.train_times is not None:
xtime = self.train_times[iepoch]
atime = sum([self.train_times[i] for i in range(iepoch+1)])
else: xtime, atime = None, None
return {'iepoch' : iepoch,
'loss' : self.train_losses[iepoch],
'accuracy': self.train_acc1es[iepoch],
'cur_time': xtime,
'all_time': atime}
def get_eval(self, name, iepoch=None):
"""Get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument)."""
if iepoch is None: iepoch = self.epochs-1
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
def _internal_query(xname):
if isinstance(self.eval_times,dict) and len(self.eval_times) > 0:
xtime = self.eval_times['{:}@{:}'.format(xname, iepoch)]
atime = sum([self.eval_times['{:}@{:}'.format(xname, i)] for i in range(iepoch+1)])
else:
xtime, atime = None, None
return {'iepoch' : iepoch,
'loss' : self.eval_losses['{:}@{:}'.format(xname, iepoch)],
'accuracy': self.eval_acc1es['{:}@{:}'.format(xname, iepoch)],
'cur_time': xtime,
'all_time': atime}
if name == 'valid':
return _internal_query('x-valid')
else:
return _internal_query(name)
def get_net_param(self, clone=False):
if clone: return copy.deepcopy(self.net_state_dict)
else: return self.net_state_dict
def get_config(self, str2structure):
"""This function is used to obtain the config dict for this architecture."""
if str2structure is None:
# In this case, this is NAS-Bench-301
if 'name' in self.arch_config and self.arch_config['name'] == 'infer.shape.tiny':
return {'name': 'infer.shape.tiny', 'channels': self.arch_config['channels'],
'genotype': self.arch_config['genotype'], 'num_classes': self.arch_config['class_num']}
# In this case, this is NAS-Bench-201
else:
return {'name': 'infer.tiny', 'C': self.arch_config['channel'],
'N' : self.arch_config['num_cells'],
'arch_str': self.arch_config['arch_str'], 'num_classes': self.arch_config['class_num']}
else:
# In this case, this is NAS-Bench-301
if 'name' in self.arch_config and self.arch_config['name'] == 'infer.shape.tiny':
return {'name': 'infer.shape.tiny', 'channels': self.arch_config['channels'],
'genotype': str2structure(self.arch_config['genotype']), 'num_classes': self.arch_config['class_num']}
# In this case, this is NAS-Bench-201
else:
return {'name': 'infer.tiny', 'C': self.arch_config['channel'],
'N' : self.arch_config['num_cells'],
'genotype': str2structure(self.arch_config['arch_str']), 'num_classes': self.arch_config['class_num']}
def state_dict(self):
_state_dict = {key: value for key, value in self.__dict__.items()}
return _state_dict
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
@staticmethod
def create_from_state_dict(state_dict):
x = ResultsCount(None, None, None, None, None, None, None, None, None, None)
x.load_state_dict(state_dict)
return x

View File

@@ -0,0 +1,25 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from .starts import prepare_seed, prepare_logger, get_machine_info, save_checkpoint, copy_checkpoint
from .optimizers import get_optim_scheduler
from .funcs_nasbench import evaluate_for_seed as bench_evaluate_for_seed
from .funcs_nasbench import pure_evaluate as bench_pure_evaluate
from .funcs_nasbench import get_nas_bench_loaders
def get_procedures(procedure):
from .basic_main import basic_train, basic_valid
from .search_main import search_train, search_valid
from .search_main_v2 import search_train_v2
from .simple_KD_main import simple_KD_train, simple_KD_valid
train_funcs = {'basic' : basic_train, \
'search': search_train,'Simple-KD': simple_KD_train, \
'search-v2': search_train_v2}
valid_funcs = {'basic' : basic_valid, \
'search': search_valid,'Simple-KD': simple_KD_valid, \
'search-v2': search_valid}
train_func = train_funcs[procedure]
valid_func = valid_funcs[procedure]
return train_func, valid_func

View File

@@ -0,0 +1,75 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, torch
from log_utils import AverageMeter, time_string
from utils import obtain_accuracy
def basic_train(xloader, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger):
loss, acc1, acc5 = procedure(xloader, network, criterion, scheduler, optimizer, 'train', optim_config, extra_info, print_freq, logger)
return loss, acc1, acc5
def basic_valid(xloader, network, criterion, optim_config, extra_info, print_freq, logger):
with torch.no_grad():
loss, acc1, acc5 = procedure(xloader, network, criterion, None, None, 'valid', None, extra_info, print_freq, logger)
return loss, acc1, acc5
def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger):
data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
if mode == 'train':
network.train()
elif mode == 'valid':
network.eval()
else: raise ValueError("The mode is not right : {:}".format(mode))
#logger.log('[{:5s}] config :: auxiliary={:}, message={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, network.module.get_message()))
logger.log('[{:5s}] config :: auxiliary={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1))
end = time.time()
for i, (inputs, targets) in enumerate(xloader):
if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))
# measure data loading time
data_time.update(time.time() - end)
# calculate prediction and loss
targets = targets.cuda(non_blocking=True)
if mode == 'train': optimizer.zero_grad()
features, logits = network(inputs)
if isinstance(logits, list):
assert len(logits) == 2, 'logits must has {:} items instead of {:}'.format(2, len(logits))
logits, logits_aux = logits
else:
logits, logits_aux = logits, None
loss = criterion(logits, targets)
if config is not None and hasattr(config, 'auxiliary') and config.auxiliary > 0:
loss_aux = criterion(logits_aux, targets)
loss += config.auxiliary * loss_aux
if mode == 'train':
loss.backward()
optimizer.step()
# record
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update (prec1.item(), inputs.size(0))
top5.update (prec5.item(), inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % print_freq == 0 or (i+1) == len(xloader):
Sstr = ' {:5s} '.format(mode.upper()) + time_string() + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader))
if scheduler is not None:
Sstr += ' {:}'.format(scheduler.get_min_info())
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5)
Istr = 'Size={:}'.format(list(inputs.size()))
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr)
logger.log(' **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(mode=mode.upper(), top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg))
return losses.avg, top1.avg, top5.avg

View File

@@ -0,0 +1,203 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
#####################################################
import os, time, copy, torch, pathlib
import datasets
from config_utils import load_config
from autodl.procedures import prepare_seed, get_optim_scheduler
from autodl.utils import get_model_infos, obtain_accuracy
from autodl.log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net
__all__ = ['evaluate_for_seed', 'pure_evaluate', 'get_nas_bench_loaders']
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
latencies, device = [], torch.cuda.current_device()
network.eval()
with torch.no_grad():
end = time.time()
for i, (inputs, targets) in enumerate(xloader):
targets = targets.cuda(device=device, non_blocking=True)
inputs = inputs.cuda(device=device, non_blocking=True)
data_time.update(time.time() - end)
# forward
features, logits = network(inputs)
loss = criterion(logits, targets)
batch_time.update(time.time() - end)
if batch is None or batch == inputs.size(0):
batch = inputs.size(0)
latencies.append( batch_time.val - data_time.val )
# record loss and accuracy
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update (prec1.item(), inputs.size(0))
top5.update (prec5.item(), inputs.size(0))
end = time.time()
if len(latencies) > 2: latencies = latencies[1:]
return losses.avg, top1.avg, top5.avg, latencies
def procedure(xloader, network, criterion, scheduler, optimizer, mode: str):
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
if mode == 'train' : network.train()
elif mode == 'valid': network.eval()
else: raise ValueError("The mode is not right : {:}".format(mode))
device = torch.cuda.current_device()
data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time()
for i, (inputs, targets) in enumerate(xloader):
if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))
targets = targets.cuda(device=device, non_blocking=True)
if mode == 'train': optimizer.zero_grad()
# forward
features, logits = network(inputs)
loss = criterion(logits, targets)
# backward
if mode == 'train':
loss.backward()
optimizer.step()
# record loss and accuracy
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update (prec1.item(), inputs.size(0))
top5.update (prec5.item(), inputs.size(0))
# count time
batch_time.update(time.time() - end)
end = time.time()
return losses.avg, top1.avg, top5.avg, batch_time.sum
def evaluate_for_seed(arch_config, opt_config, train_loader, valid_loaders, seed: int, logger):
prepare_seed(seed) # random seed
net = get_cell_based_tiny_net(arch_config)
#net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
flop, param = get_model_infos(net, opt_config.xshape)
logger.log('Network : {:}'.format(net.get_message()), False)
logger.log('{:} Seed-------------------------- {:} --------------------------'.format(time_string(), seed))
logger.log('FLOP = {:} MB, Param = {:} MB'.format(flop, param))
# train and valid
optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), opt_config)
default_device = torch.cuda.current_device()
network = torch.nn.DataParallel(net, device_ids=[default_device]).cuda(device=default_device)
criterion = criterion.cuda(device=default_device)
# start training
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), opt_config.epochs + opt_config.warmup
train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {}
train_times , valid_times, lrs = {}, {}, {}
for epoch in range(total_epoch):
scheduler.update(epoch, 0.0)
lr = min(scheduler.get_lr())
train_loss, train_acc1, train_acc5, train_tm = procedure(train_loader, network, criterion, scheduler, optimizer, 'train')
train_losses[epoch] = train_loss
train_acc1es[epoch] = train_acc1
train_acc5es[epoch] = train_acc5
train_times [epoch] = train_tm
lrs[epoch] = lr
with torch.no_grad():
for key, xloder in valid_loaders.items():
valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(xloder , network, criterion, None, None, 'valid')
valid_losses['{:}@{:}'.format(key,epoch)] = valid_loss
valid_acc1es['{:}@{:}'.format(key,epoch)] = valid_acc1
valid_acc5es['{:}@{:}'.format(key,epoch)] = valid_acc5
valid_times ['{:}@{:}'.format(key,epoch)] = valid_tm
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (total_epoch-epoch-1), True) )
logger.log('{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%], lr={:}'.format(time_string(), need_time, epoch, total_epoch, train_loss, train_acc1, train_acc5, valid_loss, valid_acc1, valid_acc5, lr))
info_seed = {'flop' : flop,
'param': param,
'arch_config' : arch_config._asdict(),
'opt_config' : opt_config._asdict(),
'total_epoch' : total_epoch ,
'train_losses': train_losses,
'train_acc1es': train_acc1es,
'train_acc5es': train_acc5es,
'train_times' : train_times,
'valid_losses': valid_losses,
'valid_acc1es': valid_acc1es,
'valid_acc5es': valid_acc5es,
'valid_times' : valid_times,
'learning_rates': lrs,
'net_state_dict': net.state_dict(),
'net_string' : '{:}'.format(net),
'finish-train': True
}
return info_seed
def get_nas_bench_loaders(workers):
torch.set_num_threads(workers)
root_dir = (pathlib.Path(__file__).parent / '..' / '..').resolve()
torch_dir = pathlib.Path(os.environ['TORCH_HOME'])
# cifar
cifar_config_path = root_dir / 'configs' / 'nas-benchmark' / 'CIFAR.config'
cifar_config = load_config(cifar_config_path, None, None)
get_datasets = datasets.get_datasets # a function to return the dataset
break_line = '-' * 150
print ('{:} Create data-loader for all datasets'.format(time_string()))
print (break_line)
TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets('cifar10', str(torch_dir/'cifar.python'), -1)
print ('original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num))
cifar10_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'cifar-split.txt', None, None)
assert cifar10_splits.train[:10] == [0, 5, 7, 11, 13, 15, 16, 17, 20, 24] and cifar10_splits.valid[:10] == [1, 2, 3, 4, 6, 8, 9, 10, 12, 14]
temp_dataset = copy.deepcopy(TRAIN_CIFAR10)
temp_dataset.transform = VALID_CIFAR10.transform
# data loader
trainval_cifar10_loader = torch.utils.data.DataLoader(TRAIN_CIFAR10, batch_size=cifar_config.batch_size, shuffle=True , num_workers=workers, pin_memory=True)
train_cifar10_loader = torch.utils.data.DataLoader(TRAIN_CIFAR10, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.train), num_workers=workers, pin_memory=True)
valid_cifar10_loader = torch.utils.data.DataLoader(temp_dataset , batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.valid), num_workers=workers, pin_memory=True)
test__cifar10_loader = torch.utils.data.DataLoader(VALID_CIFAR10, batch_size=cifar_config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)
print ('CIFAR-10 : trval-loader has {:3d} batch with {:} per batch'.format(len(trainval_cifar10_loader), cifar_config.batch_size))
print ('CIFAR-10 : train-loader has {:3d} batch with {:} per batch'.format(len(train_cifar10_loader), cifar_config.batch_size))
print ('CIFAR-10 : valid-loader has {:3d} batch with {:} per batch'.format(len(valid_cifar10_loader), cifar_config.batch_size))
print ('CIFAR-10 : test--loader has {:3d} batch with {:} per batch'.format(len(test__cifar10_loader), cifar_config.batch_size))
print (break_line)
# CIFAR-100
TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets('cifar100', str(torch_dir/'cifar.python'), -1)
print ('original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num))
cifar100_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'cifar100-test-split.txt', None, None)
assert cifar100_splits.xvalid[:10] == [1, 3, 4, 5, 8, 10, 13, 14, 15, 16] and cifar100_splits.xtest[:10] == [0, 2, 6, 7, 9, 11, 12, 17, 20, 24]
train_cifar100_loader = torch.utils.data.DataLoader(TRAIN_CIFAR100, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True)
valid_cifar100_loader = torch.utils.data.DataLoader(VALID_CIFAR100, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), num_workers=workers, pin_memory=True)
test__cifar100_loader = torch.utils.data.DataLoader(VALID_CIFAR100, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest) , num_workers=workers, pin_memory=True)
print ('CIFAR-100 : train-loader has {:3d} batch'.format(len(train_cifar100_loader)))
print ('CIFAR-100 : valid-loader has {:3d} batch'.format(len(valid_cifar100_loader)))
print ('CIFAR-100 : test--loader has {:3d} batch'.format(len(test__cifar100_loader)))
print (break_line)
imagenet16_config_path = 'configs/nas-benchmark/ImageNet-16.config'
imagenet16_config = load_config(imagenet16_config_path, None, None)
TRAIN_ImageNet16_120, VALID_ImageNet16_120, xshape, class_num = get_datasets('ImageNet16-120', str(torch_dir/'cifar.python'/'ImageNet16'), -1)
print ('original TRAIN_ImageNet16_120: {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num))
imagenet_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'imagenet-16-120-test-split.txt', None, None)
assert imagenet_splits.xvalid[:10] == [1, 2, 3, 6, 7, 8, 9, 12, 16, 18] and imagenet_splits.xtest[:10] == [0, 4, 5, 10, 11, 13, 14, 15, 17, 20]
train_imagenet_loader = torch.utils.data.DataLoader(TRAIN_ImageNet16_120, batch_size=imagenet16_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True)
valid_imagenet_loader = torch.utils.data.DataLoader(VALID_ImageNet16_120, batch_size=imagenet16_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xvalid), num_workers=workers, pin_memory=True)
test__imagenet_loader = torch.utils.data.DataLoader(VALID_ImageNet16_120, batch_size=imagenet16_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xtest) , num_workers=workers, pin_memory=True)
print ('ImageNet-16-120 : train-loader has {:3d} batch with {:} per batch'.format(len(train_imagenet_loader), imagenet16_config.batch_size))
print ('ImageNet-16-120 : valid-loader has {:3d} batch with {:} per batch'.format(len(valid_imagenet_loader), imagenet16_config.batch_size))
print ('ImageNet-16-120 : test--loader has {:3d} batch with {:} per batch'.format(len(test__imagenet_loader), imagenet16_config.batch_size))
# 'cifar10', 'cifar100', 'ImageNet16-120'
loaders = {'cifar10@trainval': trainval_cifar10_loader,
'cifar10@train' : train_cifar10_loader,
'cifar10@valid' : valid_cifar10_loader,
'cifar10@test' : test__cifar10_loader,
'cifar100@train' : train_cifar100_loader,
'cifar100@valid' : valid_cifar100_loader,
'cifar100@test' : test__cifar100_loader,
'ImageNet16-120@train': train_imagenet_loader,
'ImageNet16-120@valid': valid_imagenet_loader,
'ImageNet16-120@test' : test__imagenet_loader}
return loaders

View File

@@ -0,0 +1,204 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import math, torch
import torch.nn as nn
from bisect import bisect_right
from torch.optim import Optimizer
class _LRScheduler(object):
def __init__(self, optimizer, warmup_epochs, epochs):
if not isinstance(optimizer, Optimizer):
raise TypeError('{:} is not an Optimizer'.format(type(optimizer).__name__))
self.optimizer = optimizer
for group in optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
self.max_epochs = epochs
self.warmup_epochs = warmup_epochs
self.current_epoch = 0
self.current_iter = 0
def extra_repr(self):
return ''
def __repr__(self):
return ('{name}(warmup={warmup_epochs}, max-epoch={max_epochs}, current::epoch={current_epoch}, iter={current_iter:.2f}'.format(name=self.__class__.__name__, **self.__dict__)
+ ', {:})'.format(self.extra_repr()))
def state_dict(self):
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
def get_lr(self):
raise NotImplementedError
def get_min_info(self):
lrs = self.get_lr()
return '#LR=[{:.6f}~{:.6f}] epoch={:03d}, iter={:4.2f}#'.format(min(lrs), max(lrs), self.current_epoch, self.current_iter)
def get_min_lr(self):
return min( self.get_lr() )
def update(self, cur_epoch, cur_iter):
if cur_epoch is not None:
assert isinstance(cur_epoch, int) and cur_epoch>=0, 'invalid cur-epoch : {:}'.format(cur_epoch)
self.current_epoch = cur_epoch
if cur_iter is not None:
assert isinstance(cur_iter, float) and cur_iter>=0, 'invalid cur-iter : {:}'.format(cur_iter)
self.current_iter = cur_iter
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
class CosineAnnealingLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, T_max, eta_min):
self.T_max = T_max
self.eta_min = eta_min
super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self):
return 'type={:}, T-max={:}, eta-min={:}'.format('cosine', self.T_max, self.eta_min)
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs and self.current_epoch < self.max_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
#if last_epoch < self.T_max:
#if last_epoch < self.max_epochs:
lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * last_epoch / self.T_max)) / 2
#else:
# lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2
elif self.current_epoch >= self.max_epochs:
lr = self.eta_min
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lrs.append( lr )
return lrs
class MultiStepLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas):
assert len(milestones) == len(gammas), 'invalid {:} vs {:}'.format(len(milestones), len(gammas))
self.milestones = milestones
self.gammas = gammas
super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self):
return 'type={:}, milestones={:}, gammas={:}, base-lrs={:}'.format('multistep', self.milestones, self.gammas, self.base_lrs)
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
idx = bisect_right(self.milestones, last_epoch)
lr = base_lr
for x in self.gammas[:idx]: lr *= x
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lrs.append( lr )
return lrs
class ExponentialLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, gamma):
self.gamma = gamma
super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self):
return 'type={:}, gamma={:}, base-lrs={:}'.format('exponential', self.gamma, self.base_lrs)
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
assert last_epoch >= 0, 'invalid last_epoch : {:}'.format(last_epoch)
lr = base_lr * (self.gamma ** last_epoch)
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lrs.append( lr )
return lrs
class LinearLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, max_LR, min_LR):
self.max_LR = max_LR
self.min_LR = min_LR
super(LinearLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self):
return 'type={:}, max_LR={:}, min_LR={:}, base-lrs={:}'.format('LinearLR', self.max_LR, self.min_LR, self.base_lrs)
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
assert last_epoch >= 0, 'invalid last_epoch : {:}'.format(last_epoch)
ratio = (self.max_LR - self.min_LR) * last_epoch / self.max_epochs / self.max_LR
lr = base_lr * (1-ratio)
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lrs.append( lr )
return lrs
class CrossEntropyLabelSmooth(nn.Module):
def __init__(self, num_classes, epsilon):
super(CrossEntropyLabelSmooth, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, inputs, targets):
log_probs = self.logsoftmax(inputs)
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
loss = (-targets * log_probs).mean(0).sum()
return loss
def get_optim_scheduler(parameters, config):
assert hasattr(config, 'optim') and hasattr(config, 'scheduler') and hasattr(config, 'criterion'), 'config must have optim / scheduler / criterion keys instead of {:}'.format(config)
if config.optim == 'SGD':
optim = torch.optim.SGD(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay, nesterov=config.nesterov)
elif config.optim == 'RMSprop':
optim = torch.optim.RMSprop(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay)
else:
raise ValueError('invalid optim : {:}'.format(config.optim))
if config.scheduler == 'cos':
T_max = getattr(config, 'T_max', config.epochs)
scheduler = CosineAnnealingLR(optim, config.warmup, config.epochs, T_max, config.eta_min)
elif config.scheduler == 'multistep':
scheduler = MultiStepLR(optim, config.warmup, config.epochs, config.milestones, config.gammas)
elif config.scheduler == 'exponential':
scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma)
elif config.scheduler == 'linear':
scheduler = LinearLR(optim, config.warmup, config.epochs, config.LR, config.LR_min)
else:
raise ValueError('invalid scheduler : {:}'.format(config.scheduler))
if config.criterion == 'Softmax':
criterion = torch.nn.CrossEntropyLoss()
elif config.criterion == 'SmoothSoftmax':
criterion = CrossEntropyLabelSmooth(config.class_num, config.label_smooth)
else:
raise ValueError('invalid criterion : {:}'.format(config.criterion))
return optim, scheduler, criterion

View File

@@ -0,0 +1,126 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, torch
from log_utils import AverageMeter, time_string
from utils import obtain_accuracy
from models import change_key
def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant):
expected_flop = torch.mean( expected_flop )
if flop_cur < flop_need - flop_tolerant: # Too Small FLOP
loss = - torch.log( expected_flop )
#elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP
elif flop_cur > flop_need: # Too Large FLOP
loss = torch.log( expected_flop )
else: # Required FLOP
loss = None
if loss is None: return 0, 0
else : return loss, loss.item()
def search_train(search_loader, network, criterion, scheduler, base_optimizer, arch_optimizer, optim_config, extra_info, print_freq, logger):
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
epoch_str, flop_need, flop_weight, flop_tolerant = extra_info['epoch-str'], extra_info['FLOP-exp'], extra_info['FLOP-weight'], extra_info['FLOP-tolerant']
network.train()
logger.log('[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}'.format(epoch_str, flop_need, flop_weight))
end = time.time()
network.apply( change_key('search_mode', 'search') )
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader):
scheduler.update(None, 1.0 * step / len(search_loader))
# calculate prediction and loss
base_targets = base_targets.cuda(non_blocking=True)
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# update the weights
base_optimizer.zero_grad()
logits, expected_flop = network(base_inputs)
#network.apply( change_key('search_mode', 'basic') )
#features, logits = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
base_optimizer.step()
# record
prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
base_losses.update(base_loss.item(), base_inputs.size(0))
top1.update (prec1.item(), base_inputs.size(0))
top5.update (prec5.item(), base_inputs.size(0))
# update the architecture
arch_optimizer.zero_grad()
logits, expected_flop = network(arch_inputs)
flop_cur = network.module.get_flop('genotype', None, None)
flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant)
acls_loss = criterion(logits, arch_targets)
arch_loss = acls_loss + flop_loss * flop_weight
arch_loss.backward()
arch_optimizer.step()
# record
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0))
arch_cls_losses.update (acls_loss.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if step % print_freq == 0 or (step+1) == len(search_loader):
Sstr = '**TRAIN** ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(search_loader))
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
Lstr = 'Base-Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=base_losses, top1=top1, top5=top5)
Vstr = 'Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})'.format(aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses)
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr)
#Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size()))
#logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr)
#print(network.module.get_arch_info())
#print(network.module.width_attentions[0])
#print(network.module.width_attentions[1])
logger.log(' **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, baseloss=base_losses.avg, archloss=arch_losses.avg))
return base_losses.avg, arch_losses.avg, top1.avg, top5.avg
def search_valid(xloader, network, criterion, extra_info, print_freq, logger):
data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
network.eval()
network.apply( change_key('search_mode', 'search') )
end = time.time()
#logger.log('Starting evaluating {:}'.format(epoch_info))
with torch.no_grad():
for i, (inputs, targets) in enumerate(xloader):
# measure data loading time
data_time.update(time.time() - end)
# calculate prediction and loss
targets = targets.cuda(non_blocking=True)
logits, expected_flop = network(inputs)
loss = criterion(logits, targets)
# record
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update (prec1.item(), inputs.size(0))
top5.update (prec5.item(), inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % print_freq == 0 or (i+1) == len(xloader):
Sstr = '**VALID** ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader))
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5)
Istr = 'Size={:}'.format(list(inputs.size()))
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr)
logger.log(' **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg))
return losses.avg, top1.avg, top5.avg

View File

@@ -0,0 +1,87 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, torch
from log_utils import AverageMeter, time_string
from utils import obtain_accuracy
from models import change_key
def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant):
expected_flop = torch.mean( expected_flop )
if flop_cur < flop_need - flop_tolerant: # Too Small FLOP
loss = - torch.log( expected_flop )
#elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP
elif flop_cur > flop_need: # Too Large FLOP
loss = torch.log( expected_flop )
else: # Required FLOP
loss = None
if loss is None: return 0, 0
else : return loss, loss.item()
def search_train_v2(search_loader, network, criterion, scheduler, base_optimizer, arch_optimizer, optim_config, extra_info, print_freq, logger):
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
epoch_str, flop_need, flop_weight, flop_tolerant = extra_info['epoch-str'], extra_info['FLOP-exp'], extra_info['FLOP-weight'], extra_info['FLOP-tolerant']
network.train()
logger.log('[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}'.format(epoch_str, flop_need, flop_weight))
end = time.time()
network.apply( change_key('search_mode', 'search') )
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader):
scheduler.update(None, 1.0 * step / len(search_loader))
# calculate prediction and loss
base_targets = base_targets.cuda(non_blocking=True)
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# update the weights
base_optimizer.zero_grad()
logits, expected_flop = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
base_optimizer.step()
# record
prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
base_losses.update(base_loss.item(), base_inputs.size(0))
top1.update (prec1.item(), base_inputs.size(0))
top5.update (prec5.item(), base_inputs.size(0))
# update the architecture
arch_optimizer.zero_grad()
logits, expected_flop = network(arch_inputs)
flop_cur = network.module.get_flop('genotype', None, None)
flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant)
acls_loss = criterion(logits, arch_targets)
arch_loss = acls_loss + flop_loss * flop_weight
arch_loss.backward()
arch_optimizer.step()
# record
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0))
arch_cls_losses.update (acls_loss.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if step % print_freq == 0 or (step+1) == len(search_loader):
Sstr = '**TRAIN** ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(search_loader))
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
Lstr = 'Base-Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=base_losses, top1=top1, top5=top5)
Vstr = 'Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})'.format(aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses)
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr)
#num_bytes = torch.cuda.max_memory_allocated( next(network.parameters()).device ) * 1.0
#logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' GPU={:.2f}MB'.format(num_bytes/1e6))
#Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size()))
#logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr)
#print(network.module.get_arch_info())
#print(network.module.width_attentions[0])
#print(network.module.width_attentions[1])
logger.log(' **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, baseloss=base_losses.avg, archloss=arch_losses.avg))
return base_losses.avg, arch_losses.avg, top1.avg, top5.avg

View File

@@ -0,0 +1,94 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import os, sys, time, torch
import torch.nn.functional as F
# our modules
from log_utils import AverageMeter, time_string
from utils import obtain_accuracy
def simple_KD_train(xloader, teacher, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger):
loss, acc1, acc5 = procedure(xloader, teacher, network, criterion, scheduler, optimizer, 'train', optim_config, extra_info, print_freq, logger)
return loss, acc1, acc5
def simple_KD_valid(xloader, teacher, network, criterion, optim_config, extra_info, print_freq, logger):
with torch.no_grad():
loss, acc1, acc5 = procedure(xloader, teacher, network, criterion, None, None, 'valid', optim_config, extra_info, print_freq, logger)
return loss, acc1, acc5
def loss_KD_fn(criterion, student_logits, teacher_logits, studentFeatures, teacherFeatures, targets, alpha, temperature):
basic_loss = criterion(student_logits, targets) * (1. - alpha)
log_student= F.log_softmax(student_logits / temperature, dim=1)
sof_teacher= F.softmax (teacher_logits / temperature, dim=1)
KD_loss = F.kl_div(log_student, sof_teacher, reduction='batchmean') * (alpha * temperature * temperature)
return basic_loss + KD_loss
def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger):
data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
Ttop1, Ttop5 = AverageMeter(), AverageMeter()
if mode == 'train':
network.train()
elif mode == 'valid':
network.eval()
else: raise ValueError("The mode is not right : {:}".format(mode))
teacher.eval()
logger.log('[{:5s}] config :: auxiliary={:}, KD :: [alpha={:.2f}, temperature={:.2f}]'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, config.KD_alpha, config.KD_temperature))
end = time.time()
for i, (inputs, targets) in enumerate(xloader):
if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))
# measure data loading time
data_time.update(time.time() - end)
# calculate prediction and loss
targets = targets.cuda(non_blocking=True)
if mode == 'train': optimizer.zero_grad()
student_f, logits = network(inputs)
if isinstance(logits, list):
assert len(logits) == 2, 'logits must has {:} items instead of {:}'.format(2, len(logits))
logits, logits_aux = logits
else:
logits, logits_aux = logits, None
with torch.no_grad():
teacher_f, teacher_logits = teacher(inputs)
loss = loss_KD_fn(criterion, logits, teacher_logits, student_f, teacher_f, targets, config.KD_alpha, config.KD_temperature)
if config is not None and hasattr(config, 'auxiliary') and config.auxiliary > 0:
loss_aux = criterion(logits_aux, targets)
loss += config.auxiliary * loss_aux
if mode == 'train':
loss.backward()
optimizer.step()
# record
sprec1, sprec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update (sprec1.item(), inputs.size(0))
top5.update (sprec5.item(), inputs.size(0))
# teacher
tprec1, tprec5 = obtain_accuracy(teacher_logits.data, targets.data, topk=(1, 5))
Ttop1.update (tprec1.item(), inputs.size(0))
Ttop5.update (tprec5.item(), inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % print_freq == 0 or (i+1) == len(xloader):
Sstr = ' {:5s} '.format(mode.upper()) + time_string() + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader))
if scheduler is not None:
Sstr += ' {:}'.format(scheduler.get_min_info())
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5)
Lstr+= ' Teacher : acc@1={:.2f}, acc@5={:.2f}'.format(Ttop1.avg, Ttop5.avg)
Istr = 'Size={:}'.format(list(inputs.size()))
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr)
logger.log(' **{:5s}** accuracy drop :: @1={:.2f}, @5={:.2f}'.format(mode.upper(), Ttop1.avg - top1.avg, Ttop5.avg - top5.avg))
logger.log(' **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(mode=mode.upper(), top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg))
return losses.avg, top1.avg, top5.avg

View File

@@ -0,0 +1,64 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, torch, random, PIL, copy, numpy as np
from os import path as osp
from shutil import copyfile
def prepare_seed(rand_seed):
random.seed(rand_seed)
np.random.seed(rand_seed)
torch.manual_seed(rand_seed)
torch.cuda.manual_seed(rand_seed)
torch.cuda.manual_seed_all(rand_seed)
def prepare_logger(xargs):
args = copy.deepcopy( xargs )
from autodl.log_utils import Logger
logger = Logger(args.save_dir, args.rand_seed)
logger.log('Main Function with logger : {:}'.format(logger))
logger.log('Arguments : -------------------------------')
for name, value in args._get_kwargs():
logger.log('{:16} : {:}'.format(name, value))
logger.log("Python Version : {:}".format(sys.version.replace('\n', ' ')))
logger.log("Pillow Version : {:}".format(PIL.__version__))
logger.log("PyTorch Version : {:}".format(torch.__version__))
logger.log("cuDNN Version : {:}".format(torch.backends.cudnn.version()))
logger.log("CUDA available : {:}".format(torch.cuda.is_available()))
logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count()))
logger.log("CUDA_VISIBLE_DEVICES : {:}".format(os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ else 'None'))
return logger
def get_machine_info():
info = "Python Version : {:}".format(sys.version.replace('\n', ' '))
info+= "\nPillow Version : {:}".format(PIL.__version__)
info+= "\nPyTorch Version : {:}".format(torch.__version__)
info+= "\ncuDNN Version : {:}".format(torch.backends.cudnn.version())
info+= "\nCUDA available : {:}".format(torch.cuda.is_available())
info+= "\nCUDA GPU numbers : {:}".format(torch.cuda.device_count())
if 'CUDA_VISIBLE_DEVICES' in os.environ:
info+= "\nCUDA_VISIBLE_DEVICES={:}".format(os.environ['CUDA_VISIBLE_DEVICES'])
else:
info+= "\nDoes not set CUDA_VISIBLE_DEVICES"
return info
def save_checkpoint(state, filename, logger):
if osp.isfile(filename):
if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(filename))
os.remove(filename)
torch.save(state, filename)
assert osp.isfile(filename), 'save filename : {:} failed, which is not found.'.format(filename)
if hasattr(logger, 'log'): logger.log('save checkpoint into {:}'.format(filename))
return filename
def copy_checkpoint(src, dst, logger):
if osp.isfile(dst):
if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(dst))
os.remove(dst)
copyfile(src, dst)
if hasattr(logger, 'log'): logger.log('copy the file from {:} into {:}'.format(src, dst))

5
autodl/utils/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
from .evaluation_utils import obtain_accuracy
from .gpu_manager import GPUManager
from .flop_benchmark import get_model_infos, count_parameters_in_MB
from .affine_utils import normalize_points, denormalize_points
from .affine_utils import identity2affine, solve2theta, affine2image

View File

@@ -0,0 +1,125 @@
# functions for affine transformation
import math, torch
import numpy as np
import torch.nn.functional as F
def identity2affine(full=False):
if not full:
parameters = torch.zeros((2,3))
parameters[0, 0] = parameters[1, 1] = 1
else:
parameters = torch.zeros((3,3))
parameters[0, 0] = parameters[1, 1] = parameters[2, 2] = 1
return parameters
def normalize_L(x, L):
return -1. + 2. * x / (L-1)
def denormalize_L(x, L):
return (x + 1.0) / 2.0 * (L-1)
def crop2affine(crop_box, W, H):
assert len(crop_box) == 4, 'Invalid crop-box : {:}'.format(crop_box)
parameters = torch.zeros(3,3)
x1, y1 = normalize_L(crop_box[0], W), normalize_L(crop_box[1], H)
x2, y2 = normalize_L(crop_box[2], W), normalize_L(crop_box[3], H)
parameters[0,0] = (x2-x1)/2
parameters[0,2] = (x2+x1)/2
parameters[1,1] = (y2-y1)/2
parameters[1,2] = (y2+y1)/2
parameters[2,2] = 1
return parameters
def scale2affine(scalex, scaley):
parameters = torch.zeros(3,3)
parameters[0,0] = scalex
parameters[1,1] = scaley
parameters[2,2] = 1
return parameters
def offset2affine(offx, offy):
parameters = torch.zeros(3,3)
parameters[0,0] = parameters[1,1] = parameters[2,2] = 1
parameters[0,2] = offx
parameters[1,2] = offy
return parameters
def horizontalmirror2affine():
parameters = torch.zeros(3,3)
parameters[0,0] = -1
parameters[1,1] = parameters[2,2] = 1
return parameters
# clockwise rotate image = counterclockwise rotate the rectangle
# degree is between [0, 360]
def rotate2affine(degree):
assert degree >= 0 and degree <= 360, 'Invalid degree : {:}'.format(degree)
degree = degree / 180 * math.pi
parameters = torch.zeros(3,3)
parameters[0,0] = math.cos(-degree)
parameters[0,1] = -math.sin(-degree)
parameters[1,0] = math.sin(-degree)
parameters[1,1] = math.cos(-degree)
parameters[2,2] = 1
return parameters
# shape is a tuple [H, W]
def normalize_points(shape, points):
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, 'invalid shape : {:}'.format(shape)
assert isinstance(points, torch.Tensor) and (points.shape[0] == 2), 'points are wrong : {:}'.format(points.shape)
(H, W), points = shape, points.clone()
points[0, :] = normalize_L(points[0,:], W)
points[1, :] = normalize_L(points[1,:], H)
return points
# shape is a tuple [H, W]
def normalize_points_batch(shape, points):
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, 'invalid shape : {:}'.format(shape)
assert isinstance(points, torch.Tensor) and (points.size(-1) == 2), 'points are wrong : {:}'.format(points.shape)
(H, W), points = shape, points.clone()
x = normalize_L(points[...,0], W)
y = normalize_L(points[...,1], H)
return torch.stack((x,y), dim=-1)
# shape is a tuple [H, W]
def denormalize_points(shape, points):
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, 'invalid shape : {:}'.format(shape)
assert isinstance(points, torch.Tensor) and (points.shape[0] == 2), 'points are wrong : {:}'.format(points.shape)
(H, W), points = shape, points.clone()
points[0, :] = denormalize_L(points[0,:], W)
points[1, :] = denormalize_L(points[1,:], H)
return points
# shape is a tuple [H, W]
def denormalize_points_batch(shape, points):
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, 'invalid shape : {:}'.format(shape)
assert isinstance(points, torch.Tensor) and (points.shape[-1] == 2), 'points are wrong : {:}'.format(points.shape)
(H, W), points = shape, points.clone()
x = denormalize_L(points[...,0], W)
y = denormalize_L(points[...,1], H)
return torch.stack((x,y), dim=-1)
# make target * theta = source
def solve2theta(source, target):
source, target = source.clone(), target.clone()
oks = source[2, :] == 1
assert torch.sum(oks).item() >= 3, 'valid points : {:} is short'.format(oks)
if target.size(0) == 2: target = torch.cat((target, oks.unsqueeze(0).float()), dim=0)
source, target = source[:, oks], target[:, oks]
source, target = source.transpose(1,0), target.transpose(1,0)
assert source.size(1) == target.size(1) == 3
#X, residual, rank, s = np.linalg.lstsq(target.numpy(), source.numpy())
#theta = torch.Tensor(X.T[:2, :])
X_, qr = torch.gels(source, target)
theta = X_[:3, :2].transpose(1, 0)
return theta
# shape = [H,W]
def affine2image(image, theta, shape):
C, H, W = image.size()
theta = theta[:2, :].unsqueeze(0)
grid_size = torch.Size([1, C, shape[0], shape[1]])
grid = F.affine_grid(theta, grid_size)
affI = F.grid_sample(image.unsqueeze(0), grid, mode='bilinear', padding_mode='border')
return affI.squeeze(0)

View File

@@ -0,0 +1,16 @@
import torch
def obtain_accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res

View File

@@ -0,0 +1,181 @@
import torch
import torch.nn as nn
import numpy as np
def count_parameters_in_MB(model):
if isinstance(model, nn.Module):
return np.sum(np.prod(v.size()) for v in model.parameters())/1e6
else:
return np.sum(np.prod(v.size()) for v in model)/1e6
def get_model_infos(model, shape):
#model = copy.deepcopy( model )
model = add_flops_counting_methods(model)
#model = model.cuda()
model.eval()
#cache_inputs = torch.zeros(*shape).cuda()
#cache_inputs = torch.zeros(*shape)
cache_inputs = torch.rand(*shape)
if next(model.parameters()).is_cuda: cache_inputs = cache_inputs.cuda()
#print_log('In the calculating function : cache input size : {:}'.format(cache_inputs.size()), log)
with torch.no_grad():
_____ = model(cache_inputs)
FLOPs = compute_average_flops_cost( model ) / 1e6
Param = count_parameters_in_MB(model)
if hasattr(model, 'auxiliary_param'):
aux_params = count_parameters_in_MB(model.auxiliary_param())
print ('The auxiliary params of this model is : {:}'.format(aux_params))
print ('We remove the auxiliary params from the total params ({:}) when counting'.format(Param))
Param = Param - aux_params
#print_log('FLOPs : {:} MB'.format(FLOPs), log)
torch.cuda.empty_cache()
model.apply( remove_hook_function )
return FLOPs, Param
# ---- Public functions
def add_flops_counting_methods( model ):
model.__batch_counter__ = 0
add_batch_counter_hook_function( model )
model.apply( add_flops_counter_variable_or_reset )
model.apply( add_flops_counter_hook_function )
return model
def compute_average_flops_cost(model):
"""
A method that will be available after add_flops_counting_methods() is called on a desired net object.
Returns current mean flops consumption per image.
"""
batches_count = model.__batch_counter__
flops_sum = 0
#or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
for module in model.modules():
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \
or isinstance(module, torch.nn.Conv1d) \
or hasattr(module, 'calculate_flop_self'):
flops_sum += module.__flops__
return flops_sum / batches_count
# ---- Internal functions
def pool_flops_counter_hook(pool_module, inputs, output):
batch_size = inputs[0].size(0)
kernel_size = pool_module.kernel_size
out_C, output_height, output_width = output.shape[1:]
assert out_C == inputs[0].size(1), '{:} vs. {:}'.format(out_C, inputs[0].size())
overall_flops = batch_size * out_C * output_height * output_width * kernel_size * kernel_size
pool_module.__flops__ += overall_flops
def self_calculate_flops_counter_hook(self_module, inputs, output):
overall_flops = self_module.calculate_flop_self(inputs[0].shape, output.shape)
self_module.__flops__ += overall_flops
def fc_flops_counter_hook(fc_module, inputs, output):
batch_size = inputs[0].size(0)
xin, xout = fc_module.in_features, fc_module.out_features
assert xin == inputs[0].size(1) and xout == output.size(1), 'IO=({:}, {:})'.format(xin, xout)
overall_flops = batch_size * xin * xout
if fc_module.bias is not None:
overall_flops += batch_size * xout
fc_module.__flops__ += overall_flops
def conv1d_flops_counter_hook(conv_module, inputs, outputs):
batch_size = inputs[0].size(0)
outL = outputs.shape[-1]
[kernel] = conv_module.kernel_size
in_channels = conv_module.in_channels
out_channels = conv_module.out_channels
groups = conv_module.groups
conv_per_position_flops = kernel * in_channels * out_channels / groups
active_elements_count = batch_size * outL
overall_flops = conv_per_position_flops * active_elements_count
if conv_module.bias is not None:
overall_flops += out_channels * active_elements_count
conv_module.__flops__ += overall_flops
def conv2d_flops_counter_hook(conv_module, inputs, output):
batch_size = inputs[0].size(0)
output_height, output_width = output.shape[2:]
kernel_height, kernel_width = conv_module.kernel_size
in_channels = conv_module.in_channels
out_channels = conv_module.out_channels
groups = conv_module.groups
conv_per_position_flops = kernel_height * kernel_width * in_channels * out_channels / groups
active_elements_count = batch_size * output_height * output_width
overall_flops = conv_per_position_flops * active_elements_count
if conv_module.bias is not None:
overall_flops += out_channels * active_elements_count
conv_module.__flops__ += overall_flops
def batch_counter_hook(module, inputs, output):
# Can have multiple inputs, getting the first one
inputs = inputs[0]
batch_size = inputs.shape[0]
module.__batch_counter__ += batch_size
def add_batch_counter_hook_function(module):
if not hasattr(module, '__batch_counter_handle__'):
handle = module.register_forward_hook(batch_counter_hook)
module.__batch_counter_handle__ = handle
def add_flops_counter_variable_or_reset(module):
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \
or isinstance(module, torch.nn.Conv1d) \
or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
or hasattr(module, 'calculate_flop_self'):
module.__flops__ = 0
def add_flops_counter_hook_function(module):
if isinstance(module, torch.nn.Conv2d):
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(conv2d_flops_counter_hook)
module.__flops_handle__ = handle
elif isinstance(module, torch.nn.Conv1d):
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(conv1d_flops_counter_hook)
module.__flops_handle__ = handle
elif isinstance(module, torch.nn.Linear):
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(fc_flops_counter_hook)
module.__flops_handle__ = handle
elif isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d):
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(pool_flops_counter_hook)
module.__flops_handle__ = handle
elif hasattr(module, 'calculate_flop_self'): # self-defined module
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(self_calculate_flops_counter_hook)
module.__flops_handle__ = handle
def remove_hook_function(module):
hookers = ['__batch_counter_handle__', '__flops_handle__']
for hooker in hookers:
if hasattr(module, hooker):
handle = getattr(module, hooker)
handle.remove()
keys = ['__flops__', '__batch_counter__', '__flops__'] + hookers
for ckey in keys:
if hasattr(module, ckey): delattr(module, ckey)

View File

@@ -0,0 +1,70 @@
import os
class GPUManager():
queries = ('index', 'gpu_name', 'memory.free', 'memory.used', 'memory.total', 'power.draw', 'power.limit')
def __init__(self):
all_gpus = self.query_gpu(False)
def get_info(self, ctype):
cmd = 'nvidia-smi --query-gpu={} --format=csv,noheader'.format(ctype)
lines = os.popen(cmd).readlines()
lines = [line.strip('\n') for line in lines]
return lines
def query_gpu(self, show=True):
num_gpus = len( self.get_info('index') )
all_gpus = [ {} for i in range(num_gpus) ]
for query in self.queries:
infos = self.get_info(query)
for idx, info in enumerate(infos):
all_gpus[idx][query] = info
if 'CUDA_VISIBLE_DEVICES' in os.environ:
CUDA_VISIBLE_DEVICES = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
selected_gpus = []
for idx, CUDA_VISIBLE_DEVICE in enumerate(CUDA_VISIBLE_DEVICES):
find = False
for gpu in all_gpus:
if gpu['index'] == CUDA_VISIBLE_DEVICE:
assert not find, 'Duplicate cuda device index : {}'.format(CUDA_VISIBLE_DEVICE)
find = True
selected_gpus.append( gpu.copy() )
selected_gpus[-1]['index'] = '{}'.format(idx)
assert find, 'Does not find the device : {}'.format(CUDA_VISIBLE_DEVICE)
all_gpus = selected_gpus
if show:
allstrings = ''
for gpu in all_gpus:
string = '| '
for query in self.queries:
if query.find('memory') == 0: xinfo = '{:>9}'.format(gpu[query])
else: xinfo = gpu[query]
string = string + query + ' : ' + xinfo + ' | '
allstrings = allstrings + string + '\n'
return allstrings
else:
return all_gpus
def select_by_memory(self, numbers=1):
all_gpus = self.query_gpu(False)
assert numbers <= len(all_gpus), 'Require {} gpus more than you have'.format(numbers)
alls = []
for idx, gpu in enumerate(all_gpus):
free_memory = gpu['memory.free']
free_memory = free_memory.split(' ')[0]
free_memory = int(free_memory)
index = gpu['index']
alls.append((free_memory, index))
alls.sort(reverse = True)
alls = [ int(alls[i][1]) for i in range(numbers) ]
return sorted(alls)
"""
if __name__ == '__main__':
manager = GPUManager()
manager.query_gpu(True)
indexes = manager.select_by_memory(3)
print (indexes)
"""

57
autodl/utils/nas_utils.py Normal file
View File

@@ -0,0 +1,57 @@
# This file is for experimental usage
import torch, random
import numpy as np
from copy import deepcopy
import torch.nn as nn
# from utils import obtain_accuracy
from models import CellStructure
from log_utils import time_string
def evaluate_one_shot(model, xloader, api, cal_mode, seed=111):
print ('This is an old version of codes to use NAS-Bench-API, and should be modified to align with the new version. Please contact me for more details if you use this function.')
weights = deepcopy(model.state_dict())
model.train(cal_mode)
with torch.no_grad():
logits = nn.functional.log_softmax(model.arch_parameters, dim=-1)
archs = CellStructure.gen_all(model.op_names, model.max_nodes, False)
probs, accuracies, gt_accs_10_valid, gt_accs_10_test = [], [], [], []
loader_iter = iter(xloader)
random.seed(seed)
random.shuffle(archs)
for idx, arch in enumerate(archs):
arch_index = api.query_index_by_arch( arch )
metrics = api.get_more_info(arch_index, 'cifar10-valid', None, False, False)
gt_accs_10_valid.append( metrics['valid-accuracy'] )
metrics = api.get_more_info(arch_index, 'cifar10', None, False, False)
gt_accs_10_test.append( metrics['test-accuracy'] )
select_logits = []
for i, node_info in enumerate(arch.nodes):
for op, xin in node_info:
node_str = '{:}<-{:}'.format(i+1, xin)
op_index = model.op_names.index(op)
select_logits.append( logits[model.edge2index[node_str], op_index] )
cur_prob = sum(select_logits).item()
probs.append( cur_prob )
cor_prob_valid = np.corrcoef(probs, gt_accs_10_valid)[0,1]
cor_prob_test = np.corrcoef(probs, gt_accs_10_test )[0,1]
print ('{:} correlation for probabilities : {:.6f} on CIFAR-10 validation and {:.6f} on CIFAR-10 test'.format(time_string(), cor_prob_valid, cor_prob_test))
for idx, arch in enumerate(archs):
model.set_cal_mode('dynamic', arch)
try:
inputs, targets = next(loader_iter)
except:
loader_iter = iter(xloader)
inputs, targets = next(loader_iter)
_, logits = model(inputs.cuda())
_, preds = torch.max(logits, dim=-1)
correct = (preds == targets.cuda() ).float()
accuracies.append( correct.mean().item() )
if idx != 0 and (idx % 500 == 0 or idx + 1 == len(archs)):
cor_accs_valid = np.corrcoef(accuracies, gt_accs_10_valid[:idx+1])[0,1]
cor_accs_test = np.corrcoef(accuracies, gt_accs_10_test [:idx+1])[0,1]
print ('{:} {:05d}/{:05d} mode={:5s}, correlation : accs={:.5f} for CIFAR-10 valid, {:.5f} for CIFAR-10 test.'.format(time_string(), idx, len(archs), 'Train' if cal_mode else 'Eval', cor_accs_valid, cor_accs_test))
model.load_state_dict(weights)
return archs, probs, accuracies

View File

@@ -0,0 +1,319 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.03 #
#####################################################
# Reformulate the codes in https://github.com/CalculatedContent/WeightWatcher
#####################################################
import numpy as np
from typing import List
import torch.nn as nn
from collections import OrderedDict
from sklearn.decomposition import TruncatedSVD
def available_module_types():
return (nn.Conv2d, nn.Linear)
def get_conv2D_Wmats(tensor: np.ndarray) -> List[np.ndarray]:
"""
Extract W slices from a 4 index conv2D tensor of shape: (N,M,i,j) or (M,N,i,j).
Return ij (N x M) matrices
"""
mats = []
N, M, imax, jmax = tensor.shape
assert N + M >= imax + jmax, 'invalid tensor shape detected: {}x{} (NxM), {}x{} (i,j)'.format(N, M, imax, jmax)
for i in range(imax):
for j in range(jmax):
w = tensor[:, :, i, j]
if N < M: w = w.T
mats.append(w)
return mats
def glorot_norm_check(W, N, M, rf_size, lower=0.5, upper=1.5):
"""Check if this layer needs Glorot Normalization Fix"""
kappa = np.sqrt(2 / ((N + M) * rf_size))
norm = np.linalg.norm(W)
check1 = norm / np.sqrt(N * M)
check2 = norm / (kappa * np.sqrt(N * M))
if (rf_size > 1) and (check2 > lower) and (check2 < upper):
return check2, True
elif (check1 > lower) & (check1 < upper):
return check1, True
else:
if rf_size > 1: return check2, False
else: return check1, False
def glorot_norm_fix(w, n, m, rf_size):
"""Apply Glorot Normalization Fix."""
kappa = np.sqrt(2 / ((n + m) * rf_size))
w = w / kappa
return w
def analyze_weights(weights, min_size, max_size, alphas, lognorms, spectralnorms, softranks, normalize, glorot_fix):
results = OrderedDict()
count = len(weights)
if count == 0: return results
for i, weight in enumerate(weights):
M, N = np.min(weight.shape), np.max(weight.shape)
Q = N / M
results[i] = cur_res = OrderedDict(N=N, M=M, Q=Q)
check, checkTF = glorot_norm_check(weight, N, M, count)
cur_res['check'] = check
cur_res['checkTF'] = checkTF
# assume receptive field size is count
if glorot_fix:
weight = glorot_norm_fix(weight, N, M, count)
else:
# probably never needed since we always fix for glorot
weight = weight * np.sqrt(count / 2.0)
if spectralnorms: # spectralnorm is the max eigenvalues
svd = TruncatedSVD(n_components=1, n_iter=7, random_state=10)
svd.fit(weight)
sv = svd.singular_values_
sv_max = np.max(sv)
if normalize:
evals = sv * sv / N
else:
evals = sv * sv
lambda0 = evals[0]
cur_res["spectralnorm"] = lambda0
cur_res["logspectralnorm"] = np.log10(lambda0)
else:
lambda0 = None
if M < min_size:
summary = "Weight matrix {}/{} ({},{}): Skipping: too small (<{})".format(i + 1, count, M, N, min_size)
cur_res["summary"] = summary
continue
elif max_size > 0 and M > max_size:
summary = "Weight matrix {}/{} ({},{}): Skipping: too big (testing) (>{})".format(i + 1, count, M, N, max_size)
cur_res["summary"] = summary
continue
else:
summary = []
if alphas:
import powerlaw
svd = TruncatedSVD(n_components=M - 1, n_iter=7, random_state=10)
svd.fit(weight.astype(float))
sv = svd.singular_values_
if normalize: evals = sv * sv / N
else: evals = sv * sv
lambda_max = np.max(evals)
fit = powerlaw.Fit(evals, xmax=lambda_max, verbose=False)
alpha = fit.alpha
cur_res["alpha"] = alpha
D = fit.D
cur_res["D"] = D
cur_res["lambda_min"] = np.min(evals)
cur_res["lambda_max"] = lambda_max
alpha_weighted = alpha * np.log10(lambda_max)
cur_res["alpha_weighted"] = alpha_weighted
tolerance = lambda_max * M * np.finfo(np.max(sv)).eps
cur_res["rank_loss"] = np.count_nonzero(sv > tolerance, axis=-1)
logpnorm = np.log10(np.sum([ev ** alpha for ev in evals]))
cur_res["logpnorm"] = logpnorm
summary.append(
"Weight matrix {}/{} ({},{}): Alpha: {}, Alpha Weighted: {}, D: {}, pNorm {}".format(i + 1, count, M, N, alpha,
alpha_weighted, D,
logpnorm))
if lognorms:
norm = np.linalg.norm(weight) # Frobenius Norm
cur_res["norm"] = norm
lognorm = np.log10(norm)
cur_res["lognorm"] = lognorm
X = np.dot(weight.T, weight)
if normalize: X = X / N
normX = np.linalg.norm(X) # Frobenius Norm
cur_res["normX"] = normX
lognormX = np.log10(normX)
cur_res["lognormX"] = lognormX
summary.append(
"Weight matrix {}/{} ({},{}): LogNorm: {} ; LogNormX: {}".format(i + 1, count, M, N, lognorm, lognormX))
if softranks:
softrank = norm ** 2 / sv_max ** 2
softranklog = np.log10(softrank)
softranklogratio = lognorm / np.log10(sv_max)
cur_res["softrank"] = softrank
cur_res["softranklog"] = softranklog
cur_res["softranklogratio"] = softranklogratio
summary += "{}. Softrank: {}. Softrank log: {}. Softrank log ratio: {}".format(summary, softrank, softranklog,
softranklogratio)
cur_res["summary"] = "\n".join(summary)
return results
def compute_details(results):
"""
Return a pandas data frame.
"""
final_summary = OrderedDict()
metrics = {
# key in "results" : pretty print name
"check": "Check",
"checkTF": "CheckTF",
"norm": "Norm",
"lognorm": "LogNorm",
"normX": "Norm X",
"lognormX": "LogNorm X",
"alpha": "Alpha",
"alpha_weighted": "Alpha Weighted",
"spectralnorm": "Spectral Norm",
"logspectralnorm": "Log Spectral Norm",
"softrank": "Softrank",
"softranklog": "Softrank Log",
"softranklogratio": "Softrank Log Ratio",
"sigma_mp": "Marchenko-Pastur (MP) fit sigma",
"numofSpikes": "Number of spikes per MP fit",
"ratio_numofSpikes": "aka, percent_mass, Number of spikes / total number of evals",
"softrank_mp": "Softrank for MP fit",
"logpnorm": "alpha pNorm"
}
metrics_stats = []
for metric in metrics:
metrics_stats.append("{}_min".format(metric))
metrics_stats.append("{}_max".format(metric))
metrics_stats.append("{}_avg".format(metric))
metrics_stats.append("{}_compound_min".format(metric))
metrics_stats.append("{}_compound_max".format(metric))
metrics_stats.append("{}_compound_avg".format(metric))
columns = ["layer_id", "layer_type", "N", "M", "layer_count", "slice",
"slice_count", "level", "comment"] + [*metrics] + metrics_stats
metrics_values = {}
metrics_values_compound = {}
for metric in metrics:
metrics_values[metric] = []
metrics_values_compound[metric] = []
layer_count = 0
for layer_id, result in results.items():
layer_count += 1
layer_type = np.NAN
if "layer_type" in result:
layer_type = str(result["layer_type"]).replace("LAYER_TYPE.", "")
compounds = {} # temp var
for metric in metrics:
compounds[metric] = []
slice_count, Ntotal, Mtotal = 0, 0, 0
for slice_id, summary in result.items():
if not str(slice_id).isdigit():
continue
slice_count += 1
N = np.NAN
if "N" in summary:
N = summary["N"]
Ntotal += N
M = np.NAN
if "M" in summary:
M = summary["M"]
Mtotal += M
data = {"layer_id": layer_id, "layer_type": layer_type, "N": N, "M": M, "slice": slice_id, "level": "SLICE",
"comment": "Slice level"}
for metric in metrics:
if metric in summary:
value = summary[metric]
if value is not None:
metrics_values[metric].append(value)
compounds[metric].append(value)
data[metric] = value
data = {"layer_id": layer_id, "layer_type": layer_type, "N": Ntotal, "M": Mtotal, "slice_count": slice_count,
"level": "LAYER", "comment": "Layer level"}
# Compute the compound value over the slices
for metric, value in compounds.items():
count = len(value)
if count == 0:
continue
compound = np.mean(value)
metrics_values_compound[metric].append(compound)
data[metric] = compound
data = {"layer_count": layer_count, "level": "NETWORK", "comment": "Network Level"}
for metric, metric_name in metrics.items():
if metric not in metrics_values or len(metrics_values[metric]) == 0:
continue
values = metrics_values[metric]
minimum = min(values)
maximum = max(values)
avg = np.mean(values)
final_summary[metric] = avg
# print("{}: min: {}, max: {}, avg: {}".format(metric_name, minimum, maximum, avg))
data["{}_min".format(metric)] = minimum
data["{}_max".format(metric)] = maximum
data["{}_avg".format(metric)] = avg
values = metrics_values_compound[metric]
minimum = min(values)
maximum = max(values)
avg = np.mean(values)
final_summary["{}_compound".format(metric)] = avg
# print("{} compound: min: {}, max: {}, avg: {}".format(metric_name, minimum, maximum, avg))
data["{}_compound_min".format(metric)] = minimum
data["{}_compound_max".format(metric)] = maximum
data["{}_compound_avg".format(metric)] = avg
return final_summary
def analyze(model: nn.Module, min_size=50, max_size=0,
alphas: bool = False, lognorms: bool = True, spectralnorms: bool = False,
softranks: bool = False, normalize: bool = False, glorot_fix: bool = False):
"""
Analyze the weight matrices of a model.
:param model: A PyTorch model
:param min_size: The minimum weight matrix size to analyze.
:param max_size: The maximum weight matrix size to analyze (0 = no limit).
:param alphas: Compute the power laws (alpha) of the weight matrices.
Time consuming so disabled by default (use lognorm if you want speed)
:param lognorms: Compute the log norms of the weight matrices.
:param spectralnorms: Compute the spectral norm (max eigenvalue) of the weight matrices.
:param softranks: Compute the soft norm (i.e. StableRank) of the weight matrices.
:param normalize: Normalize or not.
:param glorot_fix:
:return: (a dict of all layers' results, a dict of the summarized info)
"""
names, modules = [], []
for name, module in model.named_modules():
if isinstance(module, available_module_types()):
names.append(name)
modules.append(module)
# print('There are {:} layers to be analyzed in this model.'.format(len(modules)))
all_results = OrderedDict()
for index, module in enumerate(modules):
if isinstance(module, nn.Linear):
weights = [module.weight.cpu().detach().numpy()]
else:
weights = get_conv2D_Wmats(module.weight.cpu().detach().numpy())
results = analyze_weights(weights, min_size, max_size, alphas, lognorms, spectralnorms, softranks, normalize, glorot_fix)
results['id'] = index
results['type'] = type(module)
all_results[index] = results
summary = compute_details(all_results)
return all_results, summary

File diff suppressed because one or more lines are too long

View File

@@ -3,3 +3,4 @@
##################################################
from .get_dataset_with_transform import get_datasets, get_nas_search_loaders
from .SearchDatasetWrap import SearchDataset
from .data import get_data

69
datasets/data.py Normal file
View File

@@ -0,0 +1,69 @@
from datasets import get_datasets
from config_utils import load_config
import torch
import torchvision
class AddGaussianNoise(object):
def __init__(self, mean=0., std=0.001):
self.std = std
self.mean = mean
def __call__(self, tensor):
return tensor + torch.randn(tensor.size()) * self.std + self.mean
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
class RepeatSampler(torch.utils.data.sampler.Sampler):
def __init__(self, samp, repeat):
self.samp = samp
self.repeat = repeat
def __iter__(self):
for i in self.samp:
for j in range(self.repeat):
yield i
def __len__(self):
return self.repeat*len(self.samp)
def get_data(dataset, data_loc, trainval, batch_size, augtype, repeat, args, pin_memory=True):
train_data, valid_data, xshape, class_num = get_datasets(dataset, data_loc, cutout=0)
if augtype == 'gaussnoise':
train_data.transform.transforms = train_data.transform.transforms[2:]
train_data.transform.transforms.append(AddGaussianNoise(std=args.sigma))
elif augtype == 'cutout':
train_data.transform.transforms = train_data.transform.transforms[2:]
train_data.transform.transforms.append(torchvision.transforms.RandomErasing(p=0.9, scale=(0.02, 0.04)))
elif augtype == 'none':
train_data.transform.transforms = train_data.transform.transforms[2:]
if dataset == 'cifar10':
acc_type = 'ori-test'
val_acc_type = 'x-valid'
else:
acc_type = 'x-test'
val_acc_type = 'x-valid'
if trainval and 'cifar10' in dataset:
cifar_split = load_config('config_utils/cifar-split.txt', None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid
if repeat > 0:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
num_workers=0, pin_memory=pin_memory, sampler= RepeatSampler(torch.utils.data.sampler.SubsetRandomSampler(train_split), repeat))
else:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
num_workers=0, pin_memory=pin_memory, sampler= torch.utils.data.sampler.SubsetRandomSampler(train_split))
else:
if repeat > 0:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, #shuffle=True,
num_workers=0, pin_memory=pin_memory, sampler= RepeatSampler(torch.utils.data.sampler.SubsetRandomSampler(range(len(train_data))), repeat))
else:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True,
num_workers=0, pin_memory=pin_memory)
return train_loader

View File

@@ -16,7 +16,9 @@ from config_utils import load_config
Dataset2Class = {'cifar10' : 10,
'cifar100': 100,
'fake':10,
'imagenet-1k-s':1000,
'imagenette2' : 10,
'imagenet-1k' : 1000,
'ImageNet16' : 1000,
'ImageNet16-150': 150,
@@ -98,8 +100,13 @@ def get_datasets(name, root, cutout):
elif name == 'cifar100':
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
std = [x / 255 for x in [68.2, 65.4, 70.4]]
elif name == 'fake':
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
std = [x / 255 for x in [68.2, 65.4, 70.4]]
elif name.startswith('imagenet-1k'):
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
elif name.startswith('imagenette'):
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
elif name.startswith('ImageNet16'):
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
std = [x / 255 for x in [63.22, 61.26 , 65.09]]
@@ -113,6 +120,12 @@ def get_datasets(name, root, cutout):
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
xshape = (1, 3, 32, 32)
elif name == 'fake':
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)]
if cutout > 0 : lists += [CUTOUT(cutout)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
xshape = (1, 3, 32, 32)
elif name.startswith('ImageNet16'):
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)]
if cutout > 0 : lists += [CUTOUT(cutout)]
@@ -125,6 +138,15 @@ def get_datasets(name, root, cutout):
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.CenterCrop(80), transforms.ToTensor(), transforms.Normalize(mean, std)])
xshape = (1, 3, 32, 32)
elif name.startswith('imagenette'):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
xlists = []
xlists.append( transforms.ToTensor() )
xlists.append( normalize )
#train_transform = transforms.Compose(xlists)
train_transform = transforms.Compose([normalize, normalize, transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
test_transform = transforms.Compose([normalize, normalize, transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
xshape = (1, 3, 224, 224)
elif name.startswith('imagenet-1k'):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
if name == 'imagenet-1k':
@@ -156,6 +178,12 @@ def get_datasets(name, root, cutout):
train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True)
test_data = dset.CIFAR100(root, train=False, transform=test_transform , download=True)
assert len(train_data) == 50000 and len(test_data) == 10000
elif name == 'fake':
train_data = dset.FakeData(size=50000, image_size=(3, 32, 32), transform=train_transform)
test_data = dset.FakeData(size=10000, image_size=(3, 32, 32), transform=test_transform)
elif name.startswith('imagenette2'):
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform)
elif name.startswith('imagenet-1k'):
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform)

24
env.yml Normal file
View File

@@ -0,0 +1,24 @@
name: naswot2
channels:
- conda-forge
- pytorch
dependencies:
- python=3.7
- numpy
- matplotlib
- seaborn
- pandas
- xlrd
- scipy
- pip
- scikit-learn
- scikit-image
- pytorch::pytorch==1.6.0
- pytorch::torchvision==0.7.0
- cudatoolkit=9.2
- tqdm
- pip:
- tensorflow-gpu==1.15
- yacs
- simplejson
- "--editable=git+https://github.com/google-research/nasbench#egg=nasbench-master"

View File

@@ -1,54 +0,0 @@
name: nas-wot
channels:
- pytorch
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- blas=1.0=mkl
- ca-certificates=2020.1.1=0
- certifi=2020.4.5.1=py38_0
- cudatoolkit=10.2.89=hfd86e86_1
- freetype=2.9.1=h8a8886c_1
- intel-openmp=2020.1=217
- jpeg=9b=h024ee3a_2
- ld_impl_linux-64=2.33.1=h53a641e_7
- libedit=3.1.20181209=hc058e9b_0
- libffi=3.3=he6710b0_1
- libgcc-ng=9.1.0=hdf63c60_0
- libgfortran-ng=7.3.0=hdf63c60_0
- libpng=1.6.37=hbc83047_0
- libstdcxx-ng=9.1.0=hdf63c60_0
- libtiff=4.1.0=h2733197_1
- lz4-c=1.9.2=he6710b0_0
- mkl=2020.1=217
- mkl-service=2.3.0=py38he904b0f_0
- mkl_fft=1.0.15=py38ha843d7b_0
- mkl_random=1.1.1=py38h0573a6f_0
- ncurses=6.2=he6710b0_1
- ninja=1.9.0=py38hfd86e86_0
- numpy=1.18.1=py38h4f9e942_0
- numpy-base=1.18.1=py38hde5b4d6_1
- olefile=0.46=py_0
- openssl=1.1.1g=h7b6447c_0
- pandas=1.0.3=py38h0573a6f_0
- pillow=7.1.2=py38hb39fc2d_0
- pip=20.0.2=py38_3
- python=3.8.3=hcff3b4d_0
- python-dateutil=2.8.1=py_0
- pytorch=1.5.0=py3.8_cuda10.2.89_cudnn7.6.5_0
- pytz=2020.1=py_0
- readline=8.0=h7b6447c_0
- setuptools=46.4.0=py38_0
- six=1.14.0=py38_0
- sqlite=3.31.1=h62c20be_1
- tk=8.6.8=hbc83047_0
- torchvision=0.6.0=py38_cu102
- tqdm=4.46.0=py_0
- wheel=0.34.2=py38_0
- xz=5.2.5=h7b6447c_0
- zlib=1.2.11=h7b6447c_3
- zstd=1.4.4=h0b5b093_3
- pip:
- argparse==1.4.0
- nas-bench-201==1.3
- tabulate==0.8.7

View File

@@ -55,4 +55,4 @@ class TinyNetwork(nn.Module):
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits
return logits, out

1
nas_101_api/__init__.py Normal file
View File

@@ -0,0 +1 @@

65
nas_101_api/base_ops.py Normal file
View File

@@ -0,0 +1,65 @@
"""Base operations used by the modules in this search space."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBnRelu(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
super(ConvBnRelu, self).__init__()
self.conv_bn_relu = nn.Sequential(
#nn.ReLU(),
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(out_channels),
#nn.ReLU(inplace=True)
nn.ReLU()
)
def forward(self, x):
return self.conv_bn_relu(x)
class Conv3x3BnRelu(nn.Module):
"""3x3 convolution with batch norm and ReLU activation."""
def __init__(self, in_channels, out_channels):
super(Conv3x3BnRelu, self).__init__()
self.conv3x3 = ConvBnRelu(in_channels, out_channels, 3, 1, 1)
def forward(self, x):
x = self.conv3x3(x)
return x
class Conv1x1BnRelu(nn.Module):
"""1x1 convolution with batch norm and ReLU activation."""
def __init__(self, in_channels, out_channels):
super(Conv1x1BnRelu, self).__init__()
self.conv1x1 = ConvBnRelu(in_channels, out_channels, 1, 1, 0)
def forward(self, x):
x = self.conv1x1(x)
return x
class MaxPool3x3(nn.Module):
"""3x3 max pool with no subsampling."""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
super(MaxPool3x3, self).__init__()
self.maxpool = nn.MaxPool2d(kernel_size, stride, padding)
#self.maxpool = nn.AvgPool2d(kernel_size, stride, padding)
def forward(self, x):
x = self.maxpool(x)
return x
# Commas should not be used in op names
OP_MAP = {
'conv3x3-bn-relu': Conv3x3BnRelu,
'conv1x1-bn-relu': Conv1x1BnRelu,
'maxpool3x3': MaxPool3x3
}

167
nas_101_api/graph_util.py Normal file
View File

@@ -0,0 +1,167 @@
# Copyright 2019 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions used by generate_graph.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import hashlib
import itertools
import numpy as np
def gen_is_edge_fn(bits):
"""Generate a boolean function for the edge connectivity.
Given a bitstring FEDCBA and a 4x4 matrix, the generated matrix is
[[0, A, B, D],
[0, 0, C, E],
[0, 0, 0, F],
[0, 0, 0, 0]]
Note that this function is agnostic to the actual matrix dimension due to
order in which elements are filled out (column-major, starting from least
significant bit). For example, the same FEDCBA bitstring (0-padded) on a 5x5
matrix is
[[0, A, B, D, 0],
[0, 0, C, E, 0],
[0, 0, 0, F, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]
Args:
bits: integer which will be interpreted as a bit mask.
Returns:
vectorized function that returns True when an edge is present.
"""
def is_edge(x, y):
"""Is there an edge from x to y (0-indexed)?"""
if x >= y:
return 0
# Map x, y to index into bit string
index = x + (y * (y - 1) // 2)
return (bits >> index) % 2 == 1
return np.vectorize(is_edge)
def is_full_dag(matrix):
"""Full DAG == all vertices on a path from vert 0 to (V-1).
i.e. no disconnected or "hanging" vertices.
It is sufficient to check for:
1) no rows of 0 except for row V-1 (only output vertex has no out-edges)
2) no cols of 0 except for col 0 (only input vertex has no in-edges)
Args:
matrix: V x V upper-triangular adjacency matrix
Returns:
True if the there are no dangling vertices.
"""
shape = np.shape(matrix)
rows = matrix[:shape[0]-1, :] == 0
rows = np.all(rows, axis=1) # Any row with all 0 will be True
rows_bad = np.any(rows)
cols = matrix[:, 1:] == 0
cols = np.all(cols, axis=0) # Any col with all 0 will be True
cols_bad = np.any(cols)
return (not rows_bad) and (not cols_bad)
def num_edges(matrix):
"""Computes number of edges in adjacency matrix."""
return np.sum(matrix)
def hash_module(matrix, labeling):
"""Computes a graph-invariance MD5 hash of the matrix and label pair.
Args:
matrix: np.ndarray square upper-triangular adjacency matrix.
labeling: list of int labels of length equal to both dimensions of
matrix.
Returns:
MD5 hash of the matrix and labeling.
"""
vertices = np.shape(matrix)[0]
in_edges = np.sum(matrix, axis=0).tolist()
out_edges = np.sum(matrix, axis=1).tolist()
assert len(in_edges) == len(out_edges) == len(labeling)
hashes = list(zip(out_edges, in_edges, labeling))
hashes = [hashlib.md5(str(h).encode('utf-8')).hexdigest() for h in hashes]
# Computing this up to the diameter is probably sufficient but since the
# operation is fast, it is okay to repeat more times.
for _ in range(vertices):
new_hashes = []
for v in range(vertices):
in_neighbors = [hashes[w] for w in range(vertices) if matrix[w, v]]
out_neighbors = [hashes[w] for w in range(vertices) if matrix[v, w]]
new_hashes.append(hashlib.md5(
(''.join(sorted(in_neighbors)) + '|' +
''.join(sorted(out_neighbors)) + '|' +
hashes[v]).encode('utf-8')).hexdigest())
hashes = new_hashes
fingerprint = hashlib.md5(str(sorted(hashes)).encode('utf-8')).hexdigest()
return fingerprint
def permute_graph(graph, label, permutation):
"""Permutes the graph and labels based on permutation.
Args:
graph: np.ndarray adjacency matrix.
label: list of labels of same length as graph dimensions.
permutation: a permutation list of ints of same length as graph dimensions.
Returns:
np.ndarray where vertex permutation[v] is vertex v from the original graph
"""
# vertex permutation[v] in new graph is vertex v in the old graph
forward_perm = zip(permutation, list(range(len(permutation))))
inverse_perm = [x[1] for x in sorted(forward_perm)]
edge_fn = lambda x, y: graph[inverse_perm[x], inverse_perm[y]] == 1
new_matrix = np.fromfunction(np.vectorize(edge_fn),
(len(label), len(label)),
dtype=np.int8)
new_label = [label[inverse_perm[i]] for i in range(len(label))]
return new_matrix, new_label
def is_isomorphic(graph1, graph2):
"""Exhaustively checks if 2 graphs are isomorphic."""
matrix1, label1 = np.array(graph1[0]), graph1[1]
matrix2, label2 = np.array(graph2[0]), graph2[1]
assert np.shape(matrix1) == np.shape(matrix2)
assert len(label1) == len(label2)
vertices = np.shape(matrix1)[0]
# Note: input and output in our constrained graphs always map to themselves
# but this script does not enforce that.
for perm in itertools.permutations(range(0, vertices)):
pmatrix1, plabel1 = permute_graph(matrix1, label1, perm)
if np.array_equal(pmatrix1, matrix2) and plabel1 == label2:
return True
return False

252
nas_101_api/model.py Normal file
View File

@@ -0,0 +1,252 @@
"""Builds the Pytorch computational graph.
Tensors flowing into a single vertex are added together for all vertices
except the output, which is concatenated instead. Tensors flowing out of input
are always added.
If interior edge channels don't match, drop the extra channels (channels are
guaranteed non-decreasing). Tensors flowing out of the input as always
projected instead.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import math
from .base_ops import *
import torch
import torch.nn as nn
import torch.nn.functional as F
class Network(nn.Module):
def __init__(self, spec, args, searchspace=[]):
super(Network, self).__init__()
self.layers = nn.ModuleList([])
in_channels = 3
out_channels = args.stem_out_channels
# initial stem convolution
stem_conv = ConvBnRelu(in_channels, out_channels, 3, 1, 1)
self.layers.append(stem_conv)
in_channels = out_channels
for stack_num in range(args.num_stacks):
if stack_num > 0:
#downsample = nn.MaxPool2d(kernel_size=3, stride=2)
downsample = nn.MaxPool2d(kernel_size=2, stride=2)
#downsample = nn.AvgPool2d(kernel_size=2, stride=2)
#downsample = nn.Conv2d(in_channels, out_channels, kernel_size=(2, 2), stride=2)
self.layers.append(downsample)
out_channels *= 2
for module_num in range(args.num_modules_per_stack):
cell = Cell(spec, in_channels, out_channels)
self.layers.append(cell)
in_channels = out_channels
self.classifier = nn.Linear(out_channels, args.num_labels)
# for DARTS search
num_edge = np.shape(spec.matrix)[0]
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(searchspace)))
self._initialize_weights()
def forward(self, x, get_ints=True):
ints = []
for _, layer in enumerate(self.layers):
x = layer(x)
ints.append(x)
out = torch.mean(x, (2, 3))
ints.append(out)
out = self.classifier(out)
if get_ints:
return out, ints[-1]
else:
return out
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2.0 / n))
if m.bias is not None:
m.bias.data.zero_()
pass
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
pass
elif isinstance(m, nn.Linear):
n = m.weight.size(1)
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
pass
def get_weights(self):
xlist = []
for m in self.modules():
xlist.append(m.parameters())
return xlist
def get_alphas(self):
return [self.arch_parameters]
def genotype(self):
return str(spec)
class Cell(nn.Module):
"""
Builds the model using the adjacency matrix and op labels specified. Channels
controls the module output channel count but the interior channels are
determined via equally splitting the channel count whenever there is a
concatenation of Tensors.
"""
def __init__(self, spec, in_channels, out_channels):
super(Cell, self).__init__()
self.spec = spec
self.num_vertices = np.shape(self.spec.matrix)[0]
# vertex_channels[i] = number of output channels of vertex i
self.vertex_channels = ComputeVertexChannels(in_channels, out_channels, self.spec.matrix)
#self.vertex_channels = [in_channels] + [out_channels] * (self.num_vertices - 1)
# operation for each node
self.vertex_op = nn.ModuleList([None])
for t in range(1, self.num_vertices-1):
op = OP_MAP[spec.ops[t]](self.vertex_channels[t], self.vertex_channels[t])
self.vertex_op.append(op)
# operation for input on each vertex
self.input_op = nn.ModuleList([None])
for t in range(1, self.num_vertices):
if self.spec.matrix[0, t]:
self.input_op.append(Projection(in_channels, self.vertex_channels[t]))
else:
self.input_op.append(None)
def forward(self, x):
tensors = [x]
out_concat = []
for t in range(1, self.num_vertices-1):
fan_in = [Truncate(tensors[src], self.vertex_channels[t]) for src in range(1, t) if self.spec.matrix[src, t]]
fan_in_inds = [src for src in range(1, t) if self.spec.matrix[src, t]]
if self.spec.matrix[0, t]:
fan_in.append(self.input_op[t](x))
fan_in_inds = [0] + fan_in_inds
# perform operation on node
#vertex_input = torch.stack(fan_in, dim=0).sum(dim=0)
vertex_input = sum(fan_in)
#vertex_input = sum(fan_in) / len(fan_in)
vertex_output = self.vertex_op[t](vertex_input)
tensors.append(vertex_output)
if self.spec.matrix[t, self.num_vertices-1]:
out_concat.append(tensors[t])
if not out_concat: # empty list
assert self.spec.matrix[0, self.num_vertices-1]
outputs = self.input_op[self.num_vertices-1](tensors[0])
else:
if len(out_concat) == 1:
outputs = out_concat[0]
else:
outputs = torch.cat(out_concat, 1)
if self.spec.matrix[0, self.num_vertices-1]:
outputs += self.input_op[self.num_vertices-1](tensors[0])
#if self.spec.matrix[0, self.num_vertices-1]:
# out_concat.append(self.input_op[self.num_vertices-1](tensors[0]))
#outputs = sum(out_concat) / len(out_concat)
return outputs
def Projection(in_channels, out_channels):
"""1x1 projection (as in ResNet) followed by batch normalization and ReLU."""
return ConvBnRelu(in_channels, out_channels, 1)
def Truncate(inputs, channels):
"""Slice the inputs to channels if necessary."""
input_channels = inputs.size()[1]
if input_channels < channels:
raise ValueError('input channel < output channels for truncate')
elif input_channels == channels:
return inputs # No truncation necessary
else:
# Truncation should only be necessary when channel division leads to
# vertices with +1 channels. The input vertex should always be projected to
# the minimum channel count.
assert input_channels - channels == 1
return inputs[:, :channels, :, :]
def ComputeVertexChannels(in_channels, out_channels, matrix):
"""Computes the number of channels at every vertex.
Given the input channels and output channels, this calculates the number of
channels at each interior vertex. Interior vertices have the same number of
channels as the max of the channels of the vertices it feeds into. The output
channels are divided amongst the vertices that are directly connected to it.
When the division is not even, some vertices may receive an extra channel to
compensate.
Returns:
list of channel counts, in order of the vertices.
"""
num_vertices = np.shape(matrix)[0]
vertex_channels = [0] * num_vertices
vertex_channels[0] = in_channels
vertex_channels[num_vertices - 1] = out_channels
if num_vertices == 2:
# Edge case where module only has input and output vertices
return vertex_channels
# Compute the in-degree ignoring input, axis 0 is the src vertex and axis 1 is
# the dst vertex. Summing over 0 gives the in-degree count of each vertex.
in_degree = np.sum(matrix[1:], axis=0)
interior_channels = out_channels // in_degree[num_vertices - 1]
correction = out_channels % in_degree[num_vertices - 1] # Remainder to add
# Set channels of vertices that flow directly to output
for v in range(1, num_vertices - 1):
if matrix[v, num_vertices - 1]:
vertex_channels[v] = interior_channels
if correction:
vertex_channels[v] += 1
correction -= 1
# Set channels for all other vertices to the max of the out edges, going
# backwards. (num_vertices - 2) index skipped because it only connects to
# output.
for v in range(num_vertices - 3, 0, -1):
if not matrix[v, num_vertices - 1]:
for dst in range(v + 1, num_vertices - 1):
if matrix[v, dst]:
vertex_channels[v] = max(vertex_channels[v], vertex_channels[dst])
assert vertex_channels[v] > 0
# Sanity check, verify that channels never increase and final channels add up.
final_fan_in = 0
for v in range(1, num_vertices - 1):
if matrix[v, num_vertices - 1]:
final_fan_in += vertex_channels[v]
for dst in range(v + 1, num_vertices - 1):
if matrix[v, dst]:
assert vertex_channels[v] >= vertex_channels[dst]
assert final_fan_in == out_channels or num_vertices == 2
# num_vertices == 2 means only input/output nodes, so 0 fan-in
return vertex_channels

152
nas_101_api/model_spec.py Normal file
View File

@@ -0,0 +1,152 @@
"""Model specification for module connectivity individuals.
This module handles pruning the unused parts of the computation graph but should
avoid creating any TensorFlow models (this is done inside model_builder.py).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import numpy as np
from . import graph_util
# Graphviz is optional and only required for visualization.
try:
import graphviz # pylint: disable=g-import-not-at-top
except ImportError:
pass
class ModelSpec(object):
"""Model specification given adjacency matrix and labeling."""
def __init__(self, matrix, ops, data_format='channels_last'):
"""Initialize the module spec.
Args:
matrix: ndarray or nested list with shape [V, V] for the adjacency matrix.
ops: V-length list of labels for the base ops used. The first and last
elements are ignored because they are the input and output vertices
which have no operations. The elements are retained to keep consistent
indexing.
data_format: channels_last or channels_first.
Raises:
ValueError: invalid matrix or ops
"""
if not isinstance(matrix, np.ndarray):
matrix = np.array(matrix)
shape = np.shape(matrix)
if len(shape) != 2 or shape[0] != shape[1]:
raise ValueError('matrix must be square')
if shape[0] != len(ops):
raise ValueError('length of ops must match matrix dimensions')
if not is_upper_triangular(matrix):
raise ValueError('matrix must be upper triangular')
# Both the original and pruned matrices are deep copies of the matrix and
# ops so any changes to those after initialization are not recognized by the
# spec.
self.original_matrix = copy.deepcopy(matrix)
self.original_ops = copy.deepcopy(ops)
self.matrix = copy.deepcopy(matrix)
self.ops = copy.deepcopy(ops)
self.valid_spec = True
self._prune()
self.data_format = data_format
def _prune(self):
"""Prune the extraneous parts of the graph.
General procedure:
1) Remove parts of graph not connected to input.
2) Remove parts of graph not connected to output.
3) Reorder the vertices so that they are consecutive after steps 1 and 2.
These 3 steps can be combined by deleting the rows and columns of the
vertices that are not reachable from both the input and output (in reverse).
"""
num_vertices = np.shape(self.original_matrix)[0]
# DFS forward from input
visited_from_input = set([0])
frontier = [0]
while frontier:
top = frontier.pop()
for v in range(top + 1, num_vertices):
if self.original_matrix[top, v] and v not in visited_from_input:
visited_from_input.add(v)
frontier.append(v)
# DFS backward from output
visited_from_output = set([num_vertices - 1])
frontier = [num_vertices - 1]
while frontier:
top = frontier.pop()
for v in range(0, top):
if self.original_matrix[v, top] and v not in visited_from_output:
visited_from_output.add(v)
frontier.append(v)
# Any vertex that isn't connected to both input and output is extraneous to
# the computation graph.
extraneous = set(range(num_vertices)).difference(
visited_from_input.intersection(visited_from_output))
# If the non-extraneous graph is less than 2 vertices, the input is not
# connected to the output and the spec is invalid.
if len(extraneous) > num_vertices - 2:
self.matrix = None
self.ops = None
self.valid_spec = False
return
self.matrix = np.delete(self.matrix, list(extraneous), axis=0)
self.matrix = np.delete(self.matrix, list(extraneous), axis=1)
for index in sorted(extraneous, reverse=True):
del self.ops[index]
def hash_spec(self, canonical_ops):
"""Computes the isomorphism-invariant graph hash of this spec.
Args:
canonical_ops: list of operations in the canonical ordering which they
were assigned (i.e. the order provided in the config['available_ops']).
Returns:
MD5 hash of this spec which can be used to query the dataset.
"""
# Invert the operations back to integer label indices used in graph gen.
labeling = [-1] + [canonical_ops.index(op) for op in self.ops[1:-1]] + [-2]
return graph_util.hash_module(self.matrix, labeling)
def visualize(self):
"""Creates a dot graph. Can be visualized in colab directly."""
num_vertices = np.shape(self.matrix)[0]
g = graphviz.Digraph()
g.node(str(0), 'input')
for v in range(1, num_vertices - 1):
g.node(str(v), self.ops[v])
g.node(str(num_vertices - 1), 'output')
for src in range(num_vertices - 1):
for dst in range(src + 1, num_vertices):
if self.matrix[src, dst]:
g.edge(str(src), str(dst))
return g
def is_upper_triangular(matrix):
"""True if matrix is 0 on diagonal and below."""
for src in range(np.shape(matrix)[0]):
for dst in range(0, src + 1):
if matrix[src, dst] != 0:
return False
return True

15
nas_201_api/__init__.py Normal file
View File

@@ -0,0 +1,15 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
#####################################################################
# This API will not be updated after 2020.09.16. #
# Please use our new API in NATS-Bench, which is #
# more efficient and contains info of more architecture candidates. #
#####################################################################
from .api_utils import ArchResults, ResultsCount
from .api_201 import NASBench201API
# NAS_BENCH_201_API_VERSION="v1.1" # [2020.02.25]
# NAS_BENCH_201_API_VERSION="v1.2" # [2020.03.09]
# NAS_BENCH_201_API_VERSION="v1.3" # [2020.03.16]
NAS_BENCH_201_API_VERSION="v2.0" # [2020.06.30]

830
nas_201_api/api.py Normal file
View File

@@ -0,0 +1,830 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
############################################################################################
# NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 #
############################################################################################
# [2020.02.25] NAS-Bench-201-v1_0-e61699.pth : 6219 architectures are trained once, 1621 architectures are trained twice, 7785 architectures are trained three times. `LESS` only supports CIFAR10-VALID.
# [2020.03.08] Next version (coming soon)
#
#
import os, copy, random, torch, numpy as np
from typing import List, Text, Union, Dict
from collections import OrderedDict, defaultdict
def print_information(information, extra_info=None, show=False):
dataset_names = information.get_dataset_names()
strings = [information.arch_str, 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)]
def metric2str(loss, acc):
return 'loss = {:.3f}, top1 = {:.2f}%'.format(loss, acc)
for ida, dataset in enumerate(dataset_names):
metric = information.get_compute_costs(dataset)
flop, param, latency = metric['flops'], metric['params'], metric['latency']
str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency is not None and latency > 0 else None)
train_info = information.get_metrics(dataset, 'train')
if dataset == 'cifar10-valid':
valid_info = information.get_metrics(dataset, 'x-valid')
str2 = '{:14s} train : [{:}], valid : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']))
elif dataset == 'cifar10':
test__info = information.get_metrics(dataset, 'ori-test')
str2 = '{:14s} train : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
else:
valid_info = information.get_metrics(dataset, 'x-valid')
test__info = information.get_metrics(dataset, 'x-test')
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
strings += [str1, str2]
if show: print('\n'.join(strings))
return strings
"""
This is the class for API of NAS-Bench-201.
"""
class NASBench201API(object):
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
def __init__(self, file_path_or_dict: Union[Text, Dict], verbose: bool=True):
if isinstance(file_path_or_dict, str):
if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict))
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
file_path_or_dict = torch.load(file_path_or_dict)
elif isinstance(file_path_or_dict, dict):
file_path_or_dict = copy.deepcopy( file_path_or_dict )
else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict)))
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
self.arch2infos_less = OrderedDict()
self.arch2infos_full = OrderedDict()
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
all_info = file_path_or_dict['arch2infos'][xkey]
self.arch2infos_less[xkey] = ArchResults.create_from_state_dict( all_info['less'] )
self.arch2infos_full[xkey] = ArchResults.create_from_state_dict( all_info['full'] )
self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes']))
self.archstr2index = {}
for idx, arch in enumerate(self.meta_archs):
#assert arch.tostr() not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch.tostr()])
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
self.archstr2index[ arch ] = idx
def __getitem__(self, index: int):
return copy.deepcopy( self.meta_archs[index] )
def __len__(self):
return len(self.meta_archs)
def __repr__(self):
return ('{name}({num}/{total} architectures)'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs)))
def random(self):
"""Return a random index of all architectures."""
return random.randint(0, len(self.meta_archs)-1)
# This function is used to query the index of an architecture in the search space.
# The input arch can be an architecture string such as '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|'
# or an instance that has the 'tostr' function that can generate the architecture string.
# This function will return the index.
# If return -1, it means this architecture is not in the search space.
# Otherwise, it will return an int in [0, the-number-of-candidates-in-the-search-space).
def query_index_by_arch(self, arch):
if isinstance(arch, str):
if arch in self.archstr2index: arch_index = self.archstr2index[ arch ]
else : arch_index = -1
elif hasattr(arch, 'tostr'):
if arch.tostr() in self.archstr2index: arch_index = self.archstr2index[ arch.tostr() ]
else : arch_index = -1
else: arch_index = -1
return arch_index
# Overwrite all information of the 'index'-th architecture in the search space.
# It will load its data from 'archive_root'.
def reload(self, archive_root: Text, index: int):
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index))
assert 0 <= index < len(self.meta_archs), 'invalid index of {:}'.format(index)
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
xdata = torch.load(xfile_path)
assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path)
self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] )
self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] )
# This function is used to query the information of a specific archiitecture
# 'arch' can be an architecture index or an architecture string
# When use_12epochs_result=True, the hyper-parameters used to train a model are in 'configs/nas-benchmark/CIFAR.config'
# When use_12epochs_result=False, the hyper-parameters used to train a model are in 'configs/nas-benchmark/LESS.config'
# The difference between these two configurations are the number of training epochs, which is 200 in CIFAR.config and 12 in LESS.config.
def query_by_arch(self, arch, use_12epochs_result=False):
if isinstance(arch, int):
arch_index = arch
else:
arch_index = self.query_index_by_arch(arch)
if arch_index == -1: return None # the following two lines are used to support few training epochs
if use_12epochs_result: arch2infos = self.arch2infos_less
else : arch2infos = self.arch2infos_full
if arch_index in arch2infos:
strings = print_information(arch2infos[ arch_index ], 'arch-index={:}'.format(arch_index))
return '\n'.join(strings)
else:
print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
return None
# This 'query_by_index' function is used to query information with the training of 12 epochs or 200 epochs.
# ------
# If use_12epochs_result=True, we train the model by 12 epochs (see config in configs/nas-benchmark/LESS.config)
# If use_12epochs_result=False, we train the model by 200 epochs (see config in configs/nas-benchmark/CIFAR.config)
# ------
# If dataname is None, return the ArchResults
# else, return a dict with all trials on that dataset (the key is the seed)
# Options are 'cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120'.
# -- cifar10-valid : training the model on the CIFAR-10 training set.
# -- cifar10 : training the model on the CIFAR-10 training + validation set.
# -- cifar100 : training the model on the CIFAR-100 training set.
# -- ImageNet16-120 : training the model on the ImageNet16-120 training set.
def query_by_index(self, arch_index: int, dataname: Union[None, Text] = None,
use_12epochs_result: bool = False):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
else : basestr, arch2infos = '200epochs', self.arch2infos_full
assert arch_index in arch2infos, 'arch_index [{:}] does not in arch2info with {:}'.format(arch_index, basestr)
archInfo = copy.deepcopy( arch2infos[ arch_index ] )
if dataname is None: return archInfo
else:
assert dataname in archInfo.get_dataset_names(), 'invalid dataset-name : {:}'.format(dataname)
info = archInfo.query(dataname)
return info
def query_meta_info_by_index(self, arch_index, use_12epochs_result=False):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
else : basestr, arch2infos = '200epochs', self.arch2infos_full
assert arch_index in arch2infos, 'arch_index [{:}] does not in arch2info with {:}'.format(arch_index, basestr)
archInfo = copy.deepcopy( arch2infos[ arch_index ] )
return archInfo
def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None, use_12epochs_result=False):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
else : basestr, arch2infos = '200epochs', self.arch2infos_full
best_index, highest_accuracy = -1, None
for i, idx in enumerate(self.evaluated_indexes):
info = arch2infos[idx].get_compute_costs(dataset)
flop, param, latency = info['flops'], info['params'], info['latency']
if FLOP_max is not None and flop > FLOP_max : continue
if Param_max is not None and param > Param_max: continue
xinfo = arch2infos[idx].get_metrics(dataset, metric_on_set)
loss, accuracy = xinfo['loss'], xinfo['accuracy']
if best_index == -1:
best_index, highest_accuracy = idx, accuracy
elif highest_accuracy < accuracy:
best_index, highest_accuracy = idx, accuracy
return best_index, highest_accuracy
def arch(self, index: int):
"""Return the topology structure of the `index`-th architecture."""
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
return copy.deepcopy(self.meta_archs[index])
def get_net_param(self, index, dataset, seed, use_12epochs_result=False):
"""
This function is used to obtain the trained weights of the `index`-th architecture on `dataset` with the seed of `seed`
Args [seed]:
-- None : return a dict containing the trained weights of all trials, where each key is a seed and its corresponding value is the weights.
-- a interger : return the weights of a specific trial, whose seed is this interger.
Args [use_12epochs_result]:
-- True : train the model by 12 epochs
-- False : train the model by 200 epochs
"""
if use_12epochs_result: arch2infos = self.arch2infos_less
else: arch2infos = self.arch2infos_full
arch_result = arch2infos[index]
return arch_result.get_net_param(dataset, seed)
def get_net_config(self, index: int, dataset: Text):
"""
This function is used to obtain the configuration for the `index`-th architecture on `dataset`.
Args [dataset] (4 possible options):
-- cifar10-valid : training the model on the CIFAR-10 training set.
-- cifar10 : training the model on the CIFAR-10 training + validation set.
-- cifar100 : training the model on the CIFAR-100 training set.
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
This function will return a dict.
========= Some examlpes for using this function:
config = api.get_net_config(128, 'cifar10')
"""
archresult = self.arch2infos_full[index]
all_results = archresult.query(dataset, None)
if len(all_results) == 0: raise ValueError('can not find one valid trial for the {:}-th architecture on {:}'.format(index, dataset))
for seed, result in all_results.items():
return result.get_config(None)
#print ('SEED [{:}] : {:}'.format(seed, result))
raise ValueError('Impossible to reach here!')
def get_cost_info(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> Dict[Text, float]:
"""To obtain the cost metric for the `index`-th architecture on a dataset."""
if use_12epochs_result: arch2infos = self.arch2infos_less
else: arch2infos = self.arch2infos_full
arch_result = arch2infos[index]
return arch_result.get_compute_costs(dataset)
def get_latency(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> float:
"""
To obtain the latency of the network (by default it will return the latency with the batch size of 256).
:param index: the index of the target architecture
:param dataset: the dataset name (cifar10-valid, cifar10, cifar100, ImageNet16-120)
:return: return a float value in seconds
"""
cost_dict = self.get_cost_info(index, dataset, use_12epochs_result)
return cost_dict['latency']
# obtain the metric for the `index`-th architecture
# `dataset` indicates the dataset:
# 'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
# 'cifar10' : using the proposed train+valid set of CIFAR-10 as the training set
# 'cifar100' : using the proposed train set of CIFAR-100 as the training set
# 'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set
# `iepoch` indicates the index of training epochs from 0 to 11/199.
# When iepoch=None, it will return the metric for the last training epoch
# When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0)
# `use_12epochs_result` indicates different hyper-parameters for training
# When use_12epochs_result=True, it trains the network with 12 epochs and the LR decayed from 0.1 to 0 within 12 epochs
# When use_12epochs_result=False, it trains the network with 200 epochs and the LR decayed from 0.1 to 0 within 200 epochs
# `is_random`
# When is_random=True, the performance of a random architecture will be returned
# When is_random=False, the performanceo of all trials will be averaged.
def get_more_info(self, index: int, dataset, iepoch=None, use_12epochs_result=False, is_random=True):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
else : basestr, arch2infos = '200epochs', self.arch2infos_full
archresult = arch2infos[index]
# if randomly select one trial, select the seed at first
if isinstance(is_random, bool) and is_random:
seeds = archresult.get_dataset_seeds(dataset)
is_random = random.choice(seeds)
if dataset == 'cifar10-valid':
train_info = archresult.get_metrics(dataset, 'train' , iepoch=iepoch, is_random=is_random)
valid_info = archresult.get_metrics(dataset, 'x-valid' , iepoch=iepoch, is_random=is_random)
try:
test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
except:
test__info = None
total = train_info['iepoch'] + 1
xifo = {'train-loss' : train_info['loss'],
'train-accuracy': train_info['accuracy'],
'train-per-time': None if train_info['all_time'] is None else train_info['all_time'] / total,
'train-all-time': train_info['all_time'],
'valid-loss' : valid_info['loss'],
'valid-accuracy': valid_info['accuracy'],
'valid-all-time': valid_info['all_time'],
'valid-per-time': None if valid_info['all_time'] is None else valid_info['all_time'] / total}
if test__info is not None:
xifo['test-loss'] = test__info['loss']
xifo['test-accuracy'] = test__info['accuracy']
return xifo
else:
train_info = archresult.get_metrics(dataset, 'train' , iepoch=iepoch, is_random=is_random)
try:
if dataset == 'cifar10':
test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else:
test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
except:
test__info = None
try:
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
except:
valid_info = None
try:
est_valid_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
except:
est_valid_info = None
xifo = {'train-loss' : train_info['loss'],
'train-accuracy': train_info['accuracy']}
if test__info is not None:
xifo['test-loss'] = test__info['loss'],
xifo['test-accuracy'] = test__info['accuracy']
if valid_info is not None:
xifo['valid-loss'] = valid_info['loss']
xifo['valid-accuracy'] = valid_info['accuracy']
if est_valid_info is not None:
xifo['est-valid-loss'] = est_valid_info['loss']
xifo['est-valid-accuracy'] = est_valid_info['accuracy']
return xifo
def show(self, index: int = -1):
return_flag = 0
"""
This function will print the information of a specific (or all) architecture(s).
:param index: If the index < 0: it will loop for all architectures and print their information one by one.
else: it will print the information of the 'index'-th archiitecture.
:return: nothing
"""
if index < 0: # show all architectures
print(self)
for i, idx in enumerate(self.evaluated_indexes):
print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th architecture! '.format(i, len(self.evaluated_indexes), idx) + '-'*10)
print('arch : {:}'.format(self.meta_archs[idx]))
strings = print_information(self.arch2infos_full[idx])
print('>' * 40 + ' {:03d} epochs '.format(self.arch2infos_full[idx].get_total_epoch()) + '>' * 40)
print('\n'.join(strings))
strings = print_information(self.arch2infos_less[idx])
print('>' * 40 + ' {:03d} epochs '.format(self.arch2infos_less[idx].get_total_epoch()) + '>' * 40)
print('\n'.join(strings))
print('<' * 40 + '------------' + '<' * 40)
else:
if 0 <= index < len(self.meta_archs):
if index not in self.evaluated_indexes: print('The {:}-th architecture has not been evaluated or not saved.'.format(index))
else:
return_flag = 1
out = []
strings = print_information(self.arch2infos_full[index])
out.append(strings)
print('>' * 40 + ' {:03d} epochs '.format(self.arch2infos_full[index].get_total_epoch()) + '>' * 40)
print('\n'.join(strings))
strings = print_information(self.arch2infos_less[index])
out.append(strings)
print('>' * 40 + ' {:03d} epochs '.format(self.arch2infos_less[index].get_total_epoch()) + '>' * 40)
print('\n'.join(strings))
print('<' * 40 + '------------' + '<' * 40)
else:
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))
if return_flag:
return out
@staticmethod
def str2lists(arch_str: Text) -> List[tuple]:
"""
This function shows how to read the string-based architecture encoding.
It is the same as the `str2structure` func in `AutoDL-Projects/lib/models/cell_searchs/genotypes.py`
:param
arch_str: the input is a string indicates the architecture topology, such as
|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|
:return: a list of tuple, contains multiple (op, input_node_index) pairs.
:usage
arch = api.str2lists( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
print ('there are {:} nodes in this arch'.format(len(arch)+1)) # arch is a list
for i, node in enumerate(arch):
print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node))
"""
node_strs = arch_str.split('+')
genotypes = []
for i, node_str in enumerate(node_strs):
inputs = list(filter(lambda x: x != '', node_str.split('|')))
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
inputs = ( xi.split('~') for xi in inputs )
input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs)
genotypes.append( input_infos )
return genotypes
@staticmethod
def str2matrix(arch_str: Text,
search_space: List[Text] = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']) -> np.ndarray:
"""
This func shows how to convert the string-based architecture encoding to the encoding strategy in NAS-Bench-101.
:param
arch_str: the input is a string indicates the architecture topology, such as
|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|
search_space: a list of operation string, the default list is the search space for NAS-Bench-201
the default value should be be consistent with this line https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/models/cell_operations.py#L24
:return
the numpy matrix (2-D np.ndarray) representing the DAG of this architecture topology
:usage
matrix = api.str2matrix( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
This matrix is 4-by-4 matrix representing a cell with 4 nodes (only the lower left triangle is useful).
[ [0, 0, 0, 0], # the first line represents the input (0-th) node
[2, 0, 0, 0], # the second line represents the 1-st node, is calculated by 2-th-op( 0-th-node )
[0, 0, 0, 0], # the third line represents the 2-nd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node )
[0, 0, 1, 0] ] # the fourth line represents the 3-rd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node ) + 1-th-op( 2-th-node )
In NAS-Bench-201 search space, 0-th-op is 'none', 1-th-op is 'skip_connect',
2-th-op is 'nor_conv_1x1', 3-th-op is 'nor_conv_3x3', 4-th-op is 'avg_pool_3x3'.
:(NOTE)
If a node has two input-edges from the same node, this function does not work. One edge will be overlapped.
"""
node_strs = arch_str.split('+')
num_nodes = len(node_strs) + 1
matrix = np.zeros((num_nodes, num_nodes))
for i, node_str in enumerate(node_strs):
inputs = list(filter(lambda x: x != '', node_str.split('|')))
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
for xi in inputs:
op, idx = xi.split('~')
if op not in search_space: raise ValueError('this op ({:}) is not in {:}'.format(op, search_space))
op_idx, node_idx = search_space.index(op), int(idx)
matrix[i+1, node_idx] = op_idx
return matrix
class ArchResults(object):
def __init__(self, arch_index, arch_str):
self.arch_index = int(arch_index)
self.arch_str = copy.deepcopy(arch_str)
self.all_results = dict()
self.dataset_seed = dict()
self.clear_net_done = False
def get_compute_costs(self, dataset):
x_seeds = self.dataset_seed[dataset]
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
flops = [result.flop for result in results]
params = [result.params for result in results]
latencies = [result.get_latency() for result in results]
latencies = [x for x in latencies if x > 0]
mean_latency = np.mean(latencies) if len(latencies) > 0 else None
time_infos = defaultdict(list)
for result in results:
time_info = result.get_times()
for key, value in time_info.items(): time_infos[key].append( value )
info = {'flops' : np.mean(flops),
'params' : np.mean(params),
'latency': mean_latency}
for key, value in time_infos.items():
if len(value) > 0 and value[0] is not None:
info[key] = np.mean(value)
else: info[key] = None
return info
def get_metrics(self, dataset, setname, iepoch=None, is_random=False):
"""
This `get_metrics` function is used to obtain obtain the loss, accuracy, etc information on a specific dataset.
If not specify, each set refer to the proposed split in NAS-Bench-201 paper.
If some args return None or raise error, then it is not avaliable.
========================================
Args [dataset] (4 possible options):
-- cifar10-valid : training the model on the CIFAR-10 training set.
-- cifar10 : training the model on the CIFAR-10 training + validation set.
-- cifar100 : training the model on the CIFAR-100 training set.
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
Args [setname] (each dataset has different setnames):
-- When dataset = cifar10-valid, you can use 'train', 'x-valid', 'ori-test'
------ 'train' : the metric on the training set.
------ 'x-valid' : the metric on the validation set.
------ 'ori-test' : the metric on the test set.
-- When dataset = cifar10, you can use 'train', 'ori-test'.
------ 'train' : the metric on the training + validation set.
------ 'ori-test' : the metric on the test set.
-- When dataset = cifar100 or ImageNet16-120, you can use 'train', 'ori-test', 'x-valid', 'x-test'
------ 'train' : the metric on the training set.
------ 'x-valid' : the metric on the validation set.
------ 'x-test' : the metric on the test set.
------ 'ori-test' : the metric on the validation + test set.
Args [iepoch] (None or an integer in [0, the-number-of-total-training-epochs)
------ None : return the metric after the last training epoch.
------ an integer i : return the metric after the i-th training epoch.
Args [is_random]:
------ True : return the metric of a randomly selected trial.
------ False : return the averaged metric of all avaliable trials.
------ an integer indicating the 'seed' value : return the metric of a specific trial (whose random seed is 'is_random').
"""
x_seeds = self.dataset_seed[dataset]
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
infos = defaultdict(list)
for result in results:
if setname == 'train':
info = result.get_train(iepoch)
else:
info = result.get_eval(setname, iepoch)
for key, value in info.items(): infos[key].append( value )
return_info = dict()
if isinstance(is_random, bool) and is_random: # randomly select one
index = random.randint(0, len(results)-1)
for key, value in infos.items(): return_info[key] = value[index]
elif isinstance(is_random, bool) and not is_random: # average
for key, value in infos.items():
if len(value) > 0 and value[0] is not None:
return_info[key] = np.mean(value)
else: return_info[key] = None
elif isinstance(is_random, int): # specify the seed
if is_random not in x_seeds: raise ValueError('can not find random seed ({:}) from {:}'.format(is_random, x_seeds))
index = x_seeds.index(is_random)
for key, value in infos.items(): return_info[key] = value[index]
else:
raise ValueError('invalid value for is_random: {:}'.format(is_random))
return return_info
def show(self, is_print=False):
return print_information(self, None, is_print)
def get_dataset_names(self):
return list(self.dataset_seed.keys())
def get_dataset_seeds(self, dataset):
return copy.deepcopy( self.dataset_seed[dataset] )
def get_net_param(self, dataset: Text, seed: Union[None, int] =None):
"""
This function will return the trained network's weights on the 'dataset'.
:arg
dataset: one of 'cifar10-valid', 'cifar10', 'cifar100', and 'ImageNet16-120'.
seed: an integer indicates the seed value or None that indicates returing all trials.
"""
if seed is None:
x_seeds = self.dataset_seed[dataset]
return {seed: self.all_results[(dataset, seed)].get_net_param() for seed in x_seeds}
else:
return self.all_results[(dataset, seed)].get_net_param()
def reset_latency(self, dataset: Text, seed: Union[None, Text], latency: float) -> None:
"""This function is used to reset the latency in all corresponding ResultsCount(s)."""
if seed is None:
for seed in self.dataset_seed[dataset]:
self.all_results[(dataset, seed)].update_latency([latency])
else:
self.all_results[(dataset, seed)].update_latency([latency])
def reset_pseudo_train_times(self, dataset: Text, seed: Union[None, Text], estimated_per_epoch_time: float) -> None:
"""This function is used to reset the train-times in all corresponding ResultsCount(s)."""
if seed is None:
for seed in self.dataset_seed[dataset]:
self.all_results[(dataset, seed)].reset_pseudo_train_times(estimated_per_epoch_time)
else:
self.all_results[(dataset, seed)].reset_pseudo_train_times(estimated_per_epoch_time)
def reset_pseudo_eval_times(self, dataset: Text, seed: Union[None, Text], eval_name: Text, estimated_per_epoch_time: float) -> None:
"""This function is used to reset the eval-times in all corresponding ResultsCount(s)."""
if seed is None:
for seed in self.dataset_seed[dataset]:
self.all_results[(dataset, seed)].reset_pseudo_eval_times(eval_name, estimated_per_epoch_time)
else:
self.all_results[(dataset, seed)].reset_pseudo_eval_times(eval_name, estimated_per_epoch_time)
def get_latency(self, dataset: Text) -> float:
"""Get the latency of a model on the target dataset. [Timestamp: 2020.03.09]"""
latencies = []
for seed in self.dataset_seed[dataset]:
latency = self.all_results[(dataset, seed)].get_latency()
if not isinstance(latency, float) or latency <= 0:
raise ValueError('invalid latency of {:} for {:} with {:}'.format(dataset))
latencies.append(latency)
return sum(latencies) / len(latencies)
def get_total_epoch(self, dataset=None):
"""Return the total number of training epochs."""
if dataset is None:
epochss = []
for xdata, x_seeds in self.dataset_seed.items():
epochss += [self.all_results[(xdata, seed)].get_total_epoch() for seed in x_seeds]
elif isinstance(dataset, str):
x_seeds = self.dataset_seed[dataset]
epochss = [self.all_results[(dataset, seed)].get_total_epoch() for seed in x_seeds]
else:
raise ValueError('invalid dataset={:}'.format(dataset))
if len(set(epochss)) > 1: raise ValueError('Each trial mush have the same number of training epochs : {:}'.format(epochss))
return epochss[-1]
def query(self, dataset, seed=None):
"""Return the ResultsCount object (containing all information of a single trial) for 'dataset' and 'seed'"""
if seed is None:
x_seeds = self.dataset_seed[dataset]
return {seed: self.all_results[(dataset, seed)] for seed in x_seeds}
else:
return self.all_results[(dataset, seed)]
def arch_idx_str(self):
return '{:06d}'.format(self.arch_index)
def update(self, dataset_name, seed, result):
if dataset_name not in self.dataset_seed:
self.dataset_seed[dataset_name] = []
assert seed not in self.dataset_seed[dataset_name], '{:}-th arch alreadly has this seed ({:}) on {:}'.format(self.arch_index, seed, dataset_name)
self.dataset_seed[ dataset_name ].append( seed )
self.dataset_seed[ dataset_name ] = sorted( self.dataset_seed[ dataset_name ] )
assert (dataset_name, seed) not in self.all_results
self.all_results[ (dataset_name, seed) ] = result
self.clear_net_done = False
def state_dict(self):
state_dict = dict()
for key, value in self.__dict__.items():
if key == 'all_results': # contain the class of ResultsCount
xvalue = dict()
assert isinstance(value, dict), 'invalid type of value for {:} : {:}'.format(key, type(value))
for _k, _v in value.items():
assert isinstance(_v, ResultsCount), 'invalid type of value for {:}/{:} : {:}'.format(key, _k, type(_v))
xvalue[_k] = _v.state_dict()
else:
xvalue = value
state_dict[key] = xvalue
return state_dict
def load_state_dict(self, state_dict):
new_state_dict = dict()
for key, value in state_dict.items():
if key == 'all_results': # to convert to the class of ResultsCount
xvalue = dict()
assert isinstance(value, dict), 'invalid type of value for {:} : {:}'.format(key, type(value))
for _k, _v in value.items():
xvalue[_k] = ResultsCount.create_from_state_dict(_v)
else: xvalue = value
new_state_dict[key] = xvalue
self.__dict__.update(new_state_dict)
@staticmethod
def create_from_state_dict(state_dict_or_file):
x = ArchResults(-1, -1)
if isinstance(state_dict_or_file, str): # a file path
state_dict = torch.load(state_dict_or_file)
elif isinstance(state_dict_or_file, dict):
state_dict = state_dict_or_file
else:
raise ValueError('invalid type of state_dict_or_file : {:}'.format(type(state_dict_or_file)))
x.load_state_dict(state_dict)
return x
# This function is used to clear the weights saved in each 'result'
# This can help reduce the memory footprint.
def clear_params(self):
for key, result in self.all_results.items():
result.net_state_dict = None
self.clear_net_done = True
def debug_test(self):
"""This function is used for me to debug and test, which will call most methods."""
all_dataset = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
for dataset in all_dataset:
print('---->>>> {:}'.format(dataset))
print('The latency on {:} is {:} s'.format(dataset, self.get_latency(dataset)))
for seed in self.dataset_seed[dataset]:
result = self.all_results[(dataset, seed)]
print(' ==>> result = {:}'.format(result))
print(' ==>> cost = {:}'.format(result.get_times()))
def __repr__(self):
return ('{name}(arch-index={index}, arch={arch}, {num} runs, clear={clear})'.format(name=self.__class__.__name__, index=self.arch_index, arch=self.arch_str, num=len(self.all_results), clear=self.clear_net_done))
"""
This class (ResultsCount) is used to save the information of one trial for a single architecture.
I did not write much comment for this class, because it is the lowest-level class in NAS-Bench-201 API, which will be rarely called.
If you have any question regarding this class, please open an issue or email me.
"""
class ResultsCount(object):
def __init__(self, name, state_dict, train_accs, train_losses, params, flop, arch_config, seed, epochs, latency):
self.name = name
self.net_state_dict = state_dict
self.train_acc1es = copy.deepcopy(train_accs)
self.train_acc5es = None
self.train_losses = copy.deepcopy(train_losses)
self.train_times = None
self.arch_config = copy.deepcopy(arch_config)
self.params = params
self.flop = flop
self.seed = seed
self.epochs = epochs
self.latency = latency
# evaluation results
self.reset_eval()
def update_train_info(self, train_acc1es, train_acc5es, train_losses, train_times) -> None:
self.train_acc1es = train_acc1es
self.train_acc5es = train_acc5es
self.train_losses = train_losses
self.train_times = train_times
def reset_pseudo_train_times(self, estimated_per_epoch_time: float) -> None:
"""Assign the training times."""
train_times = OrderedDict()
for i in range(self.epochs):
train_times[i] = estimated_per_epoch_time
self.train_times = train_times
def reset_pseudo_eval_times(self, eval_name: Text, estimated_per_epoch_time: float) -> None:
"""Assign the evaluation times."""
if eval_name not in self.eval_names: raise ValueError('invalid eval name : {:}'.format(eval_name))
for i in range(self.epochs):
self.eval_times['{:}@{:}'.format(eval_name,i)] = estimated_per_epoch_time
def reset_eval(self):
self.eval_names = []
self.eval_acc1es = {}
self.eval_times = {}
self.eval_losses = {}
def update_latency(self, latency):
self.latency = copy.deepcopy( latency )
def get_latency(self) -> float:
"""Return the latency value in seconds. -1 represents not avaliable ; otherwise it should be a float value"""
if self.latency is None: return -1.0
else: return sum(self.latency) / len(self.latency)
def update_eval(self, accs, losses, times): # new version
data_names = set([x.split('@')[0] for x in accs.keys()])
for data_name in data_names:
assert data_name not in self.eval_names, '{:} has already been added into eval-names'.format(data_name)
self.eval_names.append( data_name )
for iepoch in range(self.epochs):
xkey = '{:}@{:}'.format(data_name, iepoch)
self.eval_acc1es[ xkey ] = accs[ xkey ]
self.eval_losses[ xkey ] = losses[ xkey ]
self.eval_times [ xkey ] = times[ xkey ]
def update_OLD_eval(self, name, accs, losses): # old version
assert name not in self.eval_names, '{:} has already added'.format(name)
self.eval_names.append( name )
for iepoch in range(self.epochs):
if iepoch in accs:
self.eval_acc1es['{:}@{:}'.format(name,iepoch)] = accs[iepoch]
self.eval_losses['{:}@{:}'.format(name,iepoch)] = losses[iepoch]
def __repr__(self):
num_eval = len(self.eval_names)
set_name = '[' + ', '.join(self.eval_names) + ']'
return ('{name}({xname}, arch={arch}, FLOP={flop:.2f}M, Param={param:.3f}MB, seed={seed}, {num_eval} eval-sets: {set_name})'.format(name=self.__class__.__name__, xname=self.name, arch=self.arch_config['arch_str'], flop=self.flop, param=self.params, seed=self.seed, num_eval=num_eval, set_name=set_name))
def get_total_epoch(self):
return copy.deepcopy(self.epochs)
def get_times(self):
"""Obtain the information regarding both training and evaluation time."""
if self.train_times is not None and isinstance(self.train_times, dict):
train_times = list( self.train_times.values() )
time_info = {'T-train@epoch': np.mean(train_times), 'T-train@total': np.sum(train_times)}
else:
time_info = {'T-train@epoch': None, 'T-train@total': None }
for name in self.eval_names:
try:
xtimes = [self.eval_times['{:}@{:}'.format(name,i)] for i in range(self.epochs)]
time_info['T-{:}@epoch'.format(name)] = np.mean(xtimes)
time_info['T-{:}@total'.format(name)] = np.sum(xtimes)
except:
time_info['T-{:}@epoch'.format(name)] = None
time_info['T-{:}@total'.format(name)] = None
return time_info
def get_eval_set(self):
return self.eval_names
# get the training information
def get_train(self, iepoch=None):
if iepoch is None: iepoch = self.epochs-1
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
if self.train_times is not None:
xtime = self.train_times[iepoch]
atime = sum([self.train_times[i] for i in range(iepoch+1)])
else: xtime, atime = None, None
return {'iepoch' : iepoch,
'loss' : self.train_losses[iepoch],
'accuracy': self.train_acc1es[iepoch],
'cur_time': xtime,
'all_time': atime}
# get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument).
def get_eval(self, name, iepoch=None):
if iepoch is None: iepoch = self.epochs-1
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
if isinstance(self.eval_times,dict) and len(self.eval_times) > 0:
xtime = self.eval_times['{:}@{:}'.format(name,iepoch)]
atime = sum([self.eval_times['{:}@{:}'.format(name,i)] for i in range(iepoch+1)])
else: xtime, atime = None, None
return {'iepoch' : iepoch,
'loss' : self.eval_losses['{:}@{:}'.format(name,iepoch)],
'accuracy': self.eval_acc1es['{:}@{:}'.format(name,iepoch)],
'cur_time': xtime,
'all_time': atime}
def get_net_param(self, clone=False):
if clone: return copy.deepcopy(self.net_state_dict)
else: return self.net_state_dict
# This function is used to obtain the config dict for this architecture.
def get_config(self, str2structure):
if str2structure is None:
return {'name': 'infer.tiny', 'C': self.arch_config['channel'],
'N' : self.arch_config['num_cells'],
'arch_str': self.arch_config['arch_str'], 'num_classes': self.arch_config['class_num']}
else:
return {'name': 'infer.tiny', 'C': self.arch_config['channel'],
'N' : self.arch_config['num_cells'],
'genotype': str2structure(self.arch_config['arch_str']), 'num_classes': self.arch_config['class_num']}
def state_dict(self):
_state_dict = {key: value for key, value in self.__dict__.items()}
return _state_dict
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
@staticmethod
def create_from_state_dict(state_dict):
x = ResultsCount(None, None, None, None, None, None, None, None, None, None)
x.load_state_dict(state_dict)
return x

274
nas_201_api/api_201.py Normal file
View File

@@ -0,0 +1,274 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
############################################################################################
# NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 #
############################################################################################
# The history of benchmark files:
# [2020.02.25] NAS-Bench-201-v1_0-e61699.pth : 6219 architectures are trained once, 1621 architectures are trained twice, 7785 architectures are trained three times. `LESS` only supports CIFAR10-VALID.
# [2020.03.16] NAS-Bench-201-v1_1-096897.pth : 2225 architectures are trained once, 5439 archiitectures are trained twice, 7961 architectures are trained three times on all training sets. For the hyper-parameters with the total epochs of 12, each model is trained on CIFAR-10, CIFAR-100, ImageNet16-120 once, and is trained on CIFAR-10-VALID twice.
#
# I'm still actively enhancing our benchmark, while for the future benchmark file, please follow news from NATS-Bench (an extended version of NAS-Bench-201).
#
import os, copy, random, torch, numpy as np
from pathlib import Path
from typing import List, Text, Union, Dict, Optional
from collections import OrderedDict, defaultdict
from .api_utils import ArchResults
from .api_utils import NASBenchMetaAPI
from .api_utils import remap_dataset_set_names
ALL_BENCHMARK_FILES = ['NAS-Bench-201-v1_0-e61699.pth', 'NAS-Bench-201-v1_1-096897.pth']
ALL_ARCHIVE_DIRS = ['NAS-Bench-201-v1_1-archive']
def print_information(information, extra_info=None, show=False):
dataset_names = information.get_dataset_names()
strings = [information.arch_str, 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)]
def metric2str(loss, acc):
return 'loss = {:.3f}, top1 = {:.2f}%'.format(loss, acc)
for ida, dataset in enumerate(dataset_names):
metric = information.get_compute_costs(dataset)
flop, param, latency = metric['flops'], metric['params'], metric['latency']
str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency is not None and latency > 0 else None)
train_info = information.get_metrics(dataset, 'train')
if dataset == 'cifar10-valid':
valid_info = information.get_metrics(dataset, 'x-valid')
str2 = '{:14s} train : [{:}], valid : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']))
elif dataset == 'cifar10':
test__info = information.get_metrics(dataset, 'ori-test')
str2 = '{:14s} train : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
else:
valid_info = information.get_metrics(dataset, 'x-valid')
test__info = information.get_metrics(dataset, 'x-test')
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
strings += [str1, str2]
if show: print('\n'.join(strings))
return strings
"""
This is the class for the API of NAS-Bench-201.
"""
class NASBench201API(NASBenchMetaAPI):
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None,
verbose: bool=True):
self.filename = None
self.reset_time()
if file_path_or_dict is None:
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], ALL_BENCHMARK_FILES[-1])
print ('Try to use the default NAS-Bench-201 path from {:}.'.format(file_path_or_dict))
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path):
file_path_or_dict = str(file_path_or_dict)
if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict))
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
self.filename = Path(file_path_or_dict).name
file_path_or_dict = torch.load(file_path_or_dict, map_location='cpu')
elif isinstance(file_path_or_dict, dict):
file_path_or_dict = copy.deepcopy(file_path_or_dict)
else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict)))
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
self.verbose = verbose # [TODO] a flag indicating whether to print more logs
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults
self.arch2infos_dict = OrderedDict()
self._avaliable_hps = set(['12', '200'])
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
all_info = file_path_or_dict['arch2infos'][xkey]
hp2archres = OrderedDict()
# self.arch2infos_less[xkey] = ArchResults.create_from_state_dict( all_info['less'] )
# self.arch2infos_full[xkey] = ArchResults.create_from_state_dict( all_info['full'] )
hp2archres['12'] = ArchResults.create_from_state_dict(all_info['less'])
hp2archres['200'] = ArchResults.create_from_state_dict(all_info['full'])
self.arch2infos_dict[xkey] = hp2archres
self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes']))
self.archstr2index = {}
for idx, arch in enumerate(self.meta_archs):
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
self.archstr2index[ arch ] = idx
def reload(self, archive_root: Text = None, index: int = None):
"""Overwrite all information of the 'index'-th architecture in the search space.
It will load its data from 'archive_root'.
"""
if archive_root is None:
archive_root = os.path.join(os.environ['TORCH_HOME'], ALL_ARCHIVE_DIRS[-1])
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
if index is None:
indexes = list(range(len(self)))
else:
indexes = [index]
for idx in indexes:
assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx)
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(idx))
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
xdata = torch.load(xfile_path, map_location='cpu')
assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path)
hp2archres = OrderedDict()
hp2archres['12'] = ArchResults.create_from_state_dict(xdata['less'])
hp2archres['200'] = ArchResults.create_from_state_dict(xdata['full'])
self.arch2infos_dict[idx] = hp2archres
def query_info_str_by_arch(self, arch, hp: Text='12'):
""" This function is used to query the information of a specific architecture
'arch' can be an architecture index or an architecture string
When hp=12, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/12E.config'
When hp=200, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/200E.config'
The difference between these three configurations are the number of training epochs.
"""
if self.verbose:
print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp))
return self._query_info_str_by_arch(arch, hp, print_information)
# obtain the metric for the `index`-th architecture
# `dataset` indicates the dataset:
# 'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
# 'cifar10' : using the proposed train+valid set of CIFAR-10 as the training set
# 'cifar100' : using the proposed train set of CIFAR-100 as the training set
# 'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set
# `iepoch` indicates the index of training epochs from 0 to 11/199.
# When iepoch=None, it will return the metric for the last training epoch
# When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0)
# `use_12epochs_result` indicates different hyper-parameters for training
# When use_12epochs_result=True, it trains the network with 12 epochs and the LR decayed from 0.1 to 0 within 12 epochs
# When use_12epochs_result=False, it trains the network with 200 epochs and the LR decayed from 0.1 to 0 within 200 epochs
# `is_random`
# When is_random=True, the performance of a random architecture will be returned
# When is_random=False, the performanceo of all trials will be averaged.
def get_more_info(self, index, dataset, iepoch=None, hp='12', is_random=True):
if self.verbose:
print('Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(index, dataset, iepoch, hp, is_random))
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
if index not in self.arch2infos_dict:
raise ValueError('Did not find {:} from arch2infos_dict.'.format(index))
archresult = self.arch2infos_dict[index][str(hp)]
# if randomly select one trial, select the seed at first
if isinstance(is_random, bool) and is_random:
seeds = archresult.get_dataset_seeds(dataset)
is_random = random.choice(seeds)
# collect the training information
train_info = archresult.get_metrics(dataset, 'train', iepoch=iepoch, is_random=is_random)
total = train_info['iepoch'] + 1
xinfo = {'train-loss' : train_info['loss'],
'train-accuracy': train_info['accuracy'],
'train-per-time': train_info['all_time'] / total if train_info['all_time'] is not None else None,
'train-all-time': train_info['all_time']}
# collect the evaluation information
if dataset == 'cifar10-valid':
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
try:
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
except:
test_info = None
valtest_info = None
else:
try: # collect results on the proposed test set
if dataset == 'cifar10':
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else:
test_info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
except:
test_info = None
try: # collect results on the proposed validation set
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
except:
valid_info = None
try:
if dataset != 'cifar10':
valtest_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else:
valtest_info = None
except:
valtest_info = None
if valid_info is not None:
xinfo['valid-loss'] = valid_info['loss']
xinfo['valid-accuracy'] = valid_info['accuracy']
xinfo['valid-per-time'] = valid_info['all_time'] / total if valid_info['all_time'] is not None else None
xinfo['valid-all-time'] = valid_info['all_time']
if test_info is not None:
xinfo['test-loss'] = test_info['loss']
xinfo['test-accuracy'] = test_info['accuracy']
xinfo['test-per-time'] = test_info['all_time'] / total if test_info['all_time'] is not None else None
xinfo['test-all-time'] = test_info['all_time']
if valtest_info is not None:
xinfo['valtest-loss'] = valtest_info['loss']
xinfo['valtest-accuracy'] = valtest_info['accuracy']
xinfo['valtest-per-time'] = valtest_info['all_time'] / total if valtest_info['all_time'] is not None else None
xinfo['valtest-all-time'] = valtest_info['all_time']
return xinfo
def show(self, index: int = -1) -> None:
"""This function will print the information of a specific (or all) architecture(s)."""
self._show(index, print_information)
@staticmethod
def str2lists(arch_str: Text) -> List[tuple]:
"""
This function shows how to read the string-based architecture encoding.
It is the same as the `str2structure` func in `AutoDL-Projects/lib/models/cell_searchs/genotypes.py`
:param
arch_str: the input is a string indicates the architecture topology, such as
|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|
:return: a list of tuple, contains multiple (op, input_node_index) pairs.
:usage
arch = api.str2lists( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
print ('there are {:} nodes in this arch'.format(len(arch)+1)) # arch is a list
for i, node in enumerate(arch):
print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node))
"""
node_strs = arch_str.split('+')
genotypes = []
for i, node_str in enumerate(node_strs):
inputs = list(filter(lambda x: x != '', node_str.split('|')))
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
inputs = ( xi.split('~') for xi in inputs )
input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs)
genotypes.append( input_infos )
return genotypes
@staticmethod
def str2matrix(arch_str: Text,
search_space: List[Text] = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']) -> np.ndarray:
"""
This func shows how to convert the string-based architecture encoding to the encoding strategy in NAS-Bench-101.
:param
arch_str: the input is a string indicates the architecture topology, such as
|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|
search_space: a list of operation string, the default list is the search space for NAS-Bench-201
the default value should be be consistent with this line https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/models/cell_operations.py#L24
:return
the numpy matrix (2-D np.ndarray) representing the DAG of this architecture topology
:usage
matrix = api.str2matrix( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
This matrix is 4-by-4 matrix representing a cell with 4 nodes (only the lower left triangle is useful).
[ [0, 0, 0, 0], # the first line represents the input (0-th) node
[2, 0, 0, 0], # the second line represents the 1-st node, is calculated by 2-th-op( 0-th-node )
[0, 0, 0, 0], # the third line represents the 2-nd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node )
[0, 0, 1, 0] ] # the fourth line represents the 3-rd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node ) + 1-th-op( 2-th-node )
In NAS-Bench-201 search space, 0-th-op is 'none', 1-th-op is 'skip_connect',
2-th-op is 'nor_conv_1x1', 3-th-op is 'nor_conv_3x3', 4-th-op is 'avg_pool_3x3'.
:(NOTE)
If a node has two input-edges from the same node, this function does not work. One edge will be overlapped.
"""
node_strs = arch_str.split('+')
num_nodes = len(node_strs) + 1
matrix = np.zeros((num_nodes, num_nodes))
for i, node_str in enumerate(node_strs):
inputs = list(filter(lambda x: x != '', node_str.split('|')))
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
for xi in inputs:
op, idx = xi.split('~')
if op not in search_space: raise ValueError('this op ({:}) is not in {:}'.format(op, search_space))
op_idx, node_idx = search_space.index(op), int(idx)
matrix[i+1, node_idx] = op_idx
return matrix

750
nas_201_api/api_utils.py Normal file
View File

@@ -0,0 +1,750 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
############################################################################################
# NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 #
############################################################################################
# In this Python file, we define NASBenchMetaAPI, the abstract class for benchmark APIs.
# We also define the class ArchResults, which contains all information of a single architecture trained by one kind of hyper-parameters on three datasets.
# We also define the class ResultsCount, which contains all information of a single trial for a single architecture.
############################################################################################
#
import os, abc, copy, random, torch, numpy as np
from pathlib import Path
from typing import List, Text, Union, Dict, Optional
from collections import OrderedDict, defaultdict
def remap_dataset_set_names(dataset, metric_on_set, verbose=False):
"""re-map the metric_on_set to internal keys"""
if verbose:
print('Call internal function _remap_dataset_set_names with dataset={:} and metric_on_set={:}'.format(dataset, metric_on_set))
if dataset == 'cifar10' and metric_on_set == 'valid':
dataset, metric_on_set = 'cifar10-valid', 'x-valid'
elif dataset == 'cifar10' and metric_on_set == 'test':
dataset, metric_on_set = 'cifar10', 'ori-test'
elif dataset == 'cifar10' and metric_on_set == 'train':
dataset, metric_on_set = 'cifar10', 'train'
elif (dataset == 'cifar100' or dataset == 'ImageNet16-120') and metric_on_set == 'valid':
metric_on_set = 'x-valid'
elif (dataset == 'cifar100' or dataset == 'ImageNet16-120') and metric_on_set == 'test':
metric_on_set = 'x-test'
if verbose:
print(' return dataset={:} and metric_on_set={:}'.format(dataset, metric_on_set))
return dataset, metric_on_set
class NASBenchMetaAPI(metaclass=abc.ABCMeta):
@abc.abstractmethod
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, verbose: bool=True):
"""The initialization function that takes the dataset file path (or a dict loaded from that path) as input."""
def __getitem__(self, index: int):
return copy.deepcopy(self.meta_archs[index])
def arch(self, index: int):
"""Return the topology structure of the `index`-th architecture."""
if self.verbose:
print('Call the arch function with index={:}'.format(index))
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
return copy.deepcopy(self.meta_archs[index])
def __len__(self):
return len(self.meta_archs)
def __repr__(self):
return ('{name}({num}/{total} architectures, file={filename})'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs), filename=self.filename))
@property
def avaliable_hps(self):
return list(copy.deepcopy(self._avaliable_hps))
@property
def used_time(self):
return self._used_time
def reset_time(self):
self._used_time = 0
def simulate_train_eval(self, arch, dataset, iepoch=None, hp='12', account_time=True):
index = self.query_index_by_arch(arch)
all_names = ('cifar10', 'cifar100', 'ImageNet16-120')
assert dataset in all_names, 'Invalid dataset name : {:} vs {:}'.format(dataset, all_names)
if dataset == 'cifar10':
info = self.get_more_info(index, 'cifar10-valid', iepoch=iepoch, hp=hp, is_random=True)
else:
info = self.get_more_info(index, dataset, iepoch=iepoch, hp=hp, is_random=True)
valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time']
latency = self.get_latency(index, dataset)
if account_time:
self._used_time += time_cost
return valid_acc, latency, time_cost, self._used_time
def random(self):
"""Return a random index of all architectures."""
return random.randint(0, len(self.meta_archs)-1)
def query_index_by_arch(self, arch):
""" This function is used to query the index of an architecture in the search space.
In the topology search space, the input arch can be an architecture string such as '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|';
or an instance that has the 'tostr' function that can generate the architecture string;
or it is directly an architecture index, in this case, we will check whether it is valid or not.
This function will return the index.
If return -1, it means this architecture is not in the search space.
Otherwise, it will return an int in [0, the-number-of-candidates-in-the-search-space).
"""
if self.verbose:
print('Call query_index_by_arch with arch={:}'.format(arch))
if isinstance(arch, int):
if 0 <= arch < len(self):
return arch
else:
raise ValueError('Invalid architecture index {:} vs [{:}, {:}].'.format(arch, 0, len(self)))
elif isinstance(arch, str):
if arch in self.archstr2index: arch_index = self.archstr2index[ arch ]
else : arch_index = -1
elif hasattr(arch, 'tostr'):
if arch.tostr() in self.archstr2index: arch_index = self.archstr2index[ arch.tostr() ]
else : arch_index = -1
else: arch_index = -1
return arch_index
def query_by_arch(self, arch, hp):
# This is to make the current version be compatible with the old version.
return self.query_info_str_by_arch(arch, hp)
@abc.abstractmethod
def reload(self, archive_root: Text = None, index: int = None):
"""Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'.
If index is None, overwrite all ckps.
"""
def clear_params(self, index: int, hp: Optional[Text]=None):
"""Remove the architecture's weights to save memory.
:arg
index: the index of the target architecture
hp: a flag to controll how to clear the parameters.
-- None: clear all the weights in '01'/'12'/'90', which indicates the number of training epochs.
-- '01' or '12' or '90': clear all the weights in arch2infos_dict[index][hp].
"""
if self.verbose:
print('Call clear_params with index={:} and hp={:}'.format(index, hp))
if hp is None:
for key, result in self.arch2infos_dict[index].items():
result.clear_params()
else:
if str(hp) not in self.arch2infos_dict[index]:
raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(index, list(self.arch2infos_dict[index].keys()), hp))
self.arch2infos_dict[index][str(hp)].clear_params()
@abc.abstractmethod
def query_info_str_by_arch(self, arch, hp: Text='12'):
"""This function is used to query the information of a specific architecture."""
def _query_info_str_by_arch(self, arch, hp: Text='12', print_information=None):
arch_index = self.query_index_by_arch(arch)
if arch_index in self.arch2infos_dict:
if hp not in self.arch2infos_dict[arch_index]:
raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(index, list(self.arch2infos_dict[arch_index].keys()), hp))
info = self.arch2infos_dict[arch_index][hp]
strings = print_information(info, 'arch-index={:}'.format(arch_index))
return '\n'.join(strings)
else:
print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
return None
def query_meta_info_by_index(self, arch_index, hp: Text = '12'):
"""Return the ArchResults for the 'arch_index'-th architecture. This function is similar to query_by_index."""
if self.verbose:
print('Call query_meta_info_by_index with arch_index={:}, hp={:}'.format(arch_index, hp))
if arch_index in self.arch2infos_dict:
if hp not in self.arch2infos_dict[arch_index]:
raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(arch_index, list(self.arch2infos_dict[arch_index].keys()), hp))
info = self.arch2infos_dict[arch_index][hp]
else:
raise ValueError('arch_index [{:}] does not in arch2infos'.format(arch_index))
return copy.deepcopy(info)
def query_by_index(self, arch_index: int, dataname: Union[None, Text] = None, hp: Text = '12'):
""" This 'query_by_index' function is used to query information with the training of 01 epochs, 12 epochs, 90 epochs, or 200 epochs.
------
If hp=01, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/01E.config)
If hp=12, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/12E.config)
If hp=90, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/90E.config)
If hp=200, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/200E.config)
------
If dataname is None, return the ArchResults
else, return a dict with all trials on that dataset (the key is the seed)
Options are 'cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120'.
-- cifar10-valid : training the model on the CIFAR-10 training set.
-- cifar10 : training the model on the CIFAR-10 training + validation set.
-- cifar100 : training the model on the CIFAR-100 training set.
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
"""
if self.verbose:
print('Call query_by_index with arch_index={:}, dataname={:}, hp={:}'.format(arch_index, dataname, hp))
info = self.query_meta_info_by_index(arch_index, hp)
if dataname is None: return info
else:
if dataname not in info.get_dataset_names():
raise ValueError('invalid dataset-name : {:} vs. {:}'.format(dataname, info.get_dataset_names()))
return info.query(dataname)
def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None, hp: Text = '12'):
"""Find the architecture with the highest accuracy based on some constraints."""
if self.verbose:
print('Call find_best with dataset={:}, metric_on_set={:}, hp={:} | with #FLOPs < {:} and #Params < {:}'.format(dataset, metric_on_set, hp, FLOP_max, Param_max))
dataset, metric_on_set = remap_dataset_set_names(dataset, metric_on_set, self.verbose)
best_index, highest_accuracy = -1, None
for i, arch_index in enumerate(self.evaluated_indexes):
arch_info = self.arch2infos_dict[arch_index][hp]
info = arch_info.get_compute_costs(dataset) # the information of costs
flop, param, latency = info['flops'], info['params'], info['latency']
if FLOP_max is not None and flop > FLOP_max : continue
if Param_max is not None and param > Param_max: continue
xinfo = arch_info.get_metrics(dataset, metric_on_set) # the information of loss and accuracy
loss, accuracy = xinfo['loss'], xinfo['accuracy']
if best_index == -1:
best_index, highest_accuracy = arch_index, accuracy
elif highest_accuracy < accuracy:
best_index, highest_accuracy = arch_index, accuracy
if self.verbose:
print(' the best architecture : [{:}] {:} with accuracy={:.3f}%'.format(best_index, self.arch(best_index), highest_accuracy))
return best_index, highest_accuracy
def get_net_param(self, index, dataset, seed: Optional[int], hp: Text = '12'):
"""
This function is used to obtain the trained weights of the `index`-th architecture on `dataset` with the seed of `seed`
Args [seed]:
-- None : return a dict containing the trained weights of all trials, where each key is a seed and its corresponding value is the weights.
-- a interger : return the weights of a specific trial, whose seed is this interger.
Args [hp]:
-- 01 : train the model by 01 epochs
-- 12 : train the model by 12 epochs
-- 90 : train the model by 90 epochs
-- 200 : train the model by 200 epochs
"""
if self.verbose:
print('Call the get_net_param function with index={:}, dataset={:}, seed={:}, hp={:}'.format(index, dataset, seed, hp))
info = self.query_meta_info_by_index(index, hp)
return info.get_net_param(dataset, seed)
def get_net_config(self, index: int, dataset: Text):
"""
This function is used to obtain the configuration for the `index`-th architecture on `dataset`.
Args [dataset] (4 possible options):
-- cifar10-valid : training the model on the CIFAR-10 training set.
-- cifar10 : training the model on the CIFAR-10 training + validation set.
-- cifar100 : training the model on the CIFAR-100 training set.
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
This function will return a dict.
========= Some examlpes for using this function:
config = api.get_net_config(128, 'cifar10')
"""
if self.verbose:
print('Call the get_net_config function with index={:}, dataset={:}.'.format(index, dataset))
if index in self.arch2infos_dict:
info = self.arch2infos_dict[index]
else:
raise ValueError('The arch_index={:} is not in arch2infos_dict.'.format(arch_index))
info = next(iter(info.values()))
results = info.query(dataset, None)
results = next(iter(results.values()))
return results.get_config(None)
def get_cost_info(self, index: int, dataset: Text, hp: Text = '12') -> Dict[Text, float]:
"""To obtain the cost metric for the `index`-th architecture on a dataset."""
if self.verbose:
print('Call the get_cost_info function with index={:}, dataset={:}, and hp={:}.'.format(index, dataset, hp))
info = self.query_meta_info_by_index(index, hp)
return info.get_compute_costs(dataset)
def get_latency(self, index: int, dataset: Text, hp: Text = '12') -> float:
"""
To obtain the latency of the network (by default it will return the latency with the batch size of 256).
:param index: the index of the target architecture
:param dataset: the dataset name (cifar10-valid, cifar10, cifar100, ImageNet16-120)
:return: return a float value in seconds
"""
if self.verbose:
print('Call the get_latency function with index={:}, dataset={:}, and hp={:}.'.format(index, dataset, hp))
cost_dict = self.get_cost_info(index, dataset, hp)
return cost_dict['latency']
@abc.abstractmethod
def show(self, index=-1):
"""This function will print the information of a specific (or all) architecture(s)."""
def _show(self, index=-1, print_information=None) -> None:
"""
This function will print the information of a specific (or all) architecture(s).
:param index: If the index < 0: it will loop for all architectures and print their information one by one.
else: it will print the information of the 'index'-th architecture.
:return: nothing
"""
if index < 0: # show all architectures
print(self)
for i, idx in enumerate(self.evaluated_indexes):
print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th architecture! '.format(i, len(self.evaluated_indexes), idx) + '-'*10)
print('arch : {:}'.format(self.meta_archs[idx]))
for key, result in self.arch2infos_dict[index].items():
strings = print_information(result)
print('>' * 40 + ' {:03d} epochs '.format(result.get_total_epoch()) + '>' * 40)
print('\n'.join(strings))
print('<' * 40 + '------------' + '<' * 40)
else:
if 0 <= index < len(self.meta_archs):
if index not in self.evaluated_indexes: print('The {:}-th architecture has not been evaluated or not saved.'.format(index))
else:
arch_info = self.arch2infos_dict[index]
for key, result in self.arch2infos_dict[index].items():
strings = print_information(result)
print('>' * 40 + ' {:03d} epochs '.format(result.get_total_epoch()) + '>' * 40)
print('\n'.join(strings))
print('<' * 40 + '------------' + '<' * 40)
else:
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))
def statistics(self, dataset: Text, hp: Union[Text, int]) -> Dict[int, int]:
"""This function will count the number of total trials."""
if self.verbose:
print('Call the statistics function with dataset={:} and hp={:}.'.format(dataset, hp))
valid_datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
if dataset not in valid_datasets:
raise ValueError('{:} not in {:}'.format(dataset, valid_datasets))
nums, hp = defaultdict(lambda: 0), str(hp)
for index in range(len(self)):
archInfo = self.arch2infos_dict[index][hp]
dataset_seed = archInfo.dataset_seed
if dataset not in dataset_seed:
nums[0] += 1
else:
nums[len(dataset_seed[dataset])] += 1
return dict(nums)
class ArchResults(object):
def __init__(self, arch_index, arch_str):
self.arch_index = int(arch_index)
self.arch_str = copy.deepcopy(arch_str)
self.all_results = dict()
self.dataset_seed = dict()
self.clear_net_done = False
def get_compute_costs(self, dataset):
x_seeds = self.dataset_seed[dataset]
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
flops = [result.flop for result in results]
params = [result.params for result in results]
latencies = [result.get_latency() for result in results]
latencies = [x for x in latencies if x > 0]
mean_latency = np.mean(latencies) if len(latencies) > 0 else None
time_infos = defaultdict(list)
for result in results:
time_info = result.get_times()
for key, value in time_info.items(): time_infos[key].append( value )
info = {'flops' : np.mean(flops),
'params' : np.mean(params),
'latency': mean_latency}
for key, value in time_infos.items():
if len(value) > 0 and value[0] is not None:
info[key] = np.mean(value)
else: info[key] = None
return info
def get_metrics(self, dataset, setname, iepoch=None, is_random=False):
"""
This `get_metrics` function is used to obtain obtain the loss, accuracy, etc information on a specific dataset.
If not specify, each set refer to the proposed split in NAS-Bench-201 paper.
If some args return None or raise error, then it is not avaliable.
========================================
Args [dataset] (4 possible options):
-- cifar10-valid : training the model on the CIFAR-10 training set.
-- cifar10 : training the model on the CIFAR-10 training + validation set.
-- cifar100 : training the model on the CIFAR-100 training set.
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
Args [setname] (each dataset has different setnames):
-- When dataset = cifar10-valid, you can use 'train', 'x-valid', 'ori-test'
------ 'train' : the metric on the training set.
------ 'x-valid' : the metric on the validation set.
------ 'ori-test' : the metric on the test set.
-- When dataset = cifar10, you can use 'train', 'ori-test'.
------ 'train' : the metric on the training + validation set.
------ 'ori-test' : the metric on the test set.
-- When dataset = cifar100 or ImageNet16-120, you can use 'train', 'ori-test', 'x-valid', 'x-test'
------ 'train' : the metric on the training set.
------ 'x-valid' : the metric on the validation set.
------ 'x-test' : the metric on the test set.
------ 'ori-test' : the metric on the validation + test set.
Args [iepoch] (None or an integer in [0, the-number-of-total-training-epochs)
------ None : return the metric after the last training epoch.
------ an integer i : return the metric after the i-th training epoch.
Args [is_random]:
------ True : return the metric of a randomly selected trial.
------ False : return the averaged metric of all avaliable trials.
------ an integer indicating the 'seed' value : return the metric of a specific trial (whose random seed is 'is_random').
"""
x_seeds = self.dataset_seed[dataset]
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
infos = defaultdict(list)
for result in results:
if setname == 'train':
info = result.get_train(iepoch)
else:
info = result.get_eval(setname, iepoch)
for key, value in info.items(): infos[key].append( value )
return_info = dict()
if isinstance(is_random, bool) and is_random: # randomly select one
index = random.randint(0, len(results)-1)
for key, value in infos.items(): return_info[key] = value[index]
elif isinstance(is_random, bool) and not is_random: # average
for key, value in infos.items():
if len(value) > 0 and value[0] is not None:
return_info[key] = np.mean(value)
else: return_info[key] = None
elif isinstance(is_random, int): # specify the seed
if is_random not in x_seeds: raise ValueError('can not find random seed ({:}) from {:}'.format(is_random, x_seeds))
index = x_seeds.index(is_random)
for key, value in infos.items(): return_info[key] = value[index]
else:
raise ValueError('invalid value for is_random: {:}'.format(is_random))
return return_info
def show(self, is_print=False):
return print_information(self, None, is_print)
def get_dataset_names(self):
return list(self.dataset_seed.keys())
def get_dataset_seeds(self, dataset):
return copy.deepcopy( self.dataset_seed[dataset] )
def get_net_param(self, dataset: Text, seed: Union[None, int] =None):
"""
This function will return the trained network's weights on the 'dataset'.
:arg
dataset: one of 'cifar10-valid', 'cifar10', 'cifar100', and 'ImageNet16-120'.
seed: an integer indicates the seed value or None that indicates returing all trials.
"""
if seed is None:
x_seeds = self.dataset_seed[dataset]
return {seed: self.all_results[(dataset, seed)].get_net_param() for seed in x_seeds}
else:
xkey = (dataset, seed)
if xkey in self.all_results:
return self.all_results[xkey].get_net_param()
else:
raise ValueError('key={:} not in {:}'.format(xkey, list(self.all_results.keys())))
def reset_latency(self, dataset: Text, seed: Union[None, Text], latency: float) -> None:
"""This function is used to reset the latency in all corresponding ResultsCount(s)."""
if seed is None:
for seed in self.dataset_seed[dataset]:
self.all_results[(dataset, seed)].update_latency([latency])
else:
self.all_results[(dataset, seed)].update_latency([latency])
def reset_pseudo_train_times(self, dataset: Text, seed: Union[None, Text], estimated_per_epoch_time: float) -> None:
"""This function is used to reset the train-times in all corresponding ResultsCount(s)."""
if seed is None:
for seed in self.dataset_seed[dataset]:
self.all_results[(dataset, seed)].reset_pseudo_train_times(estimated_per_epoch_time)
else:
self.all_results[(dataset, seed)].reset_pseudo_train_times(estimated_per_epoch_time)
def reset_pseudo_eval_times(self, dataset: Text, seed: Union[None, Text], eval_name: Text, estimated_per_epoch_time: float) -> None:
"""This function is used to reset the eval-times in all corresponding ResultsCount(s)."""
if seed is None:
for seed in self.dataset_seed[dataset]:
self.all_results[(dataset, seed)].reset_pseudo_eval_times(eval_name, estimated_per_epoch_time)
else:
self.all_results[(dataset, seed)].reset_pseudo_eval_times(eval_name, estimated_per_epoch_time)
def get_latency(self, dataset: Text) -> float:
"""Get the latency of a model on the target dataset. [Timestamp: 2020.03.09]"""
latencies = []
for seed in self.dataset_seed[dataset]:
latency = self.all_results[(dataset, seed)].get_latency()
if not isinstance(latency, float) or latency <= 0:
raise ValueError('invalid latency of {:} with seed={:} : {:}'.format(dataset, seed, latency))
latencies.append(latency)
return sum(latencies) / len(latencies)
def get_total_epoch(self, dataset=None):
"""Return the total number of training epochs."""
if dataset is None:
epochss = []
for xdata, x_seeds in self.dataset_seed.items():
epochss += [self.all_results[(xdata, seed)].get_total_epoch() for seed in x_seeds]
elif isinstance(dataset, str):
x_seeds = self.dataset_seed[dataset]
epochss = [self.all_results[(dataset, seed)].get_total_epoch() for seed in x_seeds]
else:
raise ValueError('invalid dataset={:}'.format(dataset))
if len(set(epochss)) > 1: raise ValueError('Each trial mush have the same number of training epochs : {:}'.format(epochss))
return epochss[-1]
def query(self, dataset, seed=None):
"""Return the ResultsCount object (containing all information of a single trial) for 'dataset' and 'seed'"""
if seed is None:
#print(self.dataset_seed.keys())
#print(dataset)
x_seeds = self.dataset_seed[dataset]
return {seed: self.all_results[(dataset, seed)] for seed in x_seeds}
else:
return self.all_results[(dataset, seed)]
def arch_idx_str(self):
return '{:06d}'.format(self.arch_index)
def update(self, dataset_name, seed, result):
if dataset_name not in self.dataset_seed:
self.dataset_seed[dataset_name] = []
assert seed not in self.dataset_seed[dataset_name], '{:}-th arch alreadly has this seed ({:}) on {:}'.format(self.arch_index, seed, dataset_name)
self.dataset_seed[ dataset_name ].append( seed )
self.dataset_seed[ dataset_name ] = sorted( self.dataset_seed[ dataset_name ] )
assert (dataset_name, seed) not in self.all_results
self.all_results[ (dataset_name, seed) ] = result
self.clear_net_done = False
def state_dict(self):
state_dict = dict()
for key, value in self.__dict__.items():
if key == 'all_results': # contain the class of ResultsCount
xvalue = dict()
assert isinstance(value, dict), 'invalid type of value for {:} : {:}'.format(key, type(value))
for _k, _v in value.items():
assert isinstance(_v, ResultsCount), 'invalid type of value for {:}/{:} : {:}'.format(key, _k, type(_v))
xvalue[_k] = _v.state_dict()
else:
xvalue = value
state_dict[key] = xvalue
return state_dict
def load_state_dict(self, state_dict):
new_state_dict = dict()
for key, value in state_dict.items():
if key == 'all_results': # to convert to the class of ResultsCount
xvalue = dict()
assert isinstance(value, dict), 'invalid type of value for {:} : {:}'.format(key, type(value))
for _k, _v in value.items():
xvalue[_k] = ResultsCount.create_from_state_dict(_v)
else: xvalue = value
new_state_dict[key] = xvalue
self.__dict__.update(new_state_dict)
@staticmethod
def create_from_state_dict(state_dict_or_file):
x = ArchResults(-1, -1)
if isinstance(state_dict_or_file, str): # a file path
state_dict = torch.load(state_dict_or_file, map_location='cpu')
elif isinstance(state_dict_or_file, dict):
state_dict = state_dict_or_file
else:
raise ValueError('invalid type of state_dict_or_file : {:}'.format(type(state_dict_or_file)))
x.load_state_dict(state_dict)
return x
# This function is used to clear the weights saved in each 'result'
# This can help reduce the memory footprint.
def clear_params(self):
for key, result in self.all_results.items():
del result.net_state_dict
result.net_state_dict = None
self.clear_net_done = True
def debug_test(self):
"""This function is used for me to debug and test, which will call most methods."""
all_dataset = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
for dataset in all_dataset:
print('---->>>> {:}'.format(dataset))
print('The latency on {:} is {:} s'.format(dataset, self.get_latency(dataset)))
for seed in self.dataset_seed[dataset]:
result = self.all_results[(dataset, seed)]
print(' ==>> result = {:}'.format(result))
print(' ==>> cost = {:}'.format(result.get_times()))
def __repr__(self):
return ('{name}(arch-index={index}, arch={arch}, {num} runs, clear={clear})'.format(name=self.__class__.__name__, index=self.arch_index, arch=self.arch_str, num=len(self.all_results), clear=self.clear_net_done))
"""
This class (ResultsCount) is used to save the information of one trial for a single architecture.
I did not write much comment for this class, because it is the lowest-level class in NAS-Bench-201 API, which will be rarely called.
If you have any question regarding this class, please open an issue or email me.
"""
class ResultsCount(object):
def __init__(self, name, state_dict, train_accs, train_losses, params, flop, arch_config, seed, epochs, latency):
self.name = name
self.net_state_dict = state_dict
self.train_acc1es = copy.deepcopy(train_accs)
self.train_acc5es = None
self.train_losses = copy.deepcopy(train_losses)
self.train_times = None
self.arch_config = copy.deepcopy(arch_config)
self.params = params
self.flop = flop
self.seed = seed
self.epochs = epochs
self.latency = latency
# evaluation results
self.reset_eval()
def update_train_info(self, train_acc1es, train_acc5es, train_losses, train_times) -> None:
self.train_acc1es = train_acc1es
self.train_acc5es = train_acc5es
self.train_losses = train_losses
self.train_times = train_times
def reset_pseudo_train_times(self, estimated_per_epoch_time: float) -> None:
"""Assign the training times."""
train_times = OrderedDict()
for i in range(self.epochs):
train_times[i] = estimated_per_epoch_time
self.train_times = train_times
def reset_pseudo_eval_times(self, eval_name: Text, estimated_per_epoch_time: float) -> None:
"""Assign the evaluation times."""
if eval_name not in self.eval_names: raise ValueError('invalid eval name : {:}'.format(eval_name))
for i in range(self.epochs):
self.eval_times['{:}@{:}'.format(eval_name,i)] = estimated_per_epoch_time
def reset_eval(self):
self.eval_names = []
self.eval_acc1es = {}
self.eval_times = {}
self.eval_losses = {}
def update_latency(self, latency):
self.latency = copy.deepcopy( latency )
def get_latency(self) -> float:
"""Return the latency value in seconds. -1 represents not avaliable ; otherwise it should be a float value"""
if self.latency is None: return -1.0
else: return sum(self.latency) / len(self.latency)
def update_eval(self, accs, losses, times): # new version
data_names = set([x.split('@')[0] for x in accs.keys()])
for data_name in data_names:
assert data_name not in self.eval_names, '{:} has already been added into eval-names'.format(data_name)
self.eval_names.append( data_name )
for iepoch in range(self.epochs):
xkey = '{:}@{:}'.format(data_name, iepoch)
self.eval_acc1es[ xkey ] = accs[ xkey ]
self.eval_losses[ xkey ] = losses[ xkey ]
self.eval_times [ xkey ] = times[ xkey ]
def update_OLD_eval(self, name, accs, losses): # old version
assert name not in self.eval_names, '{:} has already added'.format(name)
self.eval_names.append( name )
for iepoch in range(self.epochs):
if iepoch in accs:
self.eval_acc1es['{:}@{:}'.format(name,iepoch)] = accs[iepoch]
self.eval_losses['{:}@{:}'.format(name,iepoch)] = losses[iepoch]
def __repr__(self):
num_eval = len(self.eval_names)
set_name = '[' + ', '.join(self.eval_names) + ']'
return ('{name}({xname}, arch={arch}, FLOP={flop:.2f}M, Param={param:.3f}MB, seed={seed}, {num_eval} eval-sets: {set_name})'.format(name=self.__class__.__name__, xname=self.name, arch=self.arch_config['arch_str'], flop=self.flop, param=self.params, seed=self.seed, num_eval=num_eval, set_name=set_name))
def get_total_epoch(self):
return copy.deepcopy(self.epochs)
def get_times(self):
"""Obtain the information regarding both training and evaluation time."""
if self.train_times is not None and isinstance(self.train_times, dict):
train_times = list( self.train_times.values() )
time_info = {'T-train@epoch': np.mean(train_times), 'T-train@total': np.sum(train_times)}
else:
time_info = {'T-train@epoch': None, 'T-train@total': None }
for name in self.eval_names:
try:
xtimes = [self.eval_times['{:}@{:}'.format(name,i)] for i in range(self.epochs)]
time_info['T-{:}@epoch'.format(name)] = np.mean(xtimes)
time_info['T-{:}@total'.format(name)] = np.sum(xtimes)
except:
time_info['T-{:}@epoch'.format(name)] = None
time_info['T-{:}@total'.format(name)] = None
return time_info
def get_eval_set(self):
return self.eval_names
# get the training information
def get_train(self, iepoch=None):
if iepoch is None: iepoch = self.epochs-1
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
if self.train_times is not None:
xtime = self.train_times[iepoch]
atime = sum([self.train_times[i] for i in range(iepoch+1)])
else: xtime, atime = None, None
return {'iepoch' : iepoch,
'loss' : self.train_losses[iepoch],
'accuracy': self.train_acc1es[iepoch],
'cur_time': xtime,
'all_time': atime}
def get_eval(self, name, iepoch=None):
"""Get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument)."""
if iepoch is None: iepoch = self.epochs-1
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
def _internal_query(xname):
if isinstance(self.eval_times,dict) and len(self.eval_times) > 0:
xtime = self.eval_times['{:}@{:}'.format(xname, iepoch)]
atime = sum([self.eval_times['{:}@{:}'.format(xname, i)] for i in range(iepoch+1)])
else:
xtime, atime = None, None
return {'iepoch' : iepoch,
'loss' : self.eval_losses['{:}@{:}'.format(xname, iepoch)],
'accuracy': self.eval_acc1es['{:}@{:}'.format(xname, iepoch)],
'cur_time': xtime,
'all_time': atime}
if name == 'valid':
return _internal_query('x-valid')
else:
return _internal_query(name)
def get_net_param(self, clone=False):
if clone: return copy.deepcopy(self.net_state_dict)
else: return self.net_state_dict
def get_config(self, str2structure):
"""This function is used to obtain the config dict for this architecture."""
if str2structure is None:
# In this case, this is to handle the size search space.
if 'name' in self.arch_config and self.arch_config['name'] == 'infer.shape.tiny':
return {'name': 'infer.shape.tiny', 'channels': self.arch_config['channels'],
'genotype': self.arch_config['genotype'], 'num_classes': self.arch_config['class_num']}
# In this case, this is NAS-Bench-201
else:
return {'name': 'infer.tiny', 'C': self.arch_config['channel'],
'N' : self.arch_config['num_cells'],
'arch_str': self.arch_config['arch_str'], 'num_classes': self.arch_config['class_num']}
else:
# In this case, this is to handle the size search space.
if 'name' in self.arch_config and self.arch_config['name'] == 'infer.shape.tiny':
return {'name': 'infer.shape.tiny', 'channels': self.arch_config['channels'],
'genotype': str2structure(self.arch_config['genotype']), 'num_classes': self.arch_config['class_num']}
# In this case, this is NAS-Bench-201
else:
return {'name': 'infer.tiny', 'C': self.arch_config['channel'],
'N' : self.arch_config['num_cells'],
'genotype': str2structure(self.arch_config['arch_str']), 'num_classes': self.arch_config['class_num']}
def state_dict(self):
_state_dict = {key: value for key, value in self.__dict__.items()}
return _state_dict
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
@staticmethod
def create_from_state_dict(state_dict):
x = ResultsCount(None, None, None, None, None, None, None, None, None, None)
x.load_state_dict(state_dict)
return x

360
nasspace.py Normal file
View File

@@ -0,0 +1,360 @@
from models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API
from nasbench import api as nasbench101api
from nas_101_api.model import Network
from nas_101_api.model_spec import ModelSpec
import itertools
import random
import numpy as np
from models.cell_searchs.genotypes import Structure
from copy import deepcopy
from pycls.models.nas.nas import NetworkImageNet, NetworkCIFAR
from pycls.models.anynet import AnyNet
from pycls.models.nas.genotypes import GENOTYPES, Genotype
import json
import torch
class Nasbench201:
def __init__(self, dataset, apiloc):
self.dataset = dataset
self.api = API(apiloc, verbose=False)
self.epochs = '12'
def get_network(self, uid):
#config = self.api.get_net_config(uid, self.dataset)
config = self.api.get_net_config(uid, 'cifar10-valid')
config['num_classes'] = 1
network = get_cell_based_tiny_net(config)
return network
def __iter__(self):
for uid in range(len(self)):
network = self.get_network(uid)
yield uid, network
def __getitem__(self, index):
return index
def __len__(self):
return 15625
def num_activations(self):
network = self.get_network(0)
return network.classifier.in_features
#def get_12epoch_accuracy(self, uid, acc_type, trainval, traincifar10=False):
# archinfo = self.api.query_meta_info_by_index(uid)
# if (self.dataset == 'cifar10' or traincifar10) and trainval:
# #return archinfo.get_metrics('cifar10-valid', acc_type, iepoch=12)['accuracy']
# return archinfo.get_metrics('cifar10-valid', 'x-valid', iepoch=12)['accuracy']
# elif traincifar10:
# return archinfo.get_metrics('cifar10', acc_type, iepoch=12)['accuracy']
# else:
# return archinfo.get_metrics(self.dataset, 'ori-test', iepoch=12)['accuracy']
def get_12epoch_accuracy(self, uid, acc_type, trainval, traincifar10=False):
#archinfo = self.api.query_meta_info_by_index(uid)
#if (self.dataset == 'cifar10' and trainval) or traincifar10:
info = self.api.get_more_info(uid, 'cifar10-valid', iepoch=None, hp=self.epochs, is_random=True)
#else:
# info = self.api.get_more_info(uid, self.dataset, iepoch=None, hp=self.epochs, is_random=True)
return info['valid-accuracy']
def get_final_accuracy(self, uid, acc_type, trainval):
#archinfo = self.api.query_meta_info_by_index(uid)
if self.dataset == 'cifar10' and trainval:
info = self.api.query_meta_info_by_index(uid, hp='200').get_metrics('cifar10-valid', 'x-valid')
#info = self.api.query_by_index(uid, 'cifar10-valid', hp='200')
#info = self.api.get_more_info(uid, 'cifar10-valid', iepoch=None, hp='200', is_random=True)
else:
info = self.api.query_meta_info_by_index(uid, hp='200').get_metrics(self.dataset, acc_type)
#info = self.api.query_by_index(uid, self.dataset, hp='200')
#info = self.api.get_more_info(uid, self.dataset, iepoch=None, hp='200', is_random=True)
return info['accuracy']
#return info['valid-accuracy']
#if self.dataset == 'cifar10' and trainval:
# return archinfo.get_metrics('cifar10-valid', acc_type, iepoch=11)['accuracy']
#else:
# #return archinfo.get_metrics(self.dataset, 'ori-test', iepoch=12)['accuracy']
# return archinfo.get_metrics(self.dataset, 'x-test', iepoch=11)['accuracy']
##dataset = self.dataset
##if self.dataset == 'cifar10' and trainval:
## dataset = 'cifar10-valid'
##archinfo = self.api.get_more_info(uid, dataset, iepoch=None, use_12epochs_result=True, is_random=True)
##return archinfo['valid-accuracy']
def get_accuracy(self, uid, acc_type, trainval=True):
archinfo = self.api.query_meta_info_by_index(uid)
if self.dataset == 'cifar10' and trainval:
return archinfo.get_metrics('cifar10-valid', acc_type)['accuracy']
else:
return archinfo.get_metrics(self.dataset, acc_type)['accuracy']
def get_accuracy_for_all_datasets(self, uid):
archinfo = self.api.query_meta_info_by_index(uid,hp='200')
c10 = archinfo.get_metrics('cifar10', 'ori-test')['accuracy']
c10_val = archinfo.get_metrics('cifar10-valid', 'x-valid')['accuracy']
c100 = archinfo.get_metrics('cifar100', 'x-test')['accuracy']
c100_val = archinfo.get_metrics('cifar100', 'x-valid')['accuracy']
imagenet = archinfo.get_metrics('ImageNet16-120', 'x-test')['accuracy']
imagenet_val = archinfo.get_metrics('ImageNet16-120', 'x-valid')['accuracy']
return c10, c10_val, c100, c100_val, imagenet, imagenet_val
#def train_and_eval(self, arch, dataname, acc_type, trainval=True):
# unique_hash = self.__getitem__(arch)
# time = self.get_training_time(unique_hash)
# acc12 = self.get_12epoch_accuracy(unique_hash, acc_type, trainval)
# acc = self.get_final_accuracy(unique_hash, acc_type, trainval)
# return acc12, acc, time
def train_and_eval(self, arch, dataname, acc_type, trainval=True, traincifar10=False):
unique_hash = self.__getitem__(arch)
time = self.get_training_time(unique_hash)
acc12 = self.get_12epoch_accuracy(unique_hash, acc_type, trainval, traincifar10)
acc = self.get_final_accuracy(unique_hash, acc_type, trainval)
return acc12, acc, time
def random_arch(self):
return random.randint(0, len(self)-1)
def get_training_time(self, unique_hash):
#info = self.api.get_more_info(unique_hash, 'cifar10-valid' if self.dataset == 'cifar10' else self.dataset, iepoch=None, use_12epochs_result=True, is_random=True)
#info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, use_12epochs_result=True, is_random=True)
info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, hp='12', is_random=True)
return info['train-all-time'] + info['valid-per-time']
#if self.dataset == 'cifar10' and trainval:
# info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, hp=self.epochs, is_random=True)
#else:
# info = self.api.get_more_info(unique_hash, self.dataset, iepoch=None, hp=self.epochs, is_random=True)
##info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, use_12epochs_result=True, is_random=True)
#return info['train-all-time'] + info['valid-per-time']
def mutate_arch(self, arch):
op_names = get_search_spaces('cell', 'nas-bench-201')
#config = self.api.get_net_config(arch, self.dataset)
config = self.api.get_net_config(arch, 'cifar10-valid')
parent_arch = Structure(self.api.str2lists(config['arch_str']))
child_arch = deepcopy( parent_arch )
node_id = random.randint(0, len(child_arch.nodes)-1)
node_info = list( child_arch.nodes[node_id] )
snode_id = random.randint(0, len(node_info)-1)
xop = random.choice( op_names )
while xop == node_info[snode_id][0]:
xop = random.choice( op_names )
node_info[snode_id] = (xop, node_info[snode_id][1])
child_arch.nodes[node_id] = tuple( node_info )
arch_index = self.api.query_index_by_arch( child_arch )
return arch_index
class Nasbench101:
def __init__(self, dataset, apiloc, args):
self.dataset = dataset
self.api = nasbench101api.NASBench(apiloc)
self.args = args
def get_accuracy(self, unique_hash, acc_type, trainval=True):
spec = self.get_spec(unique_hash)
_, stats = self.api.get_metrics_from_spec(spec)
maxacc = 0.
for ep in stats:
for statmap in stats[ep]:
newacc = statmap['final_test_accuracy']
if newacc > maxacc:
maxacc = newacc
return maxacc
def get_final_accuracy(self, uid, acc_type, trainval):
return self.get_accuracy(uid, acc_type, trainval)
def get_training_time(self, unique_hash):
spec = self.get_spec(unique_hash)
_, stats = self.api.get_metrics_from_spec(spec)
maxacc = -1.
maxtime = 0.
for ep in stats:
for statmap in stats[ep]:
newacc = statmap['final_test_accuracy']
if newacc > maxacc:
maxacc = newacc
maxtime = statmap['final_training_time']
return maxtime
def get_network(self, unique_hash):
spec = self.get_spec(unique_hash)
network = Network(spec, self.args)
return network
def get_spec(self, unique_hash):
matrix = self.api.fixed_statistics[unique_hash]['module_adjacency']
operations = self.api.fixed_statistics[unique_hash]['module_operations']
spec = ModelSpec(matrix, operations)
return spec
def __iter__(self):
for unique_hash in self.api.hash_iterator():
network = self.get_network(unique_hash)
yield unique_hash, network
def __getitem__(self, index):
return next(itertools.islice(self.api.hash_iterator(), index, None))
def __len__(self):
return len(self.api.hash_iterator())
def num_activations(self):
for unique_hash in self.api.hash_iterator():
network = self.get_network(unique_hash)
return network.classifier.in_features
def train_and_eval(self, arch, dataname, acc_type, trainval=True, traincifar10=False):
unique_hash = self.__getitem__(arch)
time =12.* self.get_training_time(unique_hash)/108.
acc = self.get_accuracy(unique_hash, acc_type, trainval)
return acc, acc, time
def random_arch(self):
return random.randint(0, len(self)-1)
def mutate_arch(self, arch):
unique_hash = self.__getitem__(arch)
matrix = self.api.fixed_statistics[unique_hash]['module_adjacency']
operations = self.api.fixed_statistics[unique_hash]['module_operations']
coords = [ (i, j) for i in range(matrix.shape[0]) for j in range(i+1, matrix.shape[1])]
random.shuffle(coords)
# loop through changes until we find change thats allowed
for i, j in coords:
# try the ops in a particular order
for k in [m for m in np.unique(matrix) if m != matrix[i, j]]:
newmatrix = matrix.copy()
newmatrix[i, j] = k
spec = ModelSpec(newmatrix, operations)
try:
newhash = self.api._hash_spec(spec)
if newhash in self.api.fixed_statistics:
return [n for n, m in enumerate(self.api.fixed_statistics.keys()) if m == newhash][0]
except:
pass
class ReturnFeatureLayer(torch.nn.Module):
def __init__(self, mod):
super(ReturnFeatureLayer, self).__init__()
self.mod = mod
def forward(self, x):
return self.mod(x), x
def return_feature_layer(network, prefix=''):
#for attr_str in dir(network):
# target_attr = getattr(network, attr_str)
# if isinstance(target_attr, torch.nn.Linear):
# setattr(network, attr_str, ReturnFeatureLayer(target_attr))
for n, ch in list(network.named_children()):
if isinstance(ch, torch.nn.Linear):
setattr(network, n, ReturnFeatureLayer(ch))
else:
return_feature_layer(ch, prefix + '\t')
class NDS:
def __init__(self, searchspace):
self.searchspace = searchspace
data = json.load(open(f'nds_data/{searchspace}.json', 'r'))
try:
data = data['top'] + data['mid']
except Exception as e:
pass
self.data = data
def __iter__(self):
for unique_hash in range(len(self)):
network = self.get_network(unique_hash)
yield unique_hash, network
def get_network_config(self, uid):
return self.data[uid]['net']
def get_network_optim_config(self, uid):
return self.data[uid]['optim']
def get_network(self, uid):
netinfo = self.data[uid]
config = netinfo['net']
#print(config)
if 'genotype' in config:
#print('geno')
gen = config['genotype']
genotype = Genotype(normal=gen['normal'], normal_concat=gen['normal_concat'], reduce=gen['reduce'], reduce_concat=gen['reduce_concat'])
if '_in' in self.searchspace:
network = NetworkImageNet(config['width'], 1, config['depth'], config['aux'], genotype)
else:
network = NetworkCIFAR(config['width'], 1, config['depth'], config['aux'], genotype)
network.drop_path_prob = 0.
#print(config)
#print('genotype')
L = config['depth']
else:
if 'bot_muls' in config and 'bms' not in config:
config['bms'] = config['bot_muls']
del config['bot_muls']
if 'num_gs' in config and 'gws' not in config:
config['gws'] = config['num_gs']
del config['num_gs']
config['nc'] = 1
config['se_r'] = None
config['stem_w'] = 12
L = sum(config['ds'])
if 'ResN' in self.searchspace:
config['stem_type'] = 'res_stem_in'
else:
config['stem_type'] = 'simple_stem_in'
#"res_stem_cifar": ResStemCifar,
#"res_stem_in": ResStemIN,
#"simple_stem_in": SimpleStemIN,
if config['block_type'] == 'double_plain_block':
config['block_type'] = 'vanilla_block'
network = AnyNet(**config)
return_feature_layer(network)
return network
def __getitem__(self, index):
return index
def __len__(self):
return len(self.data)
def random_arch(self):
return random.randint(0, len(self.data)-1)
def get_final_accuracy(self, uid, acc_type, trainval):
return 100.-self.data[uid]['test_ep_top1'][-1]
def get_search_space(args):
if args.nasspace == 'nasbench201':
return Nasbench201(args.dataset, args.api_loc)
elif args.nasspace == 'nasbench101':
return Nasbench101(args.dataset, args.api_loc, args)
elif args.nasspace == 'nds_resnet':
return NDS('ResNet')
elif args.nasspace == 'nds_amoeba':
return NDS('Amoeba')
elif args.nasspace == 'nds_amoeba_in':
return NDS('Amoeba_in')
elif args.nasspace == 'nds_darts_in':
return NDS('DARTS_in')
elif args.nasspace == 'nds_darts':
return NDS('DARTS')
elif args.nasspace == 'nds_darts_fix-w-d':
return NDS('DARTS_fix-w-d')
elif args.nasspace == 'nds_darts_lr-wd':
return NDS('DARTS_lr-wd')
elif args.nasspace == 'nds_enas':
return NDS('ENAS')
elif args.nasspace == 'nds_enas_in':
return NDS('ENAS_in')
elif args.nasspace == 'nds_enas_fix-w-d':
return NDS('ENAS_fix-w-d')
elif args.nasspace == 'nds_pnas':
return NDS('PNAS')
elif args.nasspace == 'nds_pnas_fix-w-d':
return NDS('PNAS_fix-w-d')
elif args.nasspace == 'nds_pnas_in':
return NDS('PNAS_in')
elif args.nasspace == 'nds_nasnet':
return NDS('NASNet')
elif args.nasspace == 'nds_nasnet_in':
return NDS('NASNet_in')
elif args.nasspace == 'nds_resnext-a':
return NDS('ResNeXt-A')
elif args.nasspace == 'nds_resnext-a_in':
return NDS('ResNeXt-A_in')
elif args.nasspace == 'nds_resnext-b':
return NDS('ResNeXt-B')
elif args.nasspace == 'nds_resnext-b_in':
return NDS('ResNeXt-B_in')
elif args.nasspace == 'nds_vanilla':
return NDS('Vanilla')
elif args.nasspace == 'nds_vanilla_lr-wd':
return NDS('Vanilla_lr-wd')
elif args.nasspace == 'nds_vanilla_lr-wd_in':
return NDS('Vanilla_lr-wd_in')

View File

@@ -1,144 +0,0 @@
import os
import argparse
import random
import numpy as np
import matplotlib.pyplot as plt
from datasets import get_datasets
from config_utils import load_config
from nas_201_api import NASBench201API as API
from models import get_cell_based_tiny_net
import torch
import torch.nn as nn
def get_batch_jacobian(net, data_loader, device):
data_iterator = iter(data_loader)
x, target = next(data_iterator)
x = x.to(device)
net.zero_grad()
x.requires_grad_(True)
_, y = net(x)
y.backward(torch.ones_like(y))
jacob = x.grad.detach()
return jacob, target.detach()
def plot_hist(jacob, ax, colour):
xx = jacob.reshape(jacob.size(0), -1).cpu().numpy()
corrs = np.corrcoef(xx)
ax.hist(corrs.flatten(), bins=100, color=colour)
def decide_plot(acc, plt_cts, num_rows, boundaries=[60., 70., 80., 90.]):
if acc < boundaries[0]:
plt_col = 0
accrange = f'< {boundaries[0]}%'
elif acc < boundaries[1]:
plt_col = 1
accrange = f'[{boundaries[0]}% , {boundaries[1]}%)'
elif acc < boundaries[2]:
plt_col = 2
accrange = f'[{boundaries[1]}% , {boundaries[2]}%)'
elif acc < boundaries[3]:
accrange = f'[{boundaries[2]}% , {boundaries[3]}%)'
plt_col = 3
else:
accrange = f'>= {boundaries[3]}%'
plt_col = 4
can_plot = False
plt_row = 0
if plt_cts[plt_col] < num_rows:
can_plot = True
plt_row = plt_cts[plt_col]
plt_cts[plt_col] += 1
return can_plot, plt_row, plt_col, accrange
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Plot histograms of correlation matrix')
parser.add_argument('--data_loc', default='../datasets/cifar/', type=str, help='dataset folder')
parser.add_argument('--api_loc', default='NAS-Bench-201-v1_1-096897.pth',
type=str, help='path to API')
parser.add_argument('--arch_start', default=0, type=int)
parser.add_argument('--arch_end', default=15625, type=int)
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--GPU', default='0', type=str)
parser.add_argument('--batch_size', default=256, type=int)
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU
# Reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ARCH_START = args.arch_start
ARCH_END = args.arch_end
criterion = nn.CrossEntropyLoss()
train_data, valid_data, xshape, class_num = get_datasets('cifar10', args.data_loc, 0)
cifar_split = load_config('config_utils/cifar-split.txt', None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,
num_workers=0, pin_memory=True, sampler= torch.utils.data.sampler.SubsetRandomSampler(train_split))
scores = []
accs = []
plot_shape = (25, 5)
num_plots = plot_shape[0]*plot_shape[1]
fig, axes = plt.subplots(*plot_shape, sharex=True, figsize=(9, 9) )
plt_cts = [0 for i in range(plot_shape[1])]
api = API(args.api_loc)
archs = list(range(ARCH_START, ARCH_END))
colours = ['#811F41', '#A92941', '#D15141', '#EF7941', '#F99C4B']
strs = []
random.shuffle(archs)
for arch in archs:
try:
config = api.get_net_config(arch, 'cifar10')
archinfo = api.query_meta_info_by_index(arch)
acc = archinfo.get_metrics('cifar10-valid', 'x-valid')['accuracy']
network = get_cell_based_tiny_net(config)
network = network.to(device)
jacobs, labels = get_batch_jacobian(network, train_loader, device)
boundaries = [60., 70., 80., 90.]
can_plt, row, col, accrange = decide_plot(acc, plt_cts, plot_shape[0], boundaries)
if not can_plt:
continue
axes[row, col].axis('off')
plot_hist(jacobs, axes[row, col], colours[col])
if row == 0:
axes[row, col].set_title(f'{accrange}')
if row + 1 == plot_shape[0]:
axes[row, col].axis('on')
plt.setp(axes[row, col].get_xticklabels(), fontsize=12)
axes[row, col].spines["top"].set_visible(False)
axes[row, col].spines["right"].set_visible(False)
axes[row, col].spines["left"].set_visible(False)
axes[row, col].set_yticks([])
if sum(plt_cts) == num_plots:
plt.tight_layout()
plt.savefig(f'results/histograms_cifar10val_batch{args.batch_size}.png')
plt.show()
break
except Exception as e:
plt_cts[col] -= 1
continue

285
plot_scores.py Normal file
View File

@@ -0,0 +1,285 @@
import argparse
import random
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mp
import matplotlib
matplotlib.use('Agg')
from decimal import Decimal
from scipy.special import logit, expit
from scipy import stats
import seaborn as sns
'''
font = {
'size' : 18}
matplotlib.rc('font', **font)
'''
SMALL_SIZE = 10
MEDIUM_SIZE = 12
BIGGER_SIZE = 14
plt.rc('font', size=MEDIUM_SIZE) # controls default text sizes
plt.rc('axes', titlesize=BIGGER_SIZE) # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE) # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE) # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE) # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title
parser = argparse.ArgumentParser(description='NAS Without Training')
parser.add_argument('--data_loc', default='../cifardata/', type=str, help='dataset folder')
parser.add_argument('--api_loc', default='../NAS-Bench-201-v1_0-e61699.pth',
type=str, help='path to API')
parser.add_argument('--save_loc', default='results', type=str, help='folder to save results')
parser.add_argument('--save_string', default='naswot', type=str, help='prefix of results file')
parser.add_argument('--score', default='hook_logdet', type=str, help='the score to evaluate')
parser.add_argument('--nasspace', default='nasbench201', type=str, help='the nas search space to use')
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--repeat', default=1, type=int, help='how often to repeat a single image with a batch')
parser.add_argument('--augtype', default='none', type=str, help='which perturbations to use')
parser.add_argument('--sigma', default=0.05, type=float, help='noise level if augtype is "gaussnoise"')
parser.add_argument('--init', default='', type=str)
parser.add_argument('--GPU', default='0', type=str)
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--trainval', action='store_true')
parser.add_argument('--dropout', action='store_true')
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--maxofn', default=1, type=int, help='score is the max of this many evaluations of the network')
parser.add_argument('--n_samples', default=100, type=int)
parser.add_argument('--n_runs', default=500, type=int)
parser.add_argument('--stem_out_channels', default=16, type=int, help='output channels of stem convolution (nasbench101)')
parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules (nasbench101)')
parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack (nasbench101)')
parser.add_argument('--num_labels', default=1, type=int, help='#classes (nasbench101)')
args = parser.parse_args()
print(f'{args.batch_size}')
random.seed(args.seed)
np.random.seed(args.seed)
filename = f'{args.save_loc}/{args.save_string}_{args.score}_{args.nasspace}_{args.dataset}{"_" + args.init + "_" if args.init != "" else args.init}_{"_dropout" if args.dropout else ""}_{args.augtype}_{args.sigma}_{args.repeat}_{args.trainval}_{args.batch_size}_{args.maxofn}_{args.seed}.npy'
accfilename = f'{args.save_loc}/{args.save_string}_accs_{args.nasspace}_{args.dataset}_{args.trainval}.npy'
from matplotlib.colors import hsv_to_rgb
print(filename)
scores = np.load(filename)
accs = np.load(accfilename)
def make_colours_by_hue(h, v=1.):
return [hsv_to_rgb((h1 if h1 < 1. else h1-1., s, v)) for h1, s,v in zip(np.linspace(h, h+0.05, 5), np.linspace(1., .6, 5), np.linspace(0.1, 1., 5))]
print(f'NETWORK accuracy with highest score {accs[np.argmax(scores)]}')
make_colours = lambda cols: [mp.colors.to_rgba(c) for c in cols]
oranges = make_colours(['#811F41', '#A92941', '#D15141', '#EF7941', '#F99C4B'])
blues = make_colours(['#190C30', '#241147', '#34208C', '#4882FA', '#81BAFC'])
print(blues)
print(make_colours_by_hue(0.9))
if args.nasspace == 'nasbench101':
#colours = blues
colours = make_colours_by_hue(0.9)
elif 'darts' in args.nasspace:
#colours = sns.color_palette("BuGn_r", n_colors=5)
colours = make_colours_by_hue(0.0)
elif 'pnas' in args.nasspace:
#colours = sns.color_palette("PuRd", n_colors=5)
colours = make_colours_by_hue(0.1)
elif args.nasspace == 'nasbench201':
#colours = oranges
colours = make_colours_by_hue(0.3)
elif 'enas' in args.nasspace:
#colours = oranges
colours = make_colours_by_hue(0.4)
elif 'resnet' in args.nasspace:
#colours = sns.color_palette("viridis_r", n_colors=5)
colours = make_colours_by_hue(0.5)
elif 'amoeba' in args.nasspace:
#colours = sns.color_palette("viridis_r", n_colors=5)
colours = make_colours_by_hue(0.6)
elif 'nasnet' in args.nasspace:
#colours = sns.color_palette("viridis_r", n_colors=5)
colours = make_colours_by_hue(0.7)
elif 'resnext-b' in args.nasspace:
#colours = sns.color_palette("viridis_r", n_colors=5)
colours = make_colours_by_hue(0.8)
else:
from zlib import crc32
def bytes_to_float(b):
return float(crc32(b) & 0xffffffff) / 2**32
def str_to_float(s, encoding="utf-8"):
return bytes_to_float(s.encode(encoding))
#colours = sns.color_palette("Purples_r", n_colors=5)
colours = make_colours_by_hue(str_to_float(args.nasspace))
def make_colordict(colours, points):
cdict = {'red': [[pt, colour[0], colour[0]] for pt, colour in zip(points, colours)],
'green':[[pt, colour[1], colour[1]] for pt, colour in zip(points, colours)],
'blue':[[pt, colour[2], colour[2]] for pt, colour in zip(points, colours)]}
return cdict
def make_colormap(dataset, space, colours):
if dataset == 'cifar10' and 'resn' in space:
points = [0., 0.85, 0.9, 0.95, 1.0, 1.0]
colours = [colours[0]] + colours
elif dataset == 'cifar10' and 'nds_darts' in space:
points = [0., 0.8, 0.85, 0.9, 0.95, 1.0]
colours = [colours[0]] + colours
elif dataset == 'cifar10' and 'pnas' in space:
points = [0., 0.875, 0.9, 0.925, 0.95, 1.0]
colours = [colours[0]] + colours
elif dataset == 'cifar10':
points = [0., 0.6, 0.7, 0.8, 0.9, 1.0]
colours = [colours[0]] + colours
#cdict = {'red': [[0., colours[0][0], colours[0][0]]] + [[0.1*i + 0.6, colours[i][0], colours[i][0]] for i in range(len(colours))],
# 'green':[[0., colours[0][1], colours[0][1]]] + [[0.1*i + 0.6, colours[i][1], colours[i][1]] for i in range(len(colours))],
# 'blue':[[0., colours[0][2], colours[0][2]]] + [[0.1*i + 0.6, colours[i][2], colours[i][2]] for i in range(len(colours))]}
elif dataset == 'cifar100':
points = [0., 0.3, 0.4, 0.5, 0.6, 0.7, 1.0]
colours = [colours[0]] + colours + [colours[-1]]
#cdict = {'red': [[0., colours[0][0], colours[0][0]]] + [[0.1*i + 0.3, colours[i][0], colours[i][0]] for i in range(len(colours))] + [[1., colours[-1][0], colours[-1][0]]] ,
# 'green':[[0., colours[0][1], colours[0][1]]] + [[0.1*i + 0.3, colours[i][1], colours[i][1]] for i in range(len(colours))] + [[1., colours[-1][1], colours[-1][1]]] ,
# 'blue':[[0., colours[0][2], colours[0][2]]] + [[0.1*i + 0.3, colours[i][2], colours[i][2]] for i in range(len(colours))] + [[1., colours[-1][2], colours[-1][2]]] }
else:
points = [0., 0.1, 0.2, 0.3, 0.4, 1.0]
colours = colours + [colours[-1]]
#cdict = {'red': [[0.1*i, colours[i][0], colours[i][0]] for i in range(len(colours))] + [[1., colours[-1][0], colours[-1][0]]] ,
# 'green': [[0.1*i, colours[i][1], colours[i][1]] for i in range(len(colours))] + [[1., colours[-1][1], colours[-1][1]]] ,
# 'blue': [[0.1*i, colours[i][2], colours[i][2]] for i in range(len(colours))] + [[1., colours[-1][2], colours[-1][2]]] }
cdict = make_colordict(colours, points)
return cdict
cdict = make_colormap(args.dataset, args.nasspace, colours)
newcmp = mp.colors.LinearSegmentedColormap('testCmap', segmentdata=cdict, N=256)
if args.nasspace == 'nasbench101':
accs = accs[:10000]
scores = scores[:10000]
inds = accs > 0.5
accs = accs[inds]
scores = scores[inds]
print(accs.shape)
elif args.nasspace == 'nds_amoeba' or args.nasspace == 'nds_darts_fix-w-d':
print(accs.shape)
inds = accs > 15.
accs = accs[inds]
scores = scores[inds]
print(accs.shape)
elif args.nasspace == 'nds_darts':
inds = accs > 15.
from nasspace import get_search_space
searchspace = get_search_space(args)
accs = accs[inds]
scores = scores[inds]
print(accs.shape)
else:
print(accs.shape)
inds = accs > 15.
accs = accs[inds]
scores = scores[inds]
print(accs.shape)
inds = scores == 0.
accs = accs[~inds]
scores = scores[~inds]
if accs.size > 1000:
inds = np.random.choice(accs.size, 1000, replace=False)
accs = accs[inds]
scores = scores[inds]
inds = np.isnan(scores)
accs = accs[~inds]
scores = scores[~inds]
tau, p = stats.kendalltau(accs, scores)
if args.nasspace == 'nasbench101':
fig, ax = plt.subplots(1, 1, figsize=(5,5))
else:
fig, ax = plt.subplots(1, 1, figsize=(5,5))
def scale(x):
return 2.**(10*x) - 1.
if args.score == 'svd':
score_scale = lambda x: 10.0**x
else:
score_scale = lambda x: x
if args.nasspace == 'nonetwork':
ax.scatter(scale(accs/100.), score_scale(scores), c=newcmp(accs/100., depths))
else:
ax.scatter(scale(accs/100. if args.nasspace == 'nasbench201' or 'nds' in args.nasspace else accs), score_scale(scores), c=newcmp(accs/100. if args.nasspace == 'nasbench201' or 'nds' in args.nasspace else accs))
if args.dataset == 'cifar100':
ax.set_xticks([scale(float(a)/100.) for a in [40, 60, 70]])
ax.set_xticklabels([f'{a}' for a in [40, 60, 70]])
elif args.dataset == 'imagenette2':
ax.set_xticks([scale(float(a)/100.) for a in [40, 50, 60, 70]])
ax.set_xticklabels([f'{a}' for a in [40, 50, 60, 70]])
elif args.dataset == 'ImageNet16-120':
ax.set_xticks([scale(float(a)/100.) for a in [20, 30, 40, 45]])
ax.set_xticklabels([f'{a}' for a in [20, 30, 40, 45]])
elif args.nasspace == 'nasbench101' and args.dataset == 'cifar10':
ax.set_xticks([scale(float(a)/100.) for a in [50, 80, 90, 95]])
ax.set_xticklabels([f'{a}' for a in [50, 80, 90, 95]])
elif args.nasspace == 'nasbench201' and args.dataset == 'cifar10' and args.score == 'svd':
ax.set_xticks([scale(float(a)/100.) for a in [50, 80, 90, 95]])
ax.set_xticklabels([f'{a}' for a in [50, 80, 90, 95]])
elif 'nds_resne' in args.nasspace and args.dataset == 'cifar10':
ax.set_xticks([scale(float(a)/100.) for a in [85, 88, 91, 94]])
ax.set_xticklabels([f'{a}' for a in [85, 88, 91, 94]])
elif args.nasspace == 'nds_darts' and args.dataset == 'cifar10':
ax.set_xticks([scale(float(a)/100.) for a in [80, 85, 90, 95]])
ax.set_xticklabels([f'{a}' for a in [80, 85, 90, 95]])
elif args.nasspace == 'nds_pnas' and args.dataset == 'cifar10':
ax.set_xticks([scale(float(a)/100.) for a in [90., 91.5, 93, 94.5]])
ax.set_xticklabels([f'{a}' for a in [90., 91.5, 93, 94.5]])
else:
ax.set_xticks([scale(float(a)/100.) for a in [50, 80, 90]])
ax.set_xticklabels([f'{a}' for a in [50, 80, 90]])
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
nasspacenames = {
'nds_resnext-a_in': 'NDS-ResNeXt-A(ImageNet)',
'nds_resnext-b_in': 'NDS-ResNeXt-B(ImageNet)',
'nds_resnext-a': 'NDS-ResNeXt-A(CIFAR10)',
'nds_resnext-b': 'NDS-ResNeXt-B(CIFAR10)',
'nds_nasnet': 'NDS-NASNet(CIFAR10)',
'nds_nasnet_in': 'NDS-NASNet(ImageNet)',
'nds_enas': 'NDS-ENAS(CIFAR10)',
'nds_enas_in': 'NDS-ENAS(ImageNet)',
'nds_amoeba': 'NDS-Amoeba(CIFAR10)',
'nds_amoeba_in': 'NDS-Amoeba(ImageNet)',
'nds_resnet': 'NDS-ResNet(CIFAR10)',
'nds_pnas': 'NDS-PNAS(CIFAR10)',
'nds_pnas_in': 'NDS-PNAS(ImageNet)',
'nds_darts': 'NDS-DARTS(CIFAR10)',
'nds_darts_in': 'NDS-DARTS(ImageNet)',
'nds_darts_fix-w-d': 'NDS-DARTS fixed width/depth (CIFAR10)',
'nds_darts_in_fix-w-d': 'NDS-DARTS fixed width/depth (ImageNet)',
'nds_darts_in': 'NDS-DARTS(ImageNet)',
'nasbench101': 'NAS-Bench-101',
'nasbench201': 'NAS-Bench-201'
}
ax.set_ylabel('Score')
ax.set_xlabel(f'{"Test" if not args.trainval else "Validation"} accuracy')
ax.set_title(f'{nasspacenames[args.nasspace]} {args.dataset} \n $\\tau=${tau:.3f}')
filename = f'{args.save_loc}/{args.save_string}_{args.score}_{args.nasspace}_{args.dataset}{"_" + args.init + "_" if args.init != "" else args.init}{"_dropout" if args.dropout else ""}_{args.augtype}_{args.sigma}_{args.repeat}_{args.trainval}_{args.batch_size}_{args.maxofn}_{args.seed}'
print(filename)
plt.tight_layout()
plt.savefig(filename + '.pdf')
plt.savefig(filename + '.png')
plt.show()

View File

@@ -1,87 +0,0 @@
import numpy as np
import argparse
import os
import random
import pandas as pd
from collections import OrderedDict
import tabulate
parser = argparse.ArgumentParser(description='Produce tables')
parser.add_argument('--data_loc', default='../datasets/cifar/', type=str, help='dataset folder')
parser.add_argument('--save_loc', default='results', type=str, help='folder to save results')
parser.add_argument('--batch_size', default=256, type=int)
parser.add_argument('--GPU', default='0', type=str)
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--trainval', action='store_true')
parser.add_argument('--n_runs', default=500, type=int)
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU
from statistics import mean, median, stdev as std
import torch
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
df = []
datasets = OrderedDict()
datasets['CIFAR-10 (val)'] = ('cifar10-valid', 'x-valid', True)
datasets['CIFAR-10 (test)'] = ('cifar10', 'ori-test', False)
### CIFAR-100
datasets['CIFAR-100 (val)'] = ('cifar100', 'x-valid', False)
datasets['CIFAR-100 (test)'] = ('cifar100', 'x-test', False)
datasets['ImageNet16-120 (val)'] = ('ImageNet16-120', 'x-valid', False)
datasets['ImageNet16-120 (test)'] = ('ImageNet16-120', 'x-test', False)
dataset_top1s = OrderedDict()
for n_samples in [10, 100]:
method = f"Ours (N={n_samples})"
time = 0.
for dataset, params in datasets.items():
top1s = []
dset = params[0]
acc_type = 'accs' if 'test' in params[1] else 'val_accs'
filename = f"{args.save_loc}/{dset}_{args.n_runs}_{n_samples}_{args.seed}.t7"
full_scores = torch.load(filename)
if dataset == 'CIFAR-10 (test)':
time = median(full_scores['times'])
time = f"{time:.2f}"
accs = []
for n in range(args.n_runs):
acc = full_scores[acc_type][n]
accs.append(acc)
dataset_top1s[dataset] = accs
cifar10_val = f"{mean(dataset_top1s['CIFAR-10 (val)']):.2f} +- {std(dataset_top1s['CIFAR-10 (val)']):.2f}"
cifar10_test = f"{mean(dataset_top1s['CIFAR-10 (test)']):.2f} +- {std(dataset_top1s['CIFAR-10 (test)']):.2f}"
cifar100_val = f"{mean(dataset_top1s['CIFAR-100 (val)']):.2f} +- {std(dataset_top1s['CIFAR-100 (val)']):.2f}"
cifar100_test = f"{mean(dataset_top1s['CIFAR-100 (test)']):.2f} +- {std(dataset_top1s['CIFAR-100 (test)']):.2f}"
imagenet_val = f"{mean(dataset_top1s['ImageNet16-120 (val)']):.2f} +- {std(dataset_top1s['ImageNet16-120 (val)']):.2f}"
imagenet_test = f"{mean(dataset_top1s['ImageNet16-120 (test)']):.2f} +- {std(dataset_top1s['ImageNet16-120 (test)']):.2f}"
df.append([method, time, cifar10_val, cifar10_test, cifar100_val, cifar100_test, imagenet_val, imagenet_test])
df = pd.DataFrame(df, columns=['Method','Search time (s)','CIFAR-10 (val)','CIFAR-10 (test)','CIFAR-100 (val)','CIFAR-100 (test)','ImageNet16-120 (val)','ImageNet16-120 (test)' ])
print(tabulate.tabulate(df.values,df.columns, tablefmt="pipe"))

0
pycls/core/__init__.py Normal file
View File

136
pycls/core/benchmark.py Normal file
View File

@@ -0,0 +1,136 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Benchmarking functions."""
import pycls.core.logging as logging
import pycls.datasets.loader as loader
import torch
from pycls.core.config import cfg
from pycls.core.timer import Timer
logger = logging.get_logger(__name__)
@torch.no_grad()
def compute_time_eval(model):
"""Computes precise model forward test time using dummy data."""
# Use eval mode
model.eval()
# Generate a dummy mini-batch and copy data to GPU
im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS)
if cfg.TASK == "jig":
inputs = torch.rand(batch_size, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
else:
inputs = torch.zeros(batch_size, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
# Compute precise forward pass time
timer = Timer()
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
for cur_iter in range(total_iter):
# Reset the timers after the warmup phase
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
timer.reset()
# Forward
timer.tic()
model(inputs)
torch.cuda.synchronize()
timer.toc()
return timer.average_time
def compute_time_train(model, loss_fun):
"""Computes precise model forward + backward time using dummy data."""
# Use train mode
model.train()
# Generate a dummy mini-batch and copy data to GPU
im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS)
if cfg.TASK == "jig":
inputs = torch.rand(batch_size, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
else:
inputs = torch.rand(batch_size, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
if cfg.TASK in ['col', 'seg']:
labels = torch.zeros(batch_size, im_size, im_size, dtype=torch.int64).cuda(non_blocking=False)
else:
labels = torch.zeros(batch_size, dtype=torch.int64).cuda(non_blocking=False)
# Cache BatchNorm2D running stats
bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
bn_stats = [[bn.running_mean.clone(), bn.running_var.clone()] for bn in bns]
# Compute precise forward backward pass time
fw_timer, bw_timer = Timer(), Timer()
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
for cur_iter in range(total_iter):
# Reset the timers after the warmup phase
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
fw_timer.reset()
bw_timer.reset()
# Forward
fw_timer.tic()
preds = model(inputs)
if isinstance(preds, tuple):
loss = loss_fun(preds[0], labels) + cfg.NAS.AUX_WEIGHT * loss_fun(preds[1], labels)
preds = preds[0]
else:
loss = loss_fun(preds, labels)
torch.cuda.synchronize()
fw_timer.toc()
# Backward
bw_timer.tic()
loss.backward()
torch.cuda.synchronize()
bw_timer.toc()
# Restore BatchNorm2D running stats
for bn, (mean, var) in zip(bns, bn_stats):
bn.running_mean, bn.running_var = mean, var
return fw_timer.average_time, bw_timer.average_time
def compute_time_loader(data_loader):
"""Computes loader time."""
timer = Timer()
loader.shuffle(data_loader, 0)
data_loader_iterator = iter(data_loader)
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
total_iter = min(total_iter, len(data_loader))
for cur_iter in range(total_iter):
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
timer.reset()
timer.tic()
next(data_loader_iterator)
timer.toc()
return timer.average_time
def compute_time_full(model, loss_fun, train_loader, test_loader):
"""Times model and data loader."""
logger.info("Computing model and loader timings...")
# Compute timings
test_fw_time = compute_time_eval(model)
train_fw_time, train_bw_time = compute_time_train(model, loss_fun)
train_fw_bw_time = train_fw_time + train_bw_time
train_loader_time = compute_time_loader(train_loader)
# Output iter timing
iter_times = {
"test_fw_time": test_fw_time,
"train_fw_time": train_fw_time,
"train_bw_time": train_bw_time,
"train_fw_bw_time": train_fw_bw_time,
"train_loader_time": train_loader_time,
}
logger.info(logging.dump_log_data(iter_times, "iter_times"))
# Output epoch timing
epoch_times = {
"test_fw_time": test_fw_time * len(test_loader),
"train_fw_time": train_fw_time * len(train_loader),
"train_bw_time": train_bw_time * len(train_loader),
"train_fw_bw_time": train_fw_bw_time * len(train_loader),
"train_loader_time": train_loader_time * len(train_loader),
}
logger.info(logging.dump_log_data(epoch_times, "epoch_times"))
# Compute data loader overhead (assuming DATA_LOADER.NUM_WORKERS>1)
overhead = max(0, train_loader_time - train_fw_bw_time) / train_fw_bw_time
logger.info("Overhead of data loader is {:.2f}%".format(overhead * 100))

88
pycls/core/builders.py Normal file
View File

@@ -0,0 +1,88 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Model and loss construction functions."""
import torch
from pycls.core.config import cfg
from pycls.models.anynet import AnyNet
from pycls.models.effnet import EffNet
from pycls.models.regnet import RegNet
from pycls.models.resnet import ResNet
from pycls.models.nas.nas import NAS
from pycls.models.nas.nas_search import NAS_Search
from pycls.models.nas_bench.model_builder import NAS_Bench
class LabelSmoothedCrossEntropyLoss(torch.nn.Module):
"""CrossEntropyLoss with label smoothing."""
def __init__(self):
super(LabelSmoothedCrossEntropyLoss, self).__init__()
self.eps = cfg.MODEL.LABEL_SMOOTHING_EPS
self.num_classes = cfg.MODEL.NUM_CLASSES
def forward(self, logits, target):
pred = logits.log_softmax(dim=-1)
with torch.no_grad():
target_dist = torch.ones_like(pred) * self.eps / (self.num_classes - 1)
target_dist.scatter_(-1, target.unsqueeze(-1), 1 - self.eps)
return (-target_dist * pred).sum(dim=-1).mean()
# Supported models
_models = {
"anynet": AnyNet,
"effnet": EffNet,
"resnet": ResNet,
"regnet": RegNet,
"nas": NAS,
"nas_search": NAS_Search,
"nas_bench": NAS_Bench,
}
# Supported loss functions
_loss_funs = {
"cross_entropy": torch.nn.CrossEntropyLoss,
"label_smoothed_cross_entropy": LabelSmoothedCrossEntropyLoss,
}
def get_model():
"""Gets the model class specified in the config."""
err_str = "Model type '{}' not supported"
assert cfg.MODEL.TYPE in _models.keys(), err_str.format(cfg.MODEL.TYPE)
return _models[cfg.MODEL.TYPE]
def get_loss_fun():
"""Gets the loss function class specified in the config."""
err_str = "Loss function type '{}' not supported"
assert cfg.MODEL.LOSS_FUN in _loss_funs.keys(), err_str.format(cfg.TRAIN.LOSS)
return _loss_funs[cfg.MODEL.LOSS_FUN]
def build_model():
"""Builds the model."""
return get_model()()
def build_loss_fun():
"""Build the loss function."""
if cfg.TASK == "seg":
return get_loss_fun()(ignore_index=255)
else:
return get_loss_fun()()
def register_model(name, ctor):
"""Registers a model dynamically."""
_models[name] = ctor
def register_loss_fun(name, ctor):
"""Registers a loss function dynamically."""
_loss_funs[name] = ctor

98
pycls/core/checkpoint.py Normal file
View File

@@ -0,0 +1,98 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Functions that handle saving and loading of checkpoints."""
import os
import pycls.core.distributed as dist
import torch
from pycls.core.config import cfg
# Common prefix for checkpoint file names
_NAME_PREFIX = "model_epoch_"
# Checkpoints directory name
_DIR_NAME = "checkpoints"
def get_checkpoint_dir():
"""Retrieves the location for storing checkpoints."""
return os.path.join(cfg.OUT_DIR, _DIR_NAME)
def get_checkpoint(epoch):
"""Retrieves the path to a checkpoint file."""
name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch)
return os.path.join(get_checkpoint_dir(), name)
def get_last_checkpoint():
"""Retrieves the most recent checkpoint (highest epoch number)."""
checkpoint_dir = get_checkpoint_dir()
# Checkpoint file names are in lexicographic order
checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f]
last_checkpoint_name = sorted(checkpoints)[-1]
return os.path.join(checkpoint_dir, last_checkpoint_name)
def has_checkpoint():
"""Determines if there are checkpoints available."""
checkpoint_dir = get_checkpoint_dir()
if not os.path.exists(checkpoint_dir):
return False
return any(_NAME_PREFIX in f for f in os.listdir(checkpoint_dir))
def save_checkpoint(model, optimizer, epoch):
"""Saves a checkpoint."""
# Save checkpoints only from the master process
if not dist.is_master_proc():
return
# Ensure that the checkpoint dir exists
os.makedirs(get_checkpoint_dir(), exist_ok=True)
# Omit the DDP wrapper in the multi-gpu setting
sd = model.module.state_dict() if cfg.NUM_GPUS > 1 else model.state_dict()
# Record the state
if isinstance(optimizer, list):
checkpoint = {
"epoch": epoch,
"model_state": sd,
"optimizer_w_state": optimizer[0].state_dict(),
"optimizer_a_state": optimizer[1].state_dict(),
"cfg": cfg.dump(),
}
else:
checkpoint = {
"epoch": epoch,
"model_state": sd,
"optimizer_state": optimizer.state_dict(),
"cfg": cfg.dump(),
}
# Write the checkpoint
checkpoint_file = get_checkpoint(epoch + 1)
torch.save(checkpoint, checkpoint_file)
return checkpoint_file
def load_checkpoint(checkpoint_file, model, optimizer=None):
"""Loads the checkpoint from the given file."""
err_str = "Checkpoint '{}' not found"
assert os.path.exists(checkpoint_file), err_str.format(checkpoint_file)
# Load the checkpoint on CPU to avoid GPU mem spike
checkpoint = torch.load(checkpoint_file, map_location="cpu")
# Account for the DDP wrapper in the multi-gpu setting
ms = model.module if cfg.NUM_GPUS > 1 else model
ms.load_state_dict(checkpoint["model_state"])
# Load the optimizer state (commonly not done when fine-tuning)
if optimizer:
if isinstance(optimizer, list):
optimizer[0].load_state_dict(checkpoint["optimizer_w_state"])
optimizer[1].load_state_dict(checkpoint["optimizer_a_state"])
else:
optimizer.load_state_dict(checkpoint["optimizer_state"])
return checkpoint["epoch"]

500
pycls/core/config.py Normal file
View File

@@ -0,0 +1,500 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Configuration file (powered by YACS)."""
import argparse
import os
import sys
from pycls.core.io import cache_url
from yacs.config import CfgNode as CfgNode
# Global config object
_C = CfgNode()
# Example usage:
# from core.config import cfg
cfg = _C
# ------------------------------------------------------------------------------------ #
# Model options
# ------------------------------------------------------------------------------------ #
_C.MODEL = CfgNode()
# Model type
_C.MODEL.TYPE = ""
# Number of weight layers
_C.MODEL.DEPTH = 0
# Number of input channels
_C.MODEL.INPUT_CHANNELS = 3
# Number of classes
_C.MODEL.NUM_CLASSES = 10
# Loss function (see pycls/core/builders.py for options)
_C.MODEL.LOSS_FUN = "cross_entropy"
# Label smoothing eps
_C.MODEL.LABEL_SMOOTHING_EPS = 0.0
# ASPP channels
_C.MODEL.ASPP_CHANNELS = 256
# ASPP dilation rates
_C.MODEL.ASPP_RATES = [6, 12, 18]
# ------------------------------------------------------------------------------------ #
# ResNet options
# ------------------------------------------------------------------------------------ #
_C.RESNET = CfgNode()
# Transformation function (see pycls/models/resnet.py for options)
_C.RESNET.TRANS_FUN = "basic_transform"
# Number of groups to use (1 -> ResNet; > 1 -> ResNeXt)
_C.RESNET.NUM_GROUPS = 1
# Width of each group (64 -> ResNet; 4 -> ResNeXt)
_C.RESNET.WIDTH_PER_GROUP = 64
# Apply stride to 1x1 conv (True -> MSRA; False -> fb.torch)
_C.RESNET.STRIDE_1X1 = True
# ------------------------------------------------------------------------------------ #
# AnyNet options
# ------------------------------------------------------------------------------------ #
_C.ANYNET = CfgNode()
# Stem type
_C.ANYNET.STEM_TYPE = "simple_stem_in"
# Stem width
_C.ANYNET.STEM_W = 32
# Block type
_C.ANYNET.BLOCK_TYPE = "res_bottleneck_block"
# Depth for each stage (number of blocks in the stage)
_C.ANYNET.DEPTHS = []
# Width for each stage (width of each block in the stage)
_C.ANYNET.WIDTHS = []
# Strides for each stage (applies to the first block of each stage)
_C.ANYNET.STRIDES = []
# Bottleneck multipliers for each stage (applies to bottleneck block)
_C.ANYNET.BOT_MULS = []
# Group widths for each stage (applies to bottleneck block)
_C.ANYNET.GROUP_WS = []
# Whether SE is enabled for res_bottleneck_block
_C.ANYNET.SE_ON = False
# SE ratio
_C.ANYNET.SE_R = 0.25
# ------------------------------------------------------------------------------------ #
# RegNet options
# ------------------------------------------------------------------------------------ #
_C.REGNET = CfgNode()
# Stem type
_C.REGNET.STEM_TYPE = "simple_stem_in"
# Stem width
_C.REGNET.STEM_W = 32
# Block type
_C.REGNET.BLOCK_TYPE = "res_bottleneck_block"
# Stride of each stage
_C.REGNET.STRIDE = 2
# Squeeze-and-Excitation (RegNetY)
_C.REGNET.SE_ON = False
_C.REGNET.SE_R = 0.25
# Depth
_C.REGNET.DEPTH = 10
# Initial width
_C.REGNET.W0 = 32
# Slope
_C.REGNET.WA = 5.0
# Quantization
_C.REGNET.WM = 2.5
# Group width
_C.REGNET.GROUP_W = 16
# Bottleneck multiplier (bm = 1 / b from the paper)
_C.REGNET.BOT_MUL = 1.0
# ------------------------------------------------------------------------------------ #
# EfficientNet options
# ------------------------------------------------------------------------------------ #
_C.EN = CfgNode()
# Stem width
_C.EN.STEM_W = 32
# Depth for each stage (number of blocks in the stage)
_C.EN.DEPTHS = []
# Width for each stage (width of each block in the stage)
_C.EN.WIDTHS = []
# Expansion ratios for MBConv blocks in each stage
_C.EN.EXP_RATIOS = []
# Squeeze-and-Excitation (SE) ratio
_C.EN.SE_R = 0.25
# Strides for each stage (applies to the first block of each stage)
_C.EN.STRIDES = []
# Kernel sizes for each stage
_C.EN.KERNELS = []
# Head width
_C.EN.HEAD_W = 1280
# Drop connect ratio
_C.EN.DC_RATIO = 0.0
# Dropout ratio
_C.EN.DROPOUT_RATIO = 0.0
# ---------------------------------------------------------------------------- #
# NAS options
# ---------------------------------------------------------------------------- #
_C.NAS = CfgNode()
# Cell genotype
_C.NAS.GENOTYPE = 'nas'
# Custom genotype
_C.NAS.CUSTOM_GENOTYPE = []
# Base NAS width
_C.NAS.WIDTH = 16
# Total number of cells
_C.NAS.DEPTH = 20
# Auxiliary heads
_C.NAS.AUX = False
# Weight for auxiliary heads
_C.NAS.AUX_WEIGHT = 0.4
# Drop path probability
_C.NAS.DROP_PROB = 0.0
# Matrix in NAS Bench
_C.NAS.MATRIX = []
# Operations in NAS Bench
_C.NAS.OPS = []
# Number of stacks in NAS Bench
_C.NAS.NUM_STACKS = 3
# Number of modules per stack in NAS Bench
_C.NAS.NUM_MODULES_PER_STACK = 3
# ------------------------------------------------------------------------------------ #
# Batch norm options
# ------------------------------------------------------------------------------------ #
_C.BN = CfgNode()
# BN epsilon
_C.BN.EPS = 1e-5
# BN momentum (BN momentum in PyTorch = 1 - BN momentum in Caffe2)
_C.BN.MOM = 0.1
# Precise BN stats
_C.BN.USE_PRECISE_STATS = False
_C.BN.NUM_SAMPLES_PRECISE = 1024
# Initialize the gamma of the final BN of each block to zero
_C.BN.ZERO_INIT_FINAL_GAMMA = False
# Use a different weight decay for BN layers
_C.BN.USE_CUSTOM_WEIGHT_DECAY = False
_C.BN.CUSTOM_WEIGHT_DECAY = 0.0
# ------------------------------------------------------------------------------------ #
# Optimizer options
# ------------------------------------------------------------------------------------ #
_C.OPTIM = CfgNode()
# Base learning rate
_C.OPTIM.BASE_LR = 0.1
# Learning rate policy select from {'cos', 'exp', 'steps'}
_C.OPTIM.LR_POLICY = "cos"
# Exponential decay factor
_C.OPTIM.GAMMA = 0.1
# Steps for 'steps' policy (in epochs)
_C.OPTIM.STEPS = []
# Learning rate multiplier for 'steps' policy
_C.OPTIM.LR_MULT = 0.1
# Maximal number of epochs
_C.OPTIM.MAX_EPOCH = 200
# Momentum
_C.OPTIM.MOMENTUM = 0.9
# Momentum dampening
_C.OPTIM.DAMPENING = 0.0
# Nesterov momentum
_C.OPTIM.NESTEROV = True
# L2 regularization
_C.OPTIM.WEIGHT_DECAY = 5e-4
# Start the warm up from OPTIM.BASE_LR * OPTIM.WARMUP_FACTOR
_C.OPTIM.WARMUP_FACTOR = 0.1
# Gradually warm up the OPTIM.BASE_LR over this number of epochs
_C.OPTIM.WARMUP_EPOCHS = 0
# Update the learning rate per iter
_C.OPTIM.ITER_LR = False
# Base learning rate for arch
_C.OPTIM.ARCH_BASE_LR = 0.0003
# L2 regularization for arch
_C.OPTIM.ARCH_WEIGHT_DECAY = 0.001
# Optimizer for arch
_C.OPTIM.ARCH_OPTIM = 'adam'
# Epoch to start optimizing arch
_C.OPTIM.ARCH_EPOCH = 0.0
# ------------------------------------------------------------------------------------ #
# Training options
# ------------------------------------------------------------------------------------ #
_C.TRAIN = CfgNode()
# Dataset and split
_C.TRAIN.DATASET = ""
_C.TRAIN.SPLIT = "train"
# Total mini-batch size
_C.TRAIN.BATCH_SIZE = 128
# Image size
_C.TRAIN.IM_SIZE = 224
# Evaluate model on test data every eval period epochs
_C.TRAIN.EVAL_PERIOD = 1
# Save model checkpoint every checkpoint period epochs
_C.TRAIN.CHECKPOINT_PERIOD = 1
# Resume training from the latest checkpoint in the output directory
_C.TRAIN.AUTO_RESUME = True
# Weights to start training from
_C.TRAIN.WEIGHTS = ""
# Percentage of gray images in jig
_C.TRAIN.GRAY_PERCENTAGE = 0.0
# Portion to create trainA/trainB split
_C.TRAIN.PORTION = 1.0
# ------------------------------------------------------------------------------------ #
# Testing options
# ------------------------------------------------------------------------------------ #
_C.TEST = CfgNode()
# Dataset and split
_C.TEST.DATASET = ""
_C.TEST.SPLIT = "val"
# Total mini-batch size
_C.TEST.BATCH_SIZE = 200
# Image size
_C.TEST.IM_SIZE = 256
# Weights to use for testing
_C.TEST.WEIGHTS = ""
# ------------------------------------------------------------------------------------ #
# Common train/test data loader options
# ------------------------------------------------------------------------------------ #
_C.DATA_LOADER = CfgNode()
# Number of data loader workers per process
_C.DATA_LOADER.NUM_WORKERS = 8
# Load data to pinned host memory
_C.DATA_LOADER.PIN_MEMORY = True
# ------------------------------------------------------------------------------------ #
# Memory options
# ------------------------------------------------------------------------------------ #
_C.MEM = CfgNode()
# Perform ReLU inplace
_C.MEM.RELU_INPLACE = True
# ------------------------------------------------------------------------------------ #
# CUDNN options
# ------------------------------------------------------------------------------------ #
_C.CUDNN = CfgNode()
# Perform benchmarking to select the fastest CUDNN algorithms to use
# Note that this may increase the memory usage and will likely not result
# in overall speedups when variable size inputs are used (e.g. COCO training)
_C.CUDNN.BENCHMARK = True
# ------------------------------------------------------------------------------------ #
# Precise timing options
# ------------------------------------------------------------------------------------ #
_C.PREC_TIME = CfgNode()
# Number of iterations to warm up the caches
_C.PREC_TIME.WARMUP_ITER = 3
# Number of iterations to compute avg time
_C.PREC_TIME.NUM_ITER = 30
# ------------------------------------------------------------------------------------ #
# Misc options
# ------------------------------------------------------------------------------------ #
# Number of GPUs to use (applies to both training and testing)
_C.NUM_GPUS = 1
# Task (cls, seg, rot, col, jig)
_C.TASK = "cls"
# Grid in Jigsaw (2, 3); no effect if TASK is not jig
_C.JIGSAW_GRID = 3
# Output directory
_C.OUT_DIR = "/tmp"
# Config destination (in OUT_DIR)
_C.CFG_DEST = "config.yaml"
# Note that non-determinism may still be present due to non-deterministic
# operator implementations in GPU operator libraries
_C.RNG_SEED = 1
# Log destination ('stdout' or 'file')
_C.LOG_DEST = "stdout"
# Log period in iters
_C.LOG_PERIOD = 10
# Distributed backend
_C.DIST_BACKEND = "nccl"
# Hostname and port for initializing multi-process groups
_C.HOST = "localhost"
_C.PORT = 10001
# Models weights referred to by URL are downloaded to this local cache
_C.DOWNLOAD_CACHE = "/tmp/pycls-download-cache"
# ------------------------------------------------------------------------------------ #
# Deprecated keys
# ------------------------------------------------------------------------------------ #
_C.register_deprecated_key("PREC_TIME.BATCH_SIZE")
_C.register_deprecated_key("PREC_TIME.ENABLED")
def assert_and_infer_cfg(cache_urls=True):
"""Checks config values invariants."""
err_str = "The first lr step must start at 0"
assert not _C.OPTIM.STEPS or _C.OPTIM.STEPS[0] == 0, err_str
data_splits = ["train", "val", "test"]
err_str = "Data split '{}' not supported"
assert _C.TRAIN.SPLIT in data_splits, err_str.format(_C.TRAIN.SPLIT)
assert _C.TEST.SPLIT in data_splits, err_str.format(_C.TEST.SPLIT)
err_str = "Mini-batch size should be a multiple of NUM_GPUS."
assert _C.TRAIN.BATCH_SIZE % _C.NUM_GPUS == 0, err_str
assert _C.TEST.BATCH_SIZE % _C.NUM_GPUS == 0, err_str
err_str = "Precise BN stats computation not verified for > 1 GPU"
assert not _C.BN.USE_PRECISE_STATS or _C.NUM_GPUS == 1, err_str
err_str = "Log destination '{}' not supported"
assert _C.LOG_DEST in ["stdout", "file"], err_str.format(_C.LOG_DEST)
if cache_urls:
cache_cfg_urls()
def cache_cfg_urls():
"""Download URLs in config, cache them, and rewrite cfg to use cached file."""
_C.TRAIN.WEIGHTS = cache_url(_C.TRAIN.WEIGHTS, _C.DOWNLOAD_CACHE)
_C.TEST.WEIGHTS = cache_url(_C.TEST.WEIGHTS, _C.DOWNLOAD_CACHE)
def dump_cfg():
"""Dumps the config to the output directory."""
cfg_file = os.path.join(_C.OUT_DIR, _C.CFG_DEST)
with open(cfg_file, "w") as f:
_C.dump(stream=f)
def load_cfg(out_dir, cfg_dest="config.yaml"):
"""Loads config from specified output directory."""
cfg_file = os.path.join(out_dir, cfg_dest)
_C.merge_from_file(cfg_file)
def load_cfg_fom_args(description="Config file options."):
"""Load config from command line arguments and set any specified options."""
parser = argparse.ArgumentParser(description=description)
help_s = "Config file location"
parser.add_argument("--cfg", dest="cfg_file", help=help_s, required=True, type=str)
help_s = "See pycls/core/config.py for all options"
parser.add_argument("opts", help=help_s, default=None, nargs=argparse.REMAINDER)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
args = parser.parse_args()
_C.merge_from_file(args.cfg_file)
_C.merge_from_list(args.opts)

157
pycls/core/distributed.py Normal file
View File

@@ -0,0 +1,157 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Distributed helpers."""
import multiprocessing
import os
import signal
import threading
import traceback
import torch
from pycls.core.config import cfg
def is_master_proc():
"""Determines if the current process is the master process.
Master process is responsible for logging, writing and loading checkpoints. In
the multi GPU setting, we assign the master role to the rank 0 process. When
training using a single GPU, there is a single process which is considered master.
"""
return cfg.NUM_GPUS == 1 or torch.distributed.get_rank() == 0
def init_process_group(proc_rank, world_size):
"""Initializes the default process group."""
# Set the GPU to use
torch.cuda.set_device(proc_rank)
# Initialize the process group
torch.distributed.init_process_group(
backend=cfg.DIST_BACKEND,
init_method="tcp://{}:{}".format(cfg.HOST, cfg.PORT),
world_size=world_size,
rank=proc_rank,
)
def destroy_process_group():
"""Destroys the default process group."""
torch.distributed.destroy_process_group()
def scaled_all_reduce(tensors):
"""Performs the scaled all_reduce operation on the provided tensors.
The input tensors are modified in-place. Currently supports only the sum
reduction operator. The reduced values are scaled by the inverse size of the
process group (equivalent to cfg.NUM_GPUS).
"""
# There is no need for reduction in the single-proc case
if cfg.NUM_GPUS == 1:
return tensors
# Queue the reductions
reductions = []
for tensor in tensors:
reduction = torch.distributed.all_reduce(tensor, async_op=True)
reductions.append(reduction)
# Wait for reductions to finish
for reduction in reductions:
reduction.wait()
# Scale the results
for tensor in tensors:
tensor.mul_(1.0 / cfg.NUM_GPUS)
return tensors
class ChildException(Exception):
"""Wraps an exception from a child process."""
def __init__(self, child_trace):
super(ChildException, self).__init__(child_trace)
class ErrorHandler(object):
"""Multiprocessing error handler (based on fairseq's).
Listens for errors in child processes and propagates the tracebacks to the parent.
"""
def __init__(self, error_queue):
# Shared error queue
self.error_queue = error_queue
# Children processes sharing the error queue
self.children_pids = []
# Start a thread listening to errors
self.error_listener = threading.Thread(target=self.listen, daemon=True)
self.error_listener.start()
# Register the signal handler
signal.signal(signal.SIGUSR1, self.signal_handler)
def add_child(self, pid):
"""Registers a child process."""
self.children_pids.append(pid)
def listen(self):
"""Listens for errors in the error queue."""
# Wait until there is an error in the queue
child_trace = self.error_queue.get()
# Put the error back for the signal handler
self.error_queue.put(child_trace)
# Invoke the signal handler
os.kill(os.getpid(), signal.SIGUSR1)
def signal_handler(self, _sig_num, _stack_frame):
"""Signal handler."""
# Kill children processes
for pid in self.children_pids:
os.kill(pid, signal.SIGINT)
# Propagate the error from the child process
raise ChildException(self.error_queue.get())
def run(proc_rank, world_size, error_queue, fun, fun_args, fun_kwargs):
"""Runs a function from a child process."""
try:
# Initialize the process group
init_process_group(proc_rank, world_size)
# Run the function
fun(*fun_args, **fun_kwargs)
except KeyboardInterrupt:
# Killed by the parent process
pass
except Exception:
# Propagate exception to the parent process
error_queue.put(traceback.format_exc())
finally:
# Destroy the process group
destroy_process_group()
def multi_proc_run(num_proc, fun, fun_args=(), fun_kwargs=None):
"""Runs a function in a multi-proc setting (unless num_proc == 1)."""
# There is no need for multi-proc in the single-proc case
fun_kwargs = fun_kwargs if fun_kwargs else {}
if num_proc == 1:
fun(*fun_args, **fun_kwargs)
return
# Handle errors from training subprocesses
error_queue = multiprocessing.SimpleQueue()
error_handler = ErrorHandler(error_queue)
# Run each training subprocess
ps = []
for i in range(num_proc):
p_i = multiprocessing.Process(
target=run, args=(i, num_proc, error_queue, fun, fun_args, fun_kwargs)
)
ps.append(p_i)
p_i.start()
error_handler.add_child(p_i.pid)
# Wait for each subprocess to finish
for p in ps:
p.join()

77
pycls/core/io.py Normal file
View File

@@ -0,0 +1,77 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""IO utilities (adapted from Detectron)"""
import logging
import os
import re
import sys
from urllib import request as urlrequest
logger = logging.getLogger(__name__)
_PYCLS_BASE_URL = "https://dl.fbaipublicfiles.com/pycls"
def cache_url(url_or_file, cache_dir):
"""Download the file specified by the URL to the cache_dir and return the path to
the cached file. If the argument is not a URL, simply return it as is.
"""
is_url = re.match(r"^(?:http)s?://", url_or_file, re.IGNORECASE) is not None
if not is_url:
return url_or_file
url = url_or_file
err_str = "pycls only automatically caches URLs in the pycls S3 bucket: {}"
assert url.startswith(_PYCLS_BASE_URL), err_str.format(_PYCLS_BASE_URL)
cache_file_path = url.replace(_PYCLS_BASE_URL, cache_dir)
if os.path.exists(cache_file_path):
return cache_file_path
cache_file_dir = os.path.dirname(cache_file_path)
if not os.path.exists(cache_file_dir):
os.makedirs(cache_file_dir)
logger.info("Downloading remote file {} to {}".format(url, cache_file_path))
download_url(url, cache_file_path)
return cache_file_path
def _progress_bar(count, total):
"""Report download progress. Credit:
https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console/27871113
"""
bar_len = 60
filled_len = int(round(bar_len * count / float(total)))
percents = round(100.0 * count / float(total), 1)
bar = "=" * filled_len + "-" * (bar_len - filled_len)
sys.stdout.write(
" [{}] {}% of {:.1f}MB file \r".format(bar, percents, total / 1024 / 1024)
)
sys.stdout.flush()
if count >= total:
sys.stdout.write("\n")
def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar):
"""Download url and write it to dst_file_path. Credit:
https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook
"""
req = urlrequest.Request(url)
response = urlrequest.urlopen(req)
total_size = response.info().get("Content-Length").strip()
total_size = int(total_size)
bytes_so_far = 0
with open(dst_file_path, "wb") as f:
while 1:
chunk = response.read(chunk_size)
bytes_so_far += len(chunk)
if not chunk:
break
if progress_hook:
progress_hook(bytes_so_far, total_size)
f.write(chunk)
return bytes_so_far

138
pycls/core/logging.py Normal file
View File

@@ -0,0 +1,138 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Logging."""
import builtins
import decimal
import logging
import os
import sys
import pycls.core.distributed as dist
import simplejson
from pycls.core.config import cfg
# Show filename and line number in logs
_FORMAT = "[%(filename)s: %(lineno)3d]: %(message)s"
# Log file name (for cfg.LOG_DEST = 'file')
_LOG_FILE = "stdout.log"
# Data output with dump_log_data(data, data_type) will be tagged w/ this
_TAG = "json_stats: "
# Data output with dump_log_data(data, data_type) will have data[_TYPE]=data_type
_TYPE = "_type"
def _suppress_print():
"""Suppresses printing from the current process."""
def ignore(*_objects, _sep=" ", _end="\n", _file=sys.stdout, _flush=False):
pass
builtins.print = ignore
def setup_logging():
"""Sets up the logging."""
# Enable logging only for the master process
if dist.is_master_proc():
# Clear the root logger to prevent any existing logging config
# (e.g. set by another module) from messing with our setup
logging.root.handlers = []
# Construct logging configuration
logging_config = {"level": logging.INFO, "format": _FORMAT}
# Log either to stdout or to a file
if cfg.LOG_DEST == "stdout":
logging_config["stream"] = sys.stdout
else:
logging_config["filename"] = os.path.join(cfg.OUT_DIR, _LOG_FILE)
# Configure logging
logging.basicConfig(**logging_config)
else:
_suppress_print()
def get_logger(name):
"""Retrieves the logger."""
return logging.getLogger(name)
def dump_log_data(data, data_type, prec=4):
"""Covert data (a dictionary) into tagged json string for logging."""
data[_TYPE] = data_type
data = float_to_decimal(data, prec)
data_json = simplejson.dumps(data, sort_keys=True, use_decimal=True)
return "{:s}{:s}".format(_TAG, data_json)
def float_to_decimal(data, prec=4):
"""Convert floats to decimals which allows for fixed width json."""
if isinstance(data, dict):
return {k: float_to_decimal(v, prec) for k, v in data.items()}
if isinstance(data, float):
return decimal.Decimal(("{:." + str(prec) + "f}").format(data))
else:
return data
def get_log_files(log_dir, name_filter="", log_file=_LOG_FILE):
"""Get all log files in directory containing subdirs of trained models."""
names = [n for n in sorted(os.listdir(log_dir)) if name_filter in n]
files = [os.path.join(log_dir, n, log_file) for n in names]
f_n_ps = [(f, n) for (f, n) in zip(files, names) if os.path.exists(f)]
files, names = zip(*f_n_ps) if f_n_ps else ([], [])
return files, names
def load_log_data(log_file, data_types_to_skip=()):
"""Loads log data into a dictionary of the form data[data_type][metric][index]."""
# Load log_file
assert os.path.exists(log_file), "Log file not found: {}".format(log_file)
with open(log_file, "r") as f:
lines = f.readlines()
# Extract and parse lines that start with _TAG and have a type specified
lines = [l[l.find(_TAG) + len(_TAG) :] for l in lines if _TAG in l]
lines = [simplejson.loads(l) for l in lines]
lines = [l for l in lines if _TYPE in l and not l[_TYPE] in data_types_to_skip]
# Generate data structure accessed by data[data_type][index][metric]
data_types = [l[_TYPE] for l in lines]
data = {t: [] for t in data_types}
for t, line in zip(data_types, lines):
del line[_TYPE]
data[t].append(line)
# Generate data structure accessed by data[data_type][metric][index]
for t in data:
metrics = sorted(data[t][0].keys())
err_str = "Inconsistent metrics in log for _type={}: {}".format(t, metrics)
assert all(sorted(d.keys()) == metrics for d in data[t]), err_str
data[t] = {m: [d[m] for d in data[t]] for m in metrics}
return data
def sort_log_data(data):
"""Sort each data[data_type][metric] by epoch or keep only first instance."""
for t in data:
if "epoch" in data[t]:
assert "epoch_ind" not in data[t] and "epoch_max" not in data[t]
data[t]["epoch_ind"] = [int(e.split("/")[0]) for e in data[t]["epoch"]]
data[t]["epoch_max"] = [int(e.split("/")[1]) for e in data[t]["epoch"]]
epoch = data[t]["epoch_ind"]
if "iter" in data[t]:
assert "iter_ind" not in data[t] and "iter_max" not in data[t]
data[t]["iter_ind"] = [int(i.split("/")[0]) for i in data[t]["iter"]]
data[t]["iter_max"] = [int(i.split("/")[1]) for i in data[t]["iter"]]
itr = zip(epoch, data[t]["iter_ind"], data[t]["iter_max"])
epoch = [e + (i_ind - 1) / i_max for e, i_ind, i_max in itr]
for m in data[t]:
data[t][m] = [v for _, v in sorted(zip(epoch, data[t][m]))]
else:
data[t] = {m: d[0] for m, d in data[t].items()}
return data

435
pycls/core/meters.py Normal file
View File

@@ -0,0 +1,435 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Meters."""
from collections import deque
import numpy as np
import pycls.core.logging as logging
import torch
from pycls.core.config import cfg
from pycls.core.timer import Timer
logger = logging.get_logger(__name__)
def time_string(seconds):
"""Converts time in seconds to a fixed-width string format."""
days, rem = divmod(int(seconds), 24 * 3600)
hrs, rem = divmod(rem, 3600)
mins, secs = divmod(rem, 60)
return "{0:02},{1:02}:{2:02}:{3:02}".format(days, hrs, mins, secs)
def inter_union(preds, labels, num_classes):
_, preds = torch.max(preds, 1)
preds = preds.type(torch.uint8) + 1
labels = labels.type(torch.uint8) + 1
preds = preds * (labels > 0).type(torch.uint8)
inter = preds * (preds == labels).type(torch.uint8)
area_inter = torch.histc(inter.type(torch.int64), bins=num_classes, min=1, max=num_classes)
area_preds = torch.histc(preds.type(torch.int64), bins=num_classes, min=1, max=num_classes)
area_labels = torch.histc(labels.type(torch.int64), bins=num_classes, min=1, max=num_classes)
area_union = area_preds + area_labels - area_inter
return [area_inter.type(torch.float64) / labels.size(0), area_union.type(torch.float64) / labels.size(0)]
def topk_errors(preds, labels, ks):
"""Computes the top-k error for each k."""
err_str = "Batch dim of predictions and labels must match"
assert preds.size(0) == labels.size(0), err_str
# Find the top max_k predictions for each sample
_top_max_k_vals, top_max_k_inds = torch.topk(
preds, max(ks), dim=1, largest=True, sorted=True
)
# (batch_size, max_k) -> (max_k, batch_size)
top_max_k_inds = top_max_k_inds.t()
# (batch_size, ) -> (max_k, batch_size)
rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds)
# (i, j) = 1 if top i-th prediction for the j-th sample is correct
top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels)
# Compute the number of topk correct predictions for each k
topks_correct = [top_max_k_correct[:k, :].view(-1).float().sum() for k in ks]
return [(1.0 - x / preds.size(0)) * 100.0 for x in topks_correct]
def gpu_mem_usage():
"""Computes the GPU memory usage for the current device (MB)."""
mem_usage_bytes = torch.cuda.max_memory_allocated()
return mem_usage_bytes / 1024 / 1024
class ScalarMeter(object):
"""Measures a scalar value (adapted from Detectron)."""
def __init__(self, window_size):
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
def reset(self):
self.deque.clear()
self.total = 0.0
self.count = 0
def add_value(self, value):
self.deque.append(value)
self.count += 1
self.total += value
def get_win_median(self):
return np.median(self.deque)
def get_win_avg(self):
return np.mean(self.deque)
def get_global_avg(self):
return self.total / self.count
class TrainMeter(object):
"""Measures training stats."""
def __init__(self, epoch_iters):
self.epoch_iters = epoch_iters
self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters
self.iter_timer = Timer()
self.loss = ScalarMeter(cfg.LOG_PERIOD)
self.loss_total = 0.0
self.lr = None
# Current minibatch errors (smoothed over a window)
self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
# Number of misclassified examples
self.num_top1_mis = 0
self.num_top5_mis = 0
self.num_samples = 0
def reset(self, timer=False):
if timer:
self.iter_timer.reset()
self.loss.reset()
self.loss_total = 0.0
self.lr = None
self.mb_top1_err.reset()
self.mb_top5_err.reset()
self.num_top1_mis = 0
self.num_top5_mis = 0
self.num_samples = 0
def iter_tic(self):
self.iter_timer.tic()
def iter_toc(self):
self.iter_timer.toc()
def update_stats(self, top1_err, top5_err, loss, lr, mb_size):
# Current minibatch stats
self.mb_top1_err.add_value(top1_err)
self.mb_top5_err.add_value(top5_err)
self.loss.add_value(loss)
self.lr = lr
# Aggregate stats
self.num_top1_mis += top1_err * mb_size
self.num_top5_mis += top5_err * mb_size
self.loss_total += loss * mb_size
self.num_samples += mb_size
def get_iter_stats(self, cur_epoch, cur_iter):
cur_iter_total = cur_epoch * self.epoch_iters + cur_iter + 1
eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
mem_usage = gpu_mem_usage()
stats = {
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"iter": "{}/{}".format(cur_iter + 1, self.epoch_iters),
"time_avg": self.iter_timer.average_time,
"time_diff": self.iter_timer.diff,
"eta": time_string(eta_sec),
"top1_err": self.mb_top1_err.get_win_median(),
"top5_err": self.mb_top5_err.get_win_median(),
"loss": self.loss.get_win_median(),
"lr": self.lr,
"mem": int(np.ceil(mem_usage)),
}
return stats
def log_iter_stats(self, cur_epoch, cur_iter):
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
return
stats = self.get_iter_stats(cur_epoch, cur_iter)
logger.info(logging.dump_log_data(stats, "train_iter"))
def get_epoch_stats(self, cur_epoch):
cur_iter_total = (cur_epoch + 1) * self.epoch_iters
eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
mem_usage = gpu_mem_usage()
top1_err = self.num_top1_mis / self.num_samples
top5_err = self.num_top5_mis / self.num_samples
avg_loss = self.loss_total / self.num_samples
stats = {
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"time_avg": self.iter_timer.average_time,
"eta": time_string(eta_sec),
"top1_err": top1_err,
"top5_err": top5_err,
"loss": avg_loss,
"lr": self.lr,
"mem": int(np.ceil(mem_usage)),
}
return stats
def log_epoch_stats(self, cur_epoch):
stats = self.get_epoch_stats(cur_epoch)
logger.info(logging.dump_log_data(stats, "train_epoch"))
class TestMeter(object):
"""Measures testing stats."""
def __init__(self, max_iter):
self.max_iter = max_iter
self.iter_timer = Timer()
# Current minibatch errors (smoothed over a window)
self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
# Min errors (over the full test set)
self.min_top1_err = 100.0
self.min_top5_err = 100.0
# Number of misclassified examples
self.num_top1_mis = 0
self.num_top5_mis = 0
self.num_samples = 0
def reset(self, min_errs=False):
if min_errs:
self.min_top1_err = 100.0
self.min_top5_err = 100.0
self.iter_timer.reset()
self.mb_top1_err.reset()
self.mb_top5_err.reset()
self.num_top1_mis = 0
self.num_top5_mis = 0
self.num_samples = 0
def iter_tic(self):
self.iter_timer.tic()
def iter_toc(self):
self.iter_timer.toc()
def update_stats(self, top1_err, top5_err, mb_size):
self.mb_top1_err.add_value(top1_err)
self.mb_top5_err.add_value(top5_err)
self.num_top1_mis += top1_err * mb_size
self.num_top5_mis += top5_err * mb_size
self.num_samples += mb_size
def get_iter_stats(self, cur_epoch, cur_iter):
mem_usage = gpu_mem_usage()
iter_stats = {
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"iter": "{}/{}".format(cur_iter + 1, self.max_iter),
"time_avg": self.iter_timer.average_time,
"time_diff": self.iter_timer.diff,
"top1_err": self.mb_top1_err.get_win_median(),
"top5_err": self.mb_top5_err.get_win_median(),
"mem": int(np.ceil(mem_usage)),
}
return iter_stats
def log_iter_stats(self, cur_epoch, cur_iter):
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
return
stats = self.get_iter_stats(cur_epoch, cur_iter)
logger.info(logging.dump_log_data(stats, "test_iter"))
def get_epoch_stats(self, cur_epoch):
top1_err = self.num_top1_mis / self.num_samples
top5_err = self.num_top5_mis / self.num_samples
self.min_top1_err = min(self.min_top1_err, top1_err)
self.min_top5_err = min(self.min_top5_err, top5_err)
mem_usage = gpu_mem_usage()
stats = {
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"time_avg": self.iter_timer.average_time,
"top1_err": top1_err,
"top5_err": top5_err,
"min_top1_err": self.min_top1_err,
"min_top5_err": self.min_top5_err,
"mem": int(np.ceil(mem_usage)),
}
return stats
def log_epoch_stats(self, cur_epoch):
stats = self.get_epoch_stats(cur_epoch)
logger.info(logging.dump_log_data(stats, "test_epoch"))
class TrainMeterIoU(object):
"""Measures training stats."""
def __init__(self, epoch_iters):
self.epoch_iters = epoch_iters
self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters
self.iter_timer = Timer()
self.loss = ScalarMeter(cfg.LOG_PERIOD)
self.loss_total = 0.0
self.lr = None
self.mb_miou = ScalarMeter(cfg.LOG_PERIOD)
self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
self.num_samples = 0
def reset(self, timer=False):
if timer:
self.iter_timer.reset()
self.loss.reset()
self.loss_total = 0.0
self.lr = None
self.mb_miou.reset()
self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
self.num_samples = 0
def iter_tic(self):
self.iter_timer.tic()
def iter_toc(self):
self.iter_timer.toc()
def update_stats(self, inter, union, loss, lr, mb_size):
# Current minibatch stats
self.mb_miou.add_value((inter / (union + 1e-10)).mean())
self.loss.add_value(loss)
self.lr = lr
# Aggregate stats
self.num_inter += inter * mb_size
self.num_union += union * mb_size
self.loss_total += loss * mb_size
self.num_samples += mb_size
def get_iter_stats(self, cur_epoch, cur_iter):
cur_iter_total = cur_epoch * self.epoch_iters + cur_iter + 1
eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
mem_usage = gpu_mem_usage()
stats = {
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"iter": "{}/{}".format(cur_iter + 1, self.epoch_iters),
"time_avg": self.iter_timer.average_time,
"time_diff": self.iter_timer.diff,
"eta": time_string(eta_sec),
"miou": self.mb_miou.get_win_median(),
"loss": self.loss.get_win_median(),
"lr": self.lr,
"mem": int(np.ceil(mem_usage)),
}
return stats
def log_iter_stats(self, cur_epoch, cur_iter):
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
return
stats = self.get_iter_stats(cur_epoch, cur_iter)
logger.info(logging.dump_log_data(stats, "train_iter"))
def get_epoch_stats(self, cur_epoch):
cur_iter_total = (cur_epoch + 1) * self.epoch_iters
eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
mem_usage = gpu_mem_usage()
miou = (self.num_inter / (self.num_union + 1e-10)).mean()
avg_loss = self.loss_total / self.num_samples
stats = {
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"time_avg": self.iter_timer.average_time,
"eta": time_string(eta_sec),
"miou": miou,
"loss": avg_loss,
"lr": self.lr,
"mem": int(np.ceil(mem_usage)),
}
return stats
def log_epoch_stats(self, cur_epoch):
stats = self.get_epoch_stats(cur_epoch)
logger.info(logging.dump_log_data(stats, "train_epoch"))
class TestMeterIoU(object):
"""Measures testing stats."""
def __init__(self, max_iter):
self.max_iter = max_iter
self.iter_timer = Timer()
self.mb_miou = ScalarMeter(cfg.LOG_PERIOD)
self.max_miou = 0.0
self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
self.num_samples = 0
def reset(self, min_errs=False):
if min_errs:
self.max_miou = 0.0
self.iter_timer.reset()
self.mb_miou.reset()
self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
self.num_samples = 0
def iter_tic(self):
self.iter_timer.tic()
def iter_toc(self):
self.iter_timer.toc()
def update_stats(self, inter, union, mb_size):
self.mb_miou.add_value((inter / (union + 1e-10)).mean())
self.num_inter += inter * mb_size
self.num_union += union * mb_size
self.num_samples += mb_size
def get_iter_stats(self, cur_epoch, cur_iter):
mem_usage = gpu_mem_usage()
iter_stats = {
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"iter": "{}/{}".format(cur_iter + 1, self.max_iter),
"time_avg": self.iter_timer.average_time,
"time_diff": self.iter_timer.diff,
"miou": self.mb_miou.get_win_median(),
"mem": int(np.ceil(mem_usage)),
}
return iter_stats
def log_iter_stats(self, cur_epoch, cur_iter):
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
return
stats = self.get_iter_stats(cur_epoch, cur_iter)
logger.info(logging.dump_log_data(stats, "test_iter"))
def get_epoch_stats(self, cur_epoch):
miou = (self.num_inter / (self.num_union + 1e-10)).mean()
self.max_miou = max(self.max_miou, miou)
mem_usage = gpu_mem_usage()
stats = {
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"time_avg": self.iter_timer.average_time,
"miou": miou,
"max_miou": self.max_miou,
"mem": int(np.ceil(mem_usage)),
}
return stats
def log_epoch_stats(self, cur_epoch):
stats = self.get_epoch_stats(cur_epoch)
logger.info(logging.dump_log_data(stats, "test_epoch"))

129
pycls/core/net.py Normal file
View File

@@ -0,0 +1,129 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Functions for manipulating networks."""
import itertools
import math
import torch
import torch.nn as nn
from pycls.core.config import cfg
def init_weights(m):
"""Performs ResNet-style weight initialization."""
if isinstance(m, nn.Conv2d):
# Note that there is no bias due to BN
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out))
elif isinstance(m, nn.BatchNorm2d):
zero_init_gamma = cfg.BN.ZERO_INIT_FINAL_GAMMA
zero_init_gamma = hasattr(m, "final_bn") and m.final_bn and zero_init_gamma
m.weight.data.fill_(0.0 if zero_init_gamma else 1.0)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(mean=0.0, std=0.01)
m.bias.data.zero_()
@torch.no_grad()
def compute_precise_bn_stats(model, loader):
"""Computes precise BN stats on training data."""
# Compute the number of minibatches to use
num_iter = min(cfg.BN.NUM_SAMPLES_PRECISE // loader.batch_size, len(loader))
# Retrieve the BN layers
bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
# Initialize stats storage
mus = [torch.zeros_like(bn.running_mean) for bn in bns]
sqs = [torch.zeros_like(bn.running_var) for bn in bns]
# Remember momentum values
moms = [bn.momentum for bn in bns]
# Disable momentum
for bn in bns:
bn.momentum = 1.0
# Accumulate the stats across the data samples
for inputs, _labels in itertools.islice(loader, num_iter):
model(inputs.cuda())
# Accumulate the stats for each BN layer
for i, bn in enumerate(bns):
m, v = bn.running_mean, bn.running_var
sqs[i] += (v + m * m) / num_iter
mus[i] += m / num_iter
# Set the stats and restore momentum values
for i, bn in enumerate(bns):
bn.running_var = sqs[i] - mus[i] * mus[i]
bn.running_mean = mus[i]
bn.momentum = moms[i]
def reset_bn_stats(model):
"""Resets running BN stats."""
for m in model.modules():
if isinstance(m, torch.nn.BatchNorm2d):
m.reset_running_stats()
def complexity_conv2d(cx, w_in, w_out, k, stride, padding, groups=1, bias=False):
"""Accumulates complexity of Conv2D into cx = (h, w, flops, params, acts)."""
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
h = (h + 2 * padding - k) // stride + 1
w = (w + 2 * padding - k) // stride + 1
flops += k * k * w_in * w_out * h * w // groups
params += k * k * w_in * w_out // groups
flops += w_out if bias else 0
params += w_out if bias else 0
acts += w_out * h * w
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
def complexity_batchnorm2d(cx, w_in):
"""Accumulates complexity of BatchNorm2D into cx = (h, w, flops, params, acts)."""
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
params += 2 * w_in
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
def complexity_maxpool2d(cx, k, stride, padding):
"""Accumulates complexity of MaxPool2d into cx = (h, w, flops, params, acts)."""
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
h = (h + 2 * padding - k) // stride + 1
w = (w + 2 * padding - k) // stride + 1
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
def complexity(model):
"""Compute model complexity (model can be model instance or model class)."""
size = cfg.TRAIN.IM_SIZE
cx = {"h": size, "w": size, "flops": 0, "params": 0, "acts": 0}
cx = model.complexity(cx)
return {"flops": cx["flops"], "params": cx["params"], "acts": cx["acts"]}
def drop_connect(x, drop_ratio):
"""Drop connect (adapted from DARTS)."""
keep_ratio = 1.0 - drop_ratio
mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
mask.bernoulli_(keep_ratio)
x.div_(keep_ratio)
x.mul_(mask)
return x
def get_flat_weights(model):
"""Gets all model weights as a single flat vector."""
return torch.cat([p.data.view(-1, 1) for p in model.parameters()], 0)
def set_flat_weights(model, flat_weights):
"""Sets all model weights from a single flat vector."""
k = 0
for p in model.parameters():
n = p.data.numel()
p.data.copy_(flat_weights[k : (k + n)].view_as(p.data))
k += n
assert k == flat_weights.numel()

95
pycls/core/optimizer.py Normal file
View File

@@ -0,0 +1,95 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Optimizer."""
import numpy as np
import torch
from pycls.core.config import cfg
def construct_optimizer(model):
"""Constructs the optimizer.
Note that the momentum update in PyTorch differs from the one in Caffe2.
In particular,
Caffe2:
V := mu * V + lr * g
p := p - V
PyTorch:
V := mu * V + g
p := p - lr * V
where V is the velocity, mu is the momentum factor, lr is the learning rate,
g is the gradient and p are the parameters.
Since V is defined independently of the learning rate in PyTorch,
when the learning rate is changed there is no need to perform the
momentum correction by scaling V (unlike in the Caffe2 case).
"""
if cfg.BN.USE_CUSTOM_WEIGHT_DECAY:
# Apply different weight decay to Batchnorm and non-batchnorm parameters.
p_bn = [p for n, p in model.named_parameters() if "bn" in n]
p_non_bn = [p for n, p in model.named_parameters() if "bn" not in n]
optim_params = [
{"params": p_bn, "weight_decay": cfg.BN.CUSTOM_WEIGHT_DECAY},
{"params": p_non_bn, "weight_decay": cfg.OPTIM.WEIGHT_DECAY},
]
else:
optim_params = model.parameters()
return torch.optim.SGD(
optim_params,
lr=cfg.OPTIM.BASE_LR,
momentum=cfg.OPTIM.MOMENTUM,
weight_decay=cfg.OPTIM.WEIGHT_DECAY,
dampening=cfg.OPTIM.DAMPENING,
nesterov=cfg.OPTIM.NESTEROV,
)
def lr_fun_steps(cur_epoch):
"""Steps schedule (cfg.OPTIM.LR_POLICY = 'steps')."""
ind = [i for i, s in enumerate(cfg.OPTIM.STEPS) if cur_epoch >= s][-1]
return cfg.OPTIM.BASE_LR * (cfg.OPTIM.LR_MULT ** ind)
def lr_fun_exp(cur_epoch):
"""Exponential schedule (cfg.OPTIM.LR_POLICY = 'exp')."""
return cfg.OPTIM.BASE_LR * (cfg.OPTIM.GAMMA ** cur_epoch)
def lr_fun_cos(cur_epoch):
"""Cosine schedule (cfg.OPTIM.LR_POLICY = 'cos')."""
base_lr, max_epoch = cfg.OPTIM.BASE_LR, cfg.OPTIM.MAX_EPOCH
return 0.5 * base_lr * (1.0 + np.cos(np.pi * cur_epoch / max_epoch))
def get_lr_fun():
"""Retrieves the specified lr policy function"""
lr_fun = "lr_fun_" + cfg.OPTIM.LR_POLICY
if lr_fun not in globals():
raise NotImplementedError("Unknown LR policy:" + cfg.OPTIM.LR_POLICY)
return globals()[lr_fun]
def get_epoch_lr(cur_epoch):
"""Retrieves the lr for the given epoch according to the policy."""
lr = get_lr_fun()(cur_epoch)
# Linear warmup
if cur_epoch < cfg.OPTIM.WARMUP_EPOCHS:
alpha = cur_epoch / cfg.OPTIM.WARMUP_EPOCHS
warmup_factor = cfg.OPTIM.WARMUP_FACTOR * (1.0 - alpha) + alpha
lr *= warmup_factor
return lr
def set_lr(optimizer, new_lr):
"""Sets the optimizer lr to the specified value."""
for param_group in optimizer.param_groups:
param_group["lr"] = new_lr

132
pycls/core/plotting.py Normal file
View File

@@ -0,0 +1,132 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Plotting functions."""
import colorlover as cl
import matplotlib.pyplot as plt
import plotly.graph_objs as go
import plotly.offline as offline
import pycls.core.logging as logging
def get_plot_colors(max_colors, color_format="pyplot"):
"""Generate colors for plotting."""
colors = cl.scales["11"]["qual"]["Paired"]
if max_colors > len(colors):
colors = cl.to_rgb(cl.interp(colors, max_colors))
if color_format == "pyplot":
return [[j / 255.0 for j in c] for c in cl.to_numeric(colors)]
return colors
def prepare_plot_data(log_files, names, metric="top1_err"):
"""Load logs and extract data for plotting error curves."""
plot_data = []
for file, name in zip(log_files, names):
d, data = {}, logging.sort_log_data(logging.load_log_data(file))
for phase in ["train", "test"]:
x = data[phase + "_epoch"]["epoch_ind"]
y = data[phase + "_epoch"][metric]
d["x_" + phase], d["y_" + phase] = x, y
d[phase + "_label"] = "[{:5.2f}] ".format(min(y) if y else 0) + name
plot_data.append(d)
assert len(plot_data) > 0, "No data to plot"
return plot_data
def plot_error_curves_plotly(log_files, names, filename, metric="top1_err"):
"""Plot error curves using plotly and save to file."""
plot_data = prepare_plot_data(log_files, names, metric)
colors = get_plot_colors(len(plot_data), "plotly")
# Prepare data for plots (3 sets, train duplicated w and w/o legend)
data = []
for i, d in enumerate(plot_data):
s = str(i)
line_train = {"color": colors[i], "dash": "dashdot", "width": 1.5}
line_test = {"color": colors[i], "dash": "solid", "width": 1.5}
data.append(
go.Scatter(
x=d["x_train"],
y=d["y_train"],
mode="lines",
name=d["train_label"],
line=line_train,
legendgroup=s,
visible=True,
showlegend=False,
)
)
data.append(
go.Scatter(
x=d["x_test"],
y=d["y_test"],
mode="lines",
name=d["test_label"],
line=line_test,
legendgroup=s,
visible=True,
showlegend=True,
)
)
data.append(
go.Scatter(
x=d["x_train"],
y=d["y_train"],
mode="lines",
name=d["train_label"],
line=line_train,
legendgroup=s,
visible=False,
showlegend=True,
)
)
# Prepare layout w ability to toggle 'all', 'train', 'test'
titlefont = {"size": 18, "color": "#7f7f7f"}
vis = [[True, True, False], [False, False, True], [False, True, False]]
buttons = zip(["all", "train", "test"], [[{"visible": v}] for v in vis])
buttons = [{"label": b, "args": v, "method": "update"} for b, v in buttons]
layout = go.Layout(
title=metric + " vs. epoch<br>[dash=train, solid=test]",
xaxis={"title": "epoch", "titlefont": titlefont},
yaxis={"title": metric, "titlefont": titlefont},
showlegend=True,
hoverlabel={"namelength": -1},
updatemenus=[
{
"buttons": buttons,
"direction": "down",
"showactive": True,
"x": 1.02,
"xanchor": "left",
"y": 1.08,
"yanchor": "top",
}
],
)
# Create plotly plot
offline.plot({"data": data, "layout": layout}, filename=filename)
def plot_error_curves_pyplot(log_files, names, filename=None, metric="top1_err"):
"""Plot error curves using matplotlib.pyplot and save to file."""
plot_data = prepare_plot_data(log_files, names, metric)
colors = get_plot_colors(len(names))
for ind, d in enumerate(plot_data):
c, lbl = colors[ind], d["test_label"]
plt.plot(d["x_train"], d["y_train"], "--", c=c, alpha=0.8)
plt.plot(d["x_test"], d["y_test"], "-", c=c, alpha=0.8, label=lbl)
plt.title(metric + " vs. epoch\n[dash=train, solid=test]", fontsize=14)
plt.xlabel("epoch", fontsize=14)
plt.ylabel(metric, fontsize=14)
plt.grid(alpha=0.4)
plt.legend()
if filename:
plt.savefig(filename)
plt.clf()
else:
plt.show()

39
pycls/core/timer.py Normal file
View File

@@ -0,0 +1,39 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Timer."""
import time
class Timer(object):
"""A simple timer (adapted from Detectron)."""
def __init__(self):
self.total_time = None
self.calls = None
self.start_time = None
self.diff = None
self.average_time = None
self.reset()
def tic(self):
# using time.time as time.clock does not normalize for multithreading
self.start_time = time.time()
def toc(self):
self.diff = time.time() - self.start_time
self.total_time += self.diff
self.calls += 1
self.average_time = self.total_time / self.calls
def reset(self):
self.total_time = 0.0
self.calls = 0
self.start_time = 0.0
self.diff = 0.0
self.average_time = 0.0

419
pycls/core/trainer.py Normal file
View File

@@ -0,0 +1,419 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Tools for training and testing a model."""
import os
from thop import profile
import numpy as np
import pycls.core.benchmark as benchmark
import pycls.core.builders as builders
import pycls.core.checkpoint as checkpoint
import pycls.core.config as config
import pycls.core.distributed as dist
import pycls.core.logging as logging
import pycls.core.meters as meters
import pycls.core.net as net
import pycls.core.optimizer as optim
import pycls.datasets.loader as loader
import torch
import torch.nn.functional as F
from pycls.core.config import cfg
logger = logging.get_logger(__name__)
def setup_env():
"""Sets up environment for training or testing."""
if dist.is_master_proc():
# Ensure that the output dir exists
os.makedirs(cfg.OUT_DIR, exist_ok=True)
# Save the config
config.dump_cfg()
# Setup logging
logging.setup_logging()
# Log the config as both human readable and as a json
logger.info("Config:\n{}".format(cfg))
logger.info(logging.dump_log_data(cfg, "cfg"))
# Fix the RNG seeds (see RNG comment in core/config.py for discussion)
np.random.seed(cfg.RNG_SEED)
torch.manual_seed(cfg.RNG_SEED)
# Configure the CUDNN backend
torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK
def setup_model():
"""Sets up a model for training or testing and log the results."""
# Build the model
model = builders.build_model()
logger.info("Model:\n{}".format(model))
# Log model complexity
# logger.info(logging.dump_log_data(net.complexity(model), "complexity"))
if cfg.TASK == "seg" and cfg.TRAIN.DATASET == "cityscapes":
h, w = 1025, 2049
else:
h, w = cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE
if cfg.TASK == "jig":
x = torch.randn(1, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, h, w)
else:
x = torch.randn(1, cfg.MODEL.INPUT_CHANNELS, h, w)
macs, params = profile(model, inputs=(x, ), verbose=False)
logger.info("Params: {:,}".format(params))
logger.info("Flops: {:,}".format(macs))
# Transfer the model to the current GPU device
err_str = "Cannot use more GPU devices than available"
assert cfg.NUM_GPUS <= torch.cuda.device_count(), err_str
cur_device = torch.cuda.current_device()
model = model.cuda(device=cur_device)
# Use multi-process data parallel model in the multi-gpu setting
if cfg.NUM_GPUS > 1:
# Make model replica operate on the current device
model = torch.nn.parallel.DistributedDataParallel(
module=model, device_ids=[cur_device], output_device=cur_device
)
# Set complexity function to be module's complexity function
# model.complexity = model.module.complexity
return model
def train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch):
"""Performs one epoch of training."""
# Update drop path prob for NAS
if cfg.MODEL.TYPE == "nas":
m = model.module if cfg.NUM_GPUS > 1 else model
m.set_drop_path_prob(cfg.NAS.DROP_PROB * cur_epoch / cfg.OPTIM.MAX_EPOCH)
# Shuffle the data
loader.shuffle(train_loader, cur_epoch)
# Update the learning rate per epoch
if not cfg.OPTIM.ITER_LR:
lr = optim.get_epoch_lr(cur_epoch)
optim.set_lr(optimizer, lr)
# Enable training mode
model.train()
train_meter.iter_tic()
for cur_iter, (inputs, labels) in enumerate(train_loader):
# Update the learning rate per iter
if cfg.OPTIM.ITER_LR:
lr = optim.get_epoch_lr(cur_epoch + cur_iter / len(train_loader))
optim.set_lr(optimizer, lr)
# Transfer the data to the current GPU device
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
# Perform the forward pass
preds = model(inputs)
# Compute the loss
if isinstance(preds, tuple):
loss = loss_fun(preds[0], labels) + cfg.NAS.AUX_WEIGHT * loss_fun(preds[1], labels)
preds = preds[0]
else:
loss = loss_fun(preds, labels)
# Perform the backward pass
optimizer.zero_grad()
loss.backward()
# Update the parameters
optimizer.step()
# Compute the errors
if cfg.TASK == "col":
preds = preds.permute(0, 2, 3, 1)
preds = preds.reshape(-1, preds.size(3))
labels = labels.reshape(-1)
mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
else:
mb_size = inputs.size(0) * cfg.NUM_GPUS
if cfg.TASK == "seg":
# top1_err is in fact inter; top5_err is in fact union
top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
else:
ks = [1, min(5, cfg.MODEL.NUM_CLASSES)] # rot only has 4 classes
top1_err, top5_err = meters.topk_errors(preds, labels, ks)
# Combine the stats across the GPUs (no reduction if 1 GPU used)
loss, top1_err, top5_err = dist.scaled_all_reduce([loss, top1_err, top5_err])
# Copy the stats from GPU to CPU (sync point)
loss = loss.item()
if cfg.TASK == "seg":
top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
else:
top1_err, top5_err = top1_err.item(), top5_err.item()
train_meter.iter_toc()
# Update and log stats
train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
train_meter.log_iter_stats(cur_epoch, cur_iter)
train_meter.iter_tic()
# Log epoch stats
train_meter.log_epoch_stats(cur_epoch)
train_meter.reset()
def search_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch):
"""Performs one epoch of differentiable architecture search."""
m = model.module if cfg.NUM_GPUS > 1 else model
# Shuffle the data
loader.shuffle(train_loader[0], cur_epoch)
loader.shuffle(train_loader[1], cur_epoch)
# Update the learning rate per epoch
if not cfg.OPTIM.ITER_LR:
lr = optim.get_epoch_lr(cur_epoch)
optim.set_lr(optimizer[0], lr)
# Enable training mode
model.train()
train_meter.iter_tic()
trainB_iter = iter(train_loader[1])
for cur_iter, (inputs, labels) in enumerate(train_loader[0]):
# Update the learning rate per iter
if cfg.OPTIM.ITER_LR:
lr = optim.get_epoch_lr(cur_epoch + cur_iter / len(train_loader[0]))
optim.set_lr(optimizer[0], lr)
# Transfer the data to the current GPU device
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
# Update architecture
if cur_epoch + cur_iter / len(train_loader[0]) >= cfg.OPTIM.ARCH_EPOCH:
try:
inputsB, labelsB = next(trainB_iter)
except StopIteration:
trainB_iter = iter(train_loader[1])
inputsB, labelsB = next(trainB_iter)
inputsB, labelsB = inputsB.cuda(), labelsB.cuda(non_blocking=True)
optimizer[1].zero_grad()
loss = m._loss(inputsB, labelsB)
loss.backward()
optimizer[1].step()
# Perform the forward pass
preds = model(inputs)
# Compute the loss
loss = loss_fun(preds, labels)
# Perform the backward pass
optimizer[0].zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm(model.parameters(), 5.0)
# Update the parameters
optimizer[0].step()
# Compute the errors
if cfg.TASK == "col":
preds = preds.permute(0, 2, 3, 1)
preds = preds.reshape(-1, preds.size(3))
labels = labels.reshape(-1)
mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
else:
mb_size = inputs.size(0) * cfg.NUM_GPUS
if cfg.TASK == "seg":
# top1_err is in fact inter; top5_err is in fact union
top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
else:
ks = [1, min(5, cfg.MODEL.NUM_CLASSES)] # rot only has 4 classes
top1_err, top5_err = meters.topk_errors(preds, labels, ks)
# Combine the stats across the GPUs (no reduction if 1 GPU used)
loss, top1_err, top5_err = dist.scaled_all_reduce([loss, top1_err, top5_err])
# Copy the stats from GPU to CPU (sync point)
loss = loss.item()
if cfg.TASK == "seg":
top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
else:
top1_err, top5_err = top1_err.item(), top5_err.item()
train_meter.iter_toc()
# Update and log stats
train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
train_meter.log_iter_stats(cur_epoch, cur_iter)
train_meter.iter_tic()
# Log epoch stats
train_meter.log_epoch_stats(cur_epoch)
train_meter.reset()
# Log genotype
genotype = m.genotype()
logger.info("genotype = %s", genotype)
logger.info(F.softmax(m.net_.alphas_normal, dim=-1))
logger.info(F.softmax(m.net_.alphas_reduce, dim=-1))
@torch.no_grad()
def test_epoch(test_loader, model, test_meter, cur_epoch):
"""Evaluates the model on the test set."""
# Enable eval mode
model.eval()
test_meter.iter_tic()
for cur_iter, (inputs, labels) in enumerate(test_loader):
# Transfer the data to the current GPU device
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
# Compute the predictions
preds = model(inputs)
# Compute the errors
if cfg.TASK == "col":
preds = preds.permute(0, 2, 3, 1)
preds = preds.reshape(-1, preds.size(3))
labels = labels.reshape(-1)
mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
else:
mb_size = inputs.size(0) * cfg.NUM_GPUS
if cfg.TASK == "seg":
# top1_err is in fact inter; top5_err is in fact union
top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
else:
ks = [1, min(5, cfg.MODEL.NUM_CLASSES)] # rot only has 4 classes
top1_err, top5_err = meters.topk_errors(preds, labels, ks)
# Combine the errors across the GPUs (no reduction if 1 GPU used)
top1_err, top5_err = dist.scaled_all_reduce([top1_err, top5_err])
# Copy the errors from GPU to CPU (sync point)
if cfg.TASK == "seg":
top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
else:
top1_err, top5_err = top1_err.item(), top5_err.item()
test_meter.iter_toc()
# Update and log stats
test_meter.update_stats(top1_err, top5_err, mb_size)
test_meter.log_iter_stats(cur_epoch, cur_iter)
test_meter.iter_tic()
# Log epoch stats
test_meter.log_epoch_stats(cur_epoch)
test_meter.reset()
def train_model():
"""Trains the model."""
# Setup training/testing environment
setup_env()
# Construct the model, loss_fun, and optimizer
model = setup_model()
loss_fun = builders.build_loss_fun().cuda()
if "search" in cfg.MODEL.TYPE:
params_w = [v for k, v in model.named_parameters() if "alphas" not in k]
params_a = [v for k, v in model.named_parameters() if "alphas" in k]
optimizer_w = torch.optim.SGD(
params=params_w,
lr=cfg.OPTIM.BASE_LR,
momentum=cfg.OPTIM.MOMENTUM,
weight_decay=cfg.OPTIM.WEIGHT_DECAY,
dampening=cfg.OPTIM.DAMPENING,
nesterov=cfg.OPTIM.NESTEROV
)
if cfg.OPTIM.ARCH_OPTIM == "adam":
optimizer_a = torch.optim.Adam(
params=params_a,
lr=cfg.OPTIM.ARCH_BASE_LR,
betas=(0.5, 0.999),
weight_decay=cfg.OPTIM.ARCH_WEIGHT_DECAY
)
elif cfg.OPTIM.ARCH_OPTIM == "sgd":
optimizer_a = torch.optim.SGD(
params=params_a,
lr=cfg.OPTIM.ARCH_BASE_LR,
momentum=cfg.OPTIM.MOMENTUM,
weight_decay=cfg.OPTIM.ARCH_WEIGHT_DECAY,
dampening=cfg.OPTIM.DAMPENING,
nesterov=cfg.OPTIM.NESTEROV
)
optimizer = [optimizer_w, optimizer_a]
else:
optimizer = optim.construct_optimizer(model)
# Load checkpoint or initial weights
start_epoch = 0
if cfg.TRAIN.AUTO_RESUME and checkpoint.has_checkpoint():
last_checkpoint = checkpoint.get_last_checkpoint()
checkpoint_epoch = checkpoint.load_checkpoint(last_checkpoint, model, optimizer)
logger.info("Loaded checkpoint from: {}".format(last_checkpoint))
start_epoch = checkpoint_epoch + 1
elif cfg.TRAIN.WEIGHTS:
checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model)
logger.info("Loaded initial weights from: {}".format(cfg.TRAIN.WEIGHTS))
# Create data loaders and meters
if cfg.TRAIN.PORTION < 1:
if "search" in cfg.MODEL.TYPE:
train_loader = [loader._construct_loader(
dataset_name=cfg.TRAIN.DATASET,
split=cfg.TRAIN.SPLIT,
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
shuffle=True,
drop_last=True,
portion=cfg.TRAIN.PORTION,
side="l"
),
loader._construct_loader(
dataset_name=cfg.TRAIN.DATASET,
split=cfg.TRAIN.SPLIT,
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
shuffle=True,
drop_last=True,
portion=cfg.TRAIN.PORTION,
side="r"
)]
else:
train_loader = loader._construct_loader(
dataset_name=cfg.TRAIN.DATASET,
split=cfg.TRAIN.SPLIT,
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
shuffle=True,
drop_last=True,
portion=cfg.TRAIN.PORTION,
side="l"
)
test_loader = loader._construct_loader(
dataset_name=cfg.TRAIN.DATASET,
split=cfg.TRAIN.SPLIT,
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
shuffle=False,
drop_last=False,
portion=cfg.TRAIN.PORTION,
side="r"
)
else:
train_loader = loader.construct_train_loader()
test_loader = loader.construct_test_loader()
train_meter_type = meters.TrainMeterIoU if cfg.TASK == "seg" else meters.TrainMeter
test_meter_type = meters.TestMeterIoU if cfg.TASK == "seg" else meters.TestMeter
l = train_loader[0] if isinstance(train_loader, list) else train_loader
train_meter = train_meter_type(len(l))
test_meter = test_meter_type(len(test_loader))
# Compute model and loader timings
if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0:
l = train_loader[0] if isinstance(train_loader, list) else train_loader
benchmark.compute_time_full(model, loss_fun, l, test_loader)
# Perform the training loop
logger.info("Start epoch: {}".format(start_epoch + 1))
for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
# Train for one epoch
f = search_epoch if "search" in cfg.MODEL.TYPE else train_epoch
f(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch)
# Compute precise BN stats
if cfg.BN.USE_PRECISE_STATS:
net.compute_precise_bn_stats(model, train_loader)
# Save a checkpoint
if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0:
checkpoint_file = checkpoint.save_checkpoint(model, optimizer, cur_epoch)
logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
# Evaluate the model
next_epoch = cur_epoch + 1
if next_epoch % cfg.TRAIN.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
test_epoch(test_loader, model, test_meter, cur_epoch)
def test_model():
"""Evaluates a trained model."""
# Setup training/testing environment
setup_env()
# Construct the model
model = setup_model()
# Load model weights
checkpoint.load_checkpoint(cfg.TEST.WEIGHTS, model)
logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS))
# Create data loaders and meters
test_loader = loader.construct_test_loader()
test_meter = meters.TestMeter(len(test_loader))
# Evaluate the model
test_epoch(test_loader, model, test_meter, 0)
def time_model():
"""Times model and data loader."""
# Setup training/testing environment
setup_env()
# Construct the model and loss_fun
model = setup_model()
loss_fun = builders.build_loss_fun().cuda()
# Create data loaders
train_loader = loader.construct_train_loader()
test_loader = loader.construct_test_loader()
# Compute model and loader timings
benchmark.compute_time_full(model, loss_fun, train_loader, test_loader)

0
pycls/models/__init__.py Normal file
View File

406
pycls/models/anynet.py Normal file
View File

@@ -0,0 +1,406 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""AnyNet models."""
import pycls.core.net as net
import torch.nn as nn
from pycls.core.config import cfg
def get_stem_fun(stem_type):
"""Retrieves the stem function by name."""
stem_funs = {
"res_stem_cifar": ResStemCifar,
"res_stem_in": ResStemIN,
"simple_stem_in": SimpleStemIN,
}
err_str = "Stem type '{}' not supported"
assert stem_type in stem_funs.keys(), err_str.format(stem_type)
return stem_funs[stem_type]
def get_block_fun(block_type):
"""Retrieves the block function by name."""
block_funs = {
"vanilla_block": VanillaBlock,
"res_basic_block": ResBasicBlock,
"res_bottleneck_block": ResBottleneckBlock,
}
err_str = "Block type '{}' not supported"
assert block_type in block_funs.keys(), err_str.format(block_type)
return block_funs[block_type]
class AnyHead(nn.Module):
"""AnyNet head: AvgPool, 1x1."""
def __init__(self, w_in, nc):
super(AnyHead, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(w_in, nc, bias=True)
def forward(self, x):
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
@staticmethod
def complexity(cx, w_in, nc):
cx["h"], cx["w"] = 1, 1
cx = net.complexity_conv2d(cx, w_in, nc, 1, 1, 0, bias=True)
return cx
class VanillaBlock(nn.Module):
"""Vanilla block: [3x3 conv, BN, Relu] x2."""
def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None):
err_str = "Vanilla block does not support bm, gw, and se_r options"
assert bm is None and gw is None and se_r is None, err_str
super(VanillaBlock, self).__init__()
self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False)
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False)
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, bm=None, gw=None, se_r=None):
err_str = "Vanilla block does not support bm, gw, and se_r options"
assert bm is None and gw is None and se_r is None, err_str
cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class BasicTransform(nn.Module):
"""Basic transformation: [3x3 conv, BN, Relu] x2."""
def __init__(self, w_in, w_out, stride):
super(BasicTransform, self).__init__()
self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False)
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False)
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_bn.final_bn = True
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride):
cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class ResBasicBlock(nn.Module):
"""Residual basic block: x + F(x), F = basic transform."""
def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None):
err_str = "Basic transform does not support bm, gw, and se_r options"
assert bm is None and gw is None and se_r is None, err_str
super(ResBasicBlock, self).__init__()
self.proj_block = (w_in != w_out) or (stride != 1)
if self.proj_block:
self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.f = BasicTransform(w_in, w_out, stride)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
if self.proj_block:
x = self.bn(self.proj(x)) + self.f(x)
else:
x = x + self.f(x)
x = self.relu(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, bm=None, gw=None, se_r=None):
err_str = "Basic transform does not support bm, gw, and se_r options"
assert bm is None and gw is None and se_r is None, err_str
proj_block = (w_in != w_out) or (stride != 1)
if proj_block:
h, w = cx["h"], cx["w"]
cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0)
cx = net.complexity_batchnorm2d(cx, w_out)
cx["h"], cx["w"] = h, w # parallel branch
cx = BasicTransform.complexity(cx, w_in, w_out, stride)
return cx
class SE(nn.Module):
"""Squeeze-and-Excitation (SE) block: AvgPool, FC, ReLU, FC, Sigmoid."""
def __init__(self, w_in, w_se):
super(SE, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.f_ex = nn.Sequential(
nn.Conv2d(w_in, w_se, 1, bias=True),
nn.ReLU(inplace=cfg.MEM.RELU_INPLACE),
nn.Conv2d(w_se, w_in, 1, bias=True),
nn.Sigmoid(),
)
def forward(self, x):
return x * self.f_ex(self.avg_pool(x))
@staticmethod
def complexity(cx, w_in, w_se):
h, w = cx["h"], cx["w"]
cx["h"], cx["w"] = 1, 1
cx = net.complexity_conv2d(cx, w_in, w_se, 1, 1, 0, bias=True)
cx = net.complexity_conv2d(cx, w_se, w_in, 1, 1, 0, bias=True)
cx["h"], cx["w"] = h, w
return cx
class BottleneckTransform(nn.Module):
"""Bottleneck transformation: 1x1, 3x3 [+SE], 1x1."""
def __init__(self, w_in, w_out, stride, bm, gw, se_r):
super(BottleneckTransform, self).__init__()
w_b = int(round(w_out * bm))
g = w_b // gw
self.a = nn.Conv2d(w_in, w_b, 1, stride=1, padding=0, bias=False)
self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
self.b = nn.Conv2d(w_b, w_b, 3, stride=stride, padding=1, groups=g, bias=False)
self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
if se_r:
w_se = int(round(w_in * se_r))
self.se = SE(w_b, w_se)
self.c = nn.Conv2d(w_b, w_out, 1, stride=1, padding=0, bias=False)
self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.c_bn.final_bn = True
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, bm, gw, se_r):
w_b = int(round(w_out * bm))
g = w_b // gw
cx = net.complexity_conv2d(cx, w_in, w_b, 1, 1, 0)
cx = net.complexity_batchnorm2d(cx, w_b)
cx = net.complexity_conv2d(cx, w_b, w_b, 3, stride, 1, g)
cx = net.complexity_batchnorm2d(cx, w_b)
if se_r:
w_se = int(round(w_in * se_r))
cx = SE.complexity(cx, w_b, w_se)
cx = net.complexity_conv2d(cx, w_b, w_out, 1, 1, 0)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class ResBottleneckBlock(nn.Module):
"""Residual bottleneck block: x + F(x), F = bottleneck transform."""
def __init__(self, w_in, w_out, stride, bm=1.0, gw=1, se_r=None):
super(ResBottleneckBlock, self).__init__()
# Use skip connection with projection if shape changes
self.proj_block = (w_in != w_out) or (stride != 1)
if self.proj_block:
self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.f = BottleneckTransform(w_in, w_out, stride, bm, gw, se_r)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
if self.proj_block:
x = self.bn(self.proj(x)) + self.f(x)
else:
x = x + self.f(x)
x = self.relu(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, bm=1.0, gw=1, se_r=None):
proj_block = (w_in != w_out) or (stride != 1)
if proj_block:
h, w = cx["h"], cx["w"]
cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0)
cx = net.complexity_batchnorm2d(cx, w_out)
cx["h"], cx["w"] = h, w # parallel branch
cx = BottleneckTransform.complexity(cx, w_in, w_out, stride, bm, gw, se_r)
return cx
class ResStemCifar(nn.Module):
"""ResNet stem for CIFAR: 3x3, BN, ReLU."""
def __init__(self, w_in, w_out):
super(ResStemCifar, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 3, stride=1, padding=1, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out):
cx = net.complexity_conv2d(cx, w_in, w_out, 3, 1, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class ResStemIN(nn.Module):
"""ResNet stem for ImageNet: 7x7, BN, ReLU, MaxPool."""
def __init__(self, w_in, w_out):
super(ResStemIN, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 7, stride=2, padding=3, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
self.pool = nn.MaxPool2d(3, stride=2, padding=1)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out):
cx = net.complexity_conv2d(cx, w_in, w_out, 7, 2, 3)
cx = net.complexity_batchnorm2d(cx, w_out)
cx = net.complexity_maxpool2d(cx, 3, 2, 1)
return cx
class SimpleStemIN(nn.Module):
"""Simple stem for ImageNet: 3x3, BN, ReLU."""
def __init__(self, w_in, w_out):
super(SimpleStemIN, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 3, stride=2, padding=1, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out):
cx = net.complexity_conv2d(cx, w_in, w_out, 3, 2, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class AnyStage(nn.Module):
"""AnyNet stage (sequence of blocks w/ the same output shape)."""
def __init__(self, w_in, w_out, stride, d, block_fun, bm, gw, se_r):
super(AnyStage, self).__init__()
for i in range(d):
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
name = "b{}".format(i + 1)
self.add_module(name, block_fun(b_w_in, w_out, b_stride, bm, gw, se_r))
def forward(self, x):
for block in self.children():
x = block(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, d, block_fun, bm, gw, se_r):
for i in range(d):
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
cx = block_fun.complexity(cx, b_w_in, w_out, b_stride, bm, gw, se_r)
return cx
class AnyNet(nn.Module):
"""AnyNet model."""
@staticmethod
def get_args():
return {
"stem_type": cfg.ANYNET.STEM_TYPE,
"stem_w": cfg.ANYNET.STEM_W,
"block_type": cfg.ANYNET.BLOCK_TYPE,
"ds": cfg.ANYNET.DEPTHS,
"ws": cfg.ANYNET.WIDTHS,
"ss": cfg.ANYNET.STRIDES,
"bms": cfg.ANYNET.BOT_MULS,
"gws": cfg.ANYNET.GROUP_WS,
"se_r": cfg.ANYNET.SE_R if cfg.ANYNET.SE_ON else None,
"nc": cfg.MODEL.NUM_CLASSES,
}
def __init__(self, **kwargs):
super(AnyNet, self).__init__()
kwargs = self.get_args() if not kwargs else kwargs
#print(kwargs)
self._construct(**kwargs)
self.apply(net.init_weights)
def _construct(self, stem_type, stem_w, block_type, ds, ws, ss, bms, gws, se_r, nc):
# Generate dummy bot muls and gs for models that do not use them
bms = bms if bms else [None for _d in ds]
gws = gws if gws else [None for _d in ds]
stage_params = list(zip(ds, ws, ss, bms, gws))
stem_fun = get_stem_fun(stem_type)
self.stem = stem_fun(3, stem_w)
block_fun = get_block_fun(block_type)
prev_w = stem_w
for i, (d, w, s, bm, gw) in enumerate(stage_params):
name = "s{}".format(i + 1)
self.add_module(name, AnyStage(prev_w, w, s, d, block_fun, bm, gw, se_r))
prev_w = w
self.head = AnyHead(w_in=prev_w, nc=nc)
def forward(self, x, get_ints=False):
for module in self.children():
x = module(x)
return x
@staticmethod
def complexity(cx, **kwargs):
"""Computes model complexity. If you alter the model, make sure to update."""
kwargs = AnyNet.get_args() if not kwargs else kwargs
return AnyNet._complexity(cx, **kwargs)
@staticmethod
def _complexity(cx, stem_type, stem_w, block_type, ds, ws, ss, bms, gws, se_r, nc):
bms = bms if bms else [None for _d in ds]
gws = gws if gws else [None for _d in ds]
stage_params = list(zip(ds, ws, ss, bms, gws))
stem_fun = get_stem_fun(stem_type)
cx = stem_fun.complexity(cx, 3, stem_w)
block_fun = get_block_fun(block_type)
prev_w = stem_w
for d, w, s, bm, gw in stage_params:
cx = AnyStage.complexity(cx, prev_w, w, s, d, block_fun, bm, gw, se_r)
prev_w = w
cx = AnyHead.complexity(cx, prev_w, nc)
return cx

108
pycls/models/common.py Normal file
View File

@@ -0,0 +1,108 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from pycls.core.config import cfg
def Preprocess(x):
if cfg.TASK == 'jig':
assert len(x.shape) == 5, 'Wrong tensor dimension for jigsaw'
assert x.shape[1] == cfg.JIGSAW_GRID ** 2, 'Wrong grid for jigsaw'
x = x.view([x.shape[0] * x.shape[1], x.shape[2], x.shape[3], x.shape[4]])
return x
class Classifier(nn.Module):
def __init__(self, channels, num_classes):
super(Classifier, self).__init__()
if cfg.TASK == 'jig':
self.jig_sq = cfg.JIGSAW_GRID ** 2
self.pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(channels * self.jig_sq, num_classes)
elif cfg.TASK == 'col':
self.classifier = nn.Conv2d(channels, num_classes, kernel_size=1, stride=1)
elif cfg.TASK == 'seg':
self.classifier = ASPP(channels, cfg.MODEL.ASPP_CHANNELS, num_classes, cfg.MODEL.ASPP_RATES)
else:
self.pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(channels, num_classes)
def forward(self, x, shape):
if cfg.TASK == 'jig':
x = self.pooling(x)
x = x.view([x.shape[0] // self.jig_sq, x.shape[1] * self.jig_sq, x.shape[2], x.shape[3]])
x = self.classifier(x.view(x.size(0), -1))
elif cfg.TASK in ['col', 'seg']:
x = self.classifier(x)
x = nn.Upsample(shape, mode='bilinear', align_corners=True)(x)
else:
x = self.pooling(x)
x = self.classifier(x.view(x.size(0), -1))
return x
class ASPP(nn.Module):
def __init__(self, in_channels, out_channels, num_classes, rates):
super(ASPP, self).__init__()
assert len(rates) in [1, 3]
self.rates = rates
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.aspp1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.aspp2 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[0],
padding=rates[0], bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
if len(self.rates) == 3:
self.aspp3 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[1],
padding=rates[1], bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.aspp4 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[2],
padding=rates[2], bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.aspp5 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.classifier = nn.Sequential(
nn.Conv2d(out_channels * (len(rates) + 2), out_channels, 1,
bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, num_classes, 1)
)
def forward(self, x):
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x5 = self.global_pooling(x)
x5 = self.aspp5(x5)
x5 = nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear',
align_corners=True)(x5)
if len(self.rates) == 3:
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x = torch.cat((x1, x2, x3, x4, x5), 1)
else:
x = torch.cat((x1, x2, x5), 1)
x = self.classifier(x)
return x

232
pycls/models/effnet.py Normal file
View File

@@ -0,0 +1,232 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""EfficientNet models."""
import pycls.core.net as net
import torch
import torch.nn as nn
from pycls.core.config import cfg
class EffHead(nn.Module):
"""EfficientNet head: 1x1, BN, Swish, AvgPool, Dropout, FC."""
def __init__(self, w_in, w_out, nc):
super(EffHead, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 1, stride=1, padding=0, bias=False)
self.conv_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.conv_swish = Swish()
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
if cfg.EN.DROPOUT_RATIO > 0.0:
self.dropout = nn.Dropout(p=cfg.EN.DROPOUT_RATIO)
self.fc = nn.Linear(w_out, nc, bias=True)
def forward(self, x):
x = self.conv_swish(self.conv_bn(self.conv(x)))
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.dropout(x) if hasattr(self, "dropout") else x
x = self.fc(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, nc):
cx = net.complexity_conv2d(cx, w_in, w_out, 1, 1, 0)
cx = net.complexity_batchnorm2d(cx, w_out)
cx["h"], cx["w"] = 1, 1
cx = net.complexity_conv2d(cx, w_out, nc, 1, 1, 0, bias=True)
return cx
class Swish(nn.Module):
"""Swish activation function: x * sigmoid(x)."""
def __init__(self):
super(Swish, self).__init__()
def forward(self, x):
return x * torch.sigmoid(x)
class SE(nn.Module):
"""Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid."""
def __init__(self, w_in, w_se):
super(SE, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.f_ex = nn.Sequential(
nn.Conv2d(w_in, w_se, 1, bias=True),
Swish(),
nn.Conv2d(w_se, w_in, 1, bias=True),
nn.Sigmoid(),
)
def forward(self, x):
return x * self.f_ex(self.avg_pool(x))
@staticmethod
def complexity(cx, w_in, w_se):
h, w = cx["h"], cx["w"]
cx["h"], cx["w"] = 1, 1
cx = net.complexity_conv2d(cx, w_in, w_se, 1, 1, 0, bias=True)
cx = net.complexity_conv2d(cx, w_se, w_in, 1, 1, 0, bias=True)
cx["h"], cx["w"] = h, w
return cx
class MBConv(nn.Module):
"""Mobile inverted bottleneck block w/ SE (MBConv)."""
def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out):
# expansion, 3x3 dwise, BN, Swish, SE, 1x1, BN, skip_connection
super(MBConv, self).__init__()
self.exp = None
w_exp = int(w_in * exp_r)
if w_exp != w_in:
self.exp = nn.Conv2d(w_in, w_exp, 1, stride=1, padding=0, bias=False)
self.exp_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.exp_swish = Swish()
dwise_args = {"groups": w_exp, "padding": (kernel - 1) // 2, "bias": False}
self.dwise = nn.Conv2d(w_exp, w_exp, kernel, stride=stride, **dwise_args)
self.dwise_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.dwise_swish = Swish()
self.se = SE(w_exp, int(w_in * se_r))
self.lin_proj = nn.Conv2d(w_exp, w_out, 1, stride=1, padding=0, bias=False)
self.lin_proj_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
# Skip connection if in and out shapes are the same (MN-V2 style)
self.has_skip = stride == 1 and w_in == w_out
def forward(self, x):
f_x = x
if self.exp:
f_x = self.exp_swish(self.exp_bn(self.exp(f_x)))
f_x = self.dwise_swish(self.dwise_bn(self.dwise(f_x)))
f_x = self.se(f_x)
f_x = self.lin_proj_bn(self.lin_proj(f_x))
if self.has_skip:
if self.training and cfg.EN.DC_RATIO > 0.0:
f_x = net.drop_connect(f_x, cfg.EN.DC_RATIO)
f_x = x + f_x
return f_x
@staticmethod
def complexity(cx, w_in, exp_r, kernel, stride, se_r, w_out):
w_exp = int(w_in * exp_r)
if w_exp != w_in:
cx = net.complexity_conv2d(cx, w_in, w_exp, 1, 1, 0)
cx = net.complexity_batchnorm2d(cx, w_exp)
padding = (kernel - 1) // 2
cx = net.complexity_conv2d(cx, w_exp, w_exp, kernel, stride, padding, w_exp)
cx = net.complexity_batchnorm2d(cx, w_exp)
cx = SE.complexity(cx, w_exp, int(w_in * se_r))
cx = net.complexity_conv2d(cx, w_exp, w_out, 1, 1, 0)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class EffStage(nn.Module):
"""EfficientNet stage."""
def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, d):
super(EffStage, self).__init__()
for i in range(d):
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
name = "b{}".format(i + 1)
self.add_module(name, MBConv(b_w_in, exp_r, kernel, b_stride, se_r, w_out))
def forward(self, x):
for block in self.children():
x = block(x)
return x
@staticmethod
def complexity(cx, w_in, exp_r, kernel, stride, se_r, w_out, d):
for i in range(d):
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
cx = MBConv.complexity(cx, b_w_in, exp_r, kernel, b_stride, se_r, w_out)
return cx
class StemIN(nn.Module):
"""EfficientNet stem for ImageNet: 3x3, BN, Swish."""
def __init__(self, w_in, w_out):
super(StemIN, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 3, stride=2, padding=1, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.swish = Swish()
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out):
cx = net.complexity_conv2d(cx, w_in, w_out, 3, 2, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class EffNet(nn.Module):
"""EfficientNet model."""
@staticmethod
def get_args():
return {
"stem_w": cfg.EN.STEM_W,
"ds": cfg.EN.DEPTHS,
"ws": cfg.EN.WIDTHS,
"exp_rs": cfg.EN.EXP_RATIOS,
"se_r": cfg.EN.SE_R,
"ss": cfg.EN.STRIDES,
"ks": cfg.EN.KERNELS,
"head_w": cfg.EN.HEAD_W,
"nc": cfg.MODEL.NUM_CLASSES,
}
def __init__(self):
err_str = "Dataset {} is not supported"
assert cfg.TRAIN.DATASET in ["imagenet"], err_str.format(cfg.TRAIN.DATASET)
assert cfg.TEST.DATASET in ["imagenet"], err_str.format(cfg.TEST.DATASET)
super(EffNet, self).__init__()
self._construct(**EffNet.get_args())
self.apply(net.init_weights)
def _construct(self, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, nc):
stage_params = list(zip(ds, ws, exp_rs, ss, ks))
self.stem = StemIN(3, stem_w)
prev_w = stem_w
for i, (d, w, exp_r, stride, kernel) in enumerate(stage_params):
name = "s{}".format(i + 1)
self.add_module(name, EffStage(prev_w, exp_r, kernel, stride, se_r, w, d))
prev_w = w
self.head = EffHead(prev_w, head_w, nc)
def forward(self, x):
for module in self.children():
x = module(x)
return x
@staticmethod
def complexity(cx):
"""Computes model complexity. If you alter the model, make sure to update."""
return EffNet._complexity(cx, **EffNet.get_args())
@staticmethod
def _complexity(cx, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, nc):
stage_params = list(zip(ds, ws, exp_rs, ss, ks))
cx = StemIN.complexity(cx, 3, stem_w)
prev_w = stem_w
for d, w, exp_r, stride, kernel in stage_params:
cx = EffStage.complexity(cx, prev_w, exp_r, kernel, stride, se_r, w, d)
prev_w = w
cx = EffHead.complexity(cx, prev_w, head_w, nc)
return cx

View File

@@ -0,0 +1,634 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""NAS genotypes (adopted from DARTS)."""
from collections import namedtuple
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
# NASNet ops
NASNET_OPS = [
'skip_connect',
'conv_3x1_1x3',
'conv_7x1_1x7',
'dil_conv_3x3',
'avg_pool_3x3',
'max_pool_3x3',
'max_pool_5x5',
'max_pool_7x7',
'conv_1x1',
'conv_3x3',
'sep_conv_3x3',
'sep_conv_5x5',
'sep_conv_7x7',
]
# ENAS ops
ENAS_OPS = [
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'avg_pool_3x3',
'max_pool_3x3',
]
# AmoebaNet ops
AMOEBA_OPS = [
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'sep_conv_7x7',
'avg_pool_3x3',
'max_pool_3x3',
'dil_sep_conv_3x3',
'conv_7x1_1x7',
]
# NAO ops
NAO_OPS = [
'skip_connect',
'conv_1x1',
'conv_3x3',
'conv_3x1_1x3',
'conv_7x1_1x7',
'max_pool_2x2',
'max_pool_3x3',
'max_pool_5x5',
'avg_pool_2x2',
'avg_pool_3x3',
'avg_pool_5x5',
]
# PNAS ops
PNAS_OPS = [
'sep_conv_3x3',
'sep_conv_5x5',
'sep_conv_7x7',
'conv_7x1_1x7',
'skip_connect',
'avg_pool_3x3',
'max_pool_3x3',
'dil_conv_3x3',
]
# DARTS ops
DARTS_OPS = [
'none',
'max_pool_3x3',
'avg_pool_3x3',
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'dil_conv_3x3',
'dil_conv_5x5',
]
NASNet = Genotype(
normal=[
('sep_conv_5x5', 1),
('sep_conv_3x3', 0),
('sep_conv_5x5', 0),
('sep_conv_3x3', 0),
('avg_pool_3x3', 1),
('skip_connect', 0),
('avg_pool_3x3', 0),
('avg_pool_3x3', 0),
('sep_conv_3x3', 1),
('skip_connect', 1),
],
normal_concat=[2, 3, 4, 5, 6],
reduce=[
('sep_conv_5x5', 1),
('sep_conv_7x7', 0),
('max_pool_3x3', 1),
('sep_conv_7x7', 0),
('avg_pool_3x3', 1),
('sep_conv_5x5', 0),
('skip_connect', 3),
('avg_pool_3x3', 2),
('sep_conv_3x3', 2),
('max_pool_3x3', 1),
],
reduce_concat=[4, 5, 6],
)
PNASNet = Genotype(
normal=[
('sep_conv_5x5', 0),
('max_pool_3x3', 0),
('sep_conv_7x7', 1),
('max_pool_3x3', 1),
('sep_conv_5x5', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 4),
('max_pool_3x3', 1),
('sep_conv_3x3', 0),
('skip_connect', 1),
],
normal_concat=[2, 3, 4, 5, 6],
reduce=[
('sep_conv_5x5', 0),
('max_pool_3x3', 0),
('sep_conv_7x7', 1),
('max_pool_3x3', 1),
('sep_conv_5x5', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 4),
('max_pool_3x3', 1),
('sep_conv_3x3', 0),
('skip_connect', 1),
],
reduce_concat=[2, 3, 4, 5, 6],
)
AmoebaNet = Genotype(
normal=[
('avg_pool_3x3', 0),
('max_pool_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_5x5', 2),
('sep_conv_3x3', 0),
('avg_pool_3x3', 3),
('sep_conv_3x3', 1),
('skip_connect', 1),
('skip_connect', 0),
('avg_pool_3x3', 1),
],
normal_concat=[4, 5, 6],
reduce=[
('avg_pool_3x3', 0),
('sep_conv_3x3', 1),
('max_pool_3x3', 0),
('sep_conv_7x7', 2),
('sep_conv_7x7', 0),
('avg_pool_3x3', 1),
('max_pool_3x3', 0),
('max_pool_3x3', 1),
('conv_7x1_1x7', 0),
('sep_conv_3x3', 5),
],
reduce_concat=[3, 4, 6]
)
DARTS_V1 = Genotype(
normal=[
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('skip_connect', 0),
('sep_conv_3x3', 1),
('skip_connect', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('skip_connect', 2)
],
normal_concat=[2, 3, 4, 5],
reduce=[
('max_pool_3x3', 0),
('max_pool_3x3', 1),
('skip_connect', 2),
('max_pool_3x3', 0),
('max_pool_3x3', 0),
('skip_connect', 2),
('skip_connect', 2),
('avg_pool_3x3', 0)
],
reduce_concat=[2, 3, 4, 5]
)
DARTS_V2 = Genotype(
normal=[
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 1),
('skip_connect', 0),
('skip_connect', 0),
('dil_conv_3x3', 2)
],
normal_concat=[2, 3, 4, 5],
reduce=[
('max_pool_3x3', 0),
('max_pool_3x3', 1),
('skip_connect', 2),
('max_pool_3x3', 1),
('max_pool_3x3', 0),
('skip_connect', 2),
('skip_connect', 2),
('max_pool_3x3', 1)
],
reduce_concat=[2, 3, 4, 5]
)
PDARTS = Genotype(
normal=[
('skip_connect', 0),
('dil_conv_3x3', 1),
('skip_connect', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 3),
('sep_conv_3x3', 0),
('dil_conv_5x5', 4)
],
normal_concat=range(2, 6),
reduce=[
('avg_pool_3x3', 0),
('sep_conv_5x5', 1),
('sep_conv_3x3', 0),
('dil_conv_5x5', 2),
('max_pool_3x3', 0),
('dil_conv_3x3', 1),
('dil_conv_3x3', 1),
('dil_conv_5x5', 3)
],
reduce_concat=range(2, 6)
)
PCDARTS_C10 = Genotype(
normal=[
('sep_conv_3x3', 1),
('skip_connect', 0),
('sep_conv_3x3', 0),
('dil_conv_3x3', 1),
('sep_conv_5x5', 0),
('sep_conv_3x3', 1),
('avg_pool_3x3', 0),
('dil_conv_3x3', 1)
],
normal_concat=range(2, 6),
reduce=[
('sep_conv_5x5', 1),
('max_pool_3x3', 0),
('sep_conv_5x5', 1),
('sep_conv_5x5', 2),
('sep_conv_3x3', 0),
('sep_conv_3x3', 3),
('sep_conv_3x3', 1),
('sep_conv_3x3', 2)
],
reduce_concat=range(2, 6)
)
PCDARTS_IN1K = Genotype(
normal=[
('skip_connect', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 0),
('skip_connect', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 3),
('sep_conv_3x3', 1),
('dil_conv_5x5', 4)
],
normal_concat=range(2, 6),
reduce=[
('sep_conv_3x3', 0),
('skip_connect', 1),
('dil_conv_5x5', 2),
('max_pool_3x3', 1),
('sep_conv_3x3', 2),
('sep_conv_3x3', 1),
('sep_conv_5x5', 0),
('sep_conv_3x3', 3)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET_CLS = Genotype(
normal=[
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 0),
('sep_conv_3x3', 2),
('sep_conv_5x5', 1),
('sep_conv_3x3', 0)
],
normal_concat=range(2, 6),
reduce=[
('max_pool_3x3', 0),
('skip_connect', 1),
('max_pool_3x3', 0),
('dil_conv_5x5', 2),
('max_pool_3x3', 0),
('sep_conv_3x3', 2),
('sep_conv_3x3', 4),
('dil_conv_5x5', 3)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET_ROT = Genotype(
normal=[
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1)
],
normal_concat=range(2, 6),
reduce=[
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 2),
('sep_conv_3x3', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 2),
('sep_conv_3x3', 4),
('sep_conv_5x5', 2)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET_COL = Genotype(
normal=[
('skip_connect', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 1),
('skip_connect', 0),
('sep_conv_3x3', 0),
('sep_conv_3x3', 3),
('sep_conv_3x3', 0),
('sep_conv_3x3', 2)
],
normal_concat=range(2, 6),
reduce=[
('max_pool_3x3', 0),
('sep_conv_3x3', 1),
('max_pool_3x3', 0),
('sep_conv_3x3', 1),
('max_pool_3x3', 0),
('sep_conv_5x5', 3),
('max_pool_3x3', 0),
('sep_conv_3x3', 4)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET_JIG = Genotype(
normal=[
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 3),
('sep_conv_3x3', 1),
('sep_conv_5x5', 0)
],
normal_concat=range(2, 6),
reduce=[
('sep_conv_5x5', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_5x5', 0),
('sep_conv_3x3', 1)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET22K_CLS = Genotype(
normal=[
('sep_conv_3x3', 1),
('skip_connect', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 2),
('sep_conv_3x3', 1),
('sep_conv_3x3', 2),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0)
],
normal_concat=range(2, 6),
reduce=[
('max_pool_3x3', 0),
('max_pool_3x3', 1),
('dil_conv_5x5', 2),
('max_pool_3x3', 0),
('dil_conv_5x5', 3),
('dil_conv_5x5', 2),
('dil_conv_5x5', 4),
('dil_conv_5x5', 3)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET22K_ROT = Genotype(
normal=[
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1)
],
normal_concat=range(2, 6),
reduce=[
('max_pool_3x3', 0),
('sep_conv_5x5', 1),
('dil_conv_5x5', 2),
('sep_conv_5x5', 0),
('dil_conv_5x5', 3),
('sep_conv_3x3', 2),
('sep_conv_3x3', 4),
('sep_conv_3x3', 3)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET22K_COL = Genotype(
normal=[
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 2),
('sep_conv_3x3', 1),
('sep_conv_3x3', 3),
('sep_conv_3x3', 0)
],
normal_concat=range(2, 6),
reduce=[
('max_pool_3x3', 0),
('skip_connect', 1),
('dil_conv_5x5', 2),
('sep_conv_3x3', 0),
('sep_conv_3x3', 3),
('sep_conv_3x3', 0),
('sep_conv_3x3', 4),
('sep_conv_5x5', 1)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET22K_JIG = Genotype(
normal=[
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 0),
('sep_conv_3x3', 4)
],
normal_concat=range(2, 6),
reduce=[
('sep_conv_5x5', 0),
('skip_connect', 1),
('sep_conv_5x5', 0),
('sep_conv_3x3', 2),
('sep_conv_5x5', 0),
('sep_conv_5x5', 3),
('sep_conv_5x5', 0),
('sep_conv_5x5', 4)
],
reduce_concat=range(2, 6)
)
UNNAS_CITYSCAPES_SEG = Genotype(
normal=[
('skip_connect', 0),
('sep_conv_5x5', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1)
],
normal_concat=range(2, 6),
reduce=[
('sep_conv_3x3', 0),
('avg_pool_3x3', 1),
('avg_pool_3x3', 1),
('sep_conv_5x5', 0),
('sep_conv_3x3', 2),
('sep_conv_5x5', 0),
('sep_conv_3x3', 4),
('sep_conv_5x5', 2)
],
reduce_concat=range(2, 6)
)
UNNAS_CITYSCAPES_ROT = Genotype(
normal=[
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 2),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 3),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0)
],
normal_concat=range(2, 6),
reduce=[
('max_pool_3x3', 0),
('sep_conv_5x5', 1),
('sep_conv_5x5', 2),
('sep_conv_5x5', 1),
('sep_conv_5x5', 3),
('dil_conv_5x5', 2),
('sep_conv_5x5', 2),
('sep_conv_5x5', 0)
],
reduce_concat=range(2, 6)
)
UNNAS_CITYSCAPES_COL = Genotype(
normal=[
('dil_conv_3x3', 1),
('sep_conv_3x3', 0),
('skip_connect', 0),
('sep_conv_5x5', 2),
('dil_conv_3x3', 3),
('skip_connect', 0),
('skip_connect', 0),
('sep_conv_3x3', 1)
],
normal_concat=range(2, 6),
reduce=[
('avg_pool_3x3', 1),
('avg_pool_3x3', 0),
('avg_pool_3x3', 1),
('avg_pool_3x3', 0),
('avg_pool_3x3', 1),
('avg_pool_3x3', 0),
('avg_pool_3x3', 1),
('skip_connect', 4)
],
reduce_concat=range(2, 6)
)
UNNAS_CITYSCAPES_JIG = Genotype(
normal=[
('dil_conv_5x5', 1),
('sep_conv_5x5', 0),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 2),
('sep_conv_3x3', 0),
('dil_conv_5x5', 1)
],
normal_concat=range(2, 6),
reduce=[
('avg_pool_3x3', 0),
('skip_connect', 1),
('dil_conv_5x5', 1),
('dil_conv_5x5', 2),
('dil_conv_5x5', 2),
('dil_conv_5x5', 0),
('dil_conv_5x5', 3),
('dil_conv_5x5', 2)
],
reduce_concat=range(2, 6)
)
# Supported genotypes
GENOTYPES = {
'nas': NASNet,
'pnas': PNASNet,
'amoeba': AmoebaNet,
'darts_v1': DARTS_V1,
'darts_v2': DARTS_V2,
'pdarts': PDARTS,
'pcdarts_c10': PCDARTS_C10,
'pcdarts_in1k': PCDARTS_IN1K,
'unnas_imagenet_cls': UNNAS_IMAGENET_CLS,
'unnas_imagenet_rot': UNNAS_IMAGENET_ROT,
'unnas_imagenet_col': UNNAS_IMAGENET_COL,
'unnas_imagenet_jig': UNNAS_IMAGENET_JIG,
'unnas_imagenet22k_cls': UNNAS_IMAGENET22K_CLS,
'unnas_imagenet22k_rot': UNNAS_IMAGENET22K_ROT,
'unnas_imagenet22k_col': UNNAS_IMAGENET22K_COL,
'unnas_imagenet22k_jig': UNNAS_IMAGENET22K_JIG,
'unnas_cityscapes_seg': UNNAS_CITYSCAPES_SEG,
'unnas_cityscapes_rot': UNNAS_CITYSCAPES_ROT,
'unnas_cityscapes_col': UNNAS_CITYSCAPES_COL,
'unnas_cityscapes_jig': UNNAS_CITYSCAPES_JIG,
'custom': None,
}

299
pycls/models/nas/nas.py Normal file
View File

@@ -0,0 +1,299 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""NAS network (adopted from DARTS)."""
from torch.autograd import Variable
import torch
import torch.nn as nn
import pycls.core.logging as logging
from pycls.core.config import cfg
from pycls.models.common import Preprocess
from pycls.models.common import Classifier
from pycls.models.nas.genotypes import GENOTYPES
from pycls.models.nas.genotypes import Genotype
from pycls.models.nas.operations import FactorizedReduce
from pycls.models.nas.operations import OPS
from pycls.models.nas.operations import ReLUConvBN
from pycls.models.nas.operations import Identity
logger = logging.get_logger(__name__)
def drop_path(x, drop_prob):
"""Drop path (ported from DARTS)."""
if drop_prob > 0.:
keep_prob = 1.-drop_prob
mask = Variable(
torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)
)
x.div_(keep_prob)
x.mul_(mask)
return x
class Cell(nn.Module):
"""NAS cell (ported from DARTS)."""
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
super(Cell, self).__init__()
logger.info('{}, {}, {}'.format(C_prev_prev, C_prev, C))
if reduction_prev:
self.preprocess0 = FactorizedReduce(C_prev_prev, C)
else:
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
if reduction:
op_names, indices = zip(*genotype.reduce)
concat = genotype.reduce_concat
else:
op_names, indices = zip(*genotype.normal)
concat = genotype.normal_concat
self._compile(C, op_names, indices, concat, reduction)
def _compile(self, C, op_names, indices, concat, reduction):
assert len(op_names) == len(indices)
self._steps = len(op_names) // 2
self._concat = concat
self.multiplier = len(concat)
self._ops = nn.ModuleList()
for name, index in zip(op_names, indices):
stride = 2 if reduction and index < 2 else 1
op = OPS[name](C, stride, True)
self._ops += [op]
self._indices = indices
def forward(self, s0, s1, drop_prob):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i in range(self._steps):
h1 = states[self._indices[2*i]]
h2 = states[self._indices[2*i+1]]
op1 = self._ops[2*i]
op2 = self._ops[2*i+1]
h1 = op1(h1)
h2 = op2(h2)
if self.training and drop_prob > 0.:
if not isinstance(op1, Identity):
h1 = drop_path(h1, drop_prob)
if not isinstance(op2, Identity):
h2 = drop_path(h2, drop_prob)
s = h1 + h2
states += [s]
return torch.cat([states[i] for i in self._concat], dim=1)
class AuxiliaryHeadCIFAR(nn.Module):
def __init__(self, C, num_classes):
"""assuming input size 8x8"""
super(AuxiliaryHeadCIFAR, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0),-1))
return x
class AuxiliaryHeadImageNet(nn.Module):
def __init__(self, C, num_classes):
"""assuming input size 14x14"""
super(AuxiliaryHeadImageNet, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False),
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
# NOTE: This batchnorm was omitted in my earlier implementation due to a typo.
# Commenting it out for consistency with the experiments in the paper.
# nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0),-1))
return x
class NetworkCIFAR(nn.Module):
"""CIFAR network (ported from DARTS)."""
def __init__(self, C, num_classes, layers, auxiliary, genotype):
super(NetworkCIFAR, self).__init__()
self._layers = layers
self._auxiliary = auxiliary
stem_multiplier = 3
C_curr = stem_multiplier*C
self.stem = nn.Sequential(
nn.Conv2d(cfg.MODEL.INPUT_CHANNELS, C_curr, 3, padding=1, bias=False),
nn.BatchNorm2d(C_curr)
)
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
self.cells = nn.ModuleList()
reduction_prev = False
for i in range(layers):
if i in [layers//3, 2*layers//3]:
C_curr *= 2
reduction = True
else:
reduction = False
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
reduction_prev = reduction
self.cells += [cell]
C_prev_prev, C_prev = C_prev, cell.multiplier*C_curr
if i == 2*layers//3:
C_to_auxiliary = C_prev
if auxiliary:
self.auxiliary_head = AuxiliaryHeadCIFAR(C_to_auxiliary, num_classes)
self.classifier = Classifier(C_prev, num_classes)
def forward(self, input):
input = Preprocess(input)
logits_aux = None
s0 = s1 = self.stem(input)
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
if i == 2*self._layers//3:
if self._auxiliary and self.training:
logits_aux = self.auxiliary_head(s1)
logits = self.classifier(s1, input.shape[2:])
if self._auxiliary and self.training:
return logits, logits_aux
return logits
class NetworkImageNet(nn.Module):
"""ImageNet network (ported from DARTS)."""
def __init__(self, C, num_classes, layers, auxiliary, genotype):
super(NetworkImageNet, self).__init__()
self._layers = layers
self._auxiliary = auxiliary
self.stem0 = nn.Sequential(
nn.Conv2d(cfg.MODEL.INPUT_CHANNELS, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C // 2),
nn.ReLU(inplace=True),
nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
self.stem1 = nn.Sequential(
nn.ReLU(inplace=True),
nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
C_prev_prev, C_prev, C_curr = C, C, C
self.cells = nn.ModuleList()
reduction_prev = True
reduction_layers = [layers//3] if cfg.TASK == 'seg' else [layers//3, 2*layers//3]
for i in range(layers):
if i in reduction_layers:
C_curr *= 2
reduction = True
else:
reduction = False
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
reduction_prev = reduction
self.cells += [cell]
C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
if i == 2 * layers // 3:
C_to_auxiliary = C_prev
if auxiliary:
self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes)
self.classifier = Classifier(C_prev, num_classes)
def forward(self, input):
input = Preprocess(input)
logits_aux = None
s0 = self.stem0(input)
s1 = self.stem1(s0)
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
if i == 2 * self._layers // 3:
if self._auxiliary and self.training:
logits_aux = self.auxiliary_head(s1)
logits = self.classifier(s1, input.shape[2:])
if self._auxiliary and self.training:
return logits, logits_aux
return logits
class NAS(nn.Module):
"""NAS net wrapper (delegates to nets from DARTS)."""
def __init__(self):
assert cfg.TRAIN.DATASET in ['cifar10', 'imagenet', 'cityscapes'], \
'Training on {} is not supported'.format(cfg.TRAIN.DATASET)
assert cfg.TEST.DATASET in ['cifar10', 'imagenet', 'cityscapes'], \
'Testing on {} is not supported'.format(cfg.TEST.DATASET)
assert cfg.NAS.GENOTYPE in GENOTYPES, \
'Genotype {} not supported'.format(cfg.NAS.GENOTYPE)
super(NAS, self).__init__()
logger.info('Constructing NAS: {}'.format(cfg.NAS))
# Use a custom or predefined genotype
if cfg.NAS.GENOTYPE == 'custom':
genotype = Genotype(
normal=cfg.NAS.CUSTOM_GENOTYPE[0],
normal_concat=cfg.NAS.CUSTOM_GENOTYPE[1],
reduce=cfg.NAS.CUSTOM_GENOTYPE[2],
reduce_concat=cfg.NAS.CUSTOM_GENOTYPE[3],
)
else:
genotype = GENOTYPES[cfg.NAS.GENOTYPE]
# Determine the network constructor for dataset
if 'cifar' in cfg.TRAIN.DATASET:
net_ctor = NetworkCIFAR
else:
net_ctor = NetworkImageNet
# Construct the network
self.net_ = net_ctor(
C=cfg.NAS.WIDTH,
num_classes=cfg.MODEL.NUM_CLASSES,
layers=cfg.NAS.DEPTH,
auxiliary=cfg.NAS.AUX,
genotype=genotype
)
# Drop path probability (set / annealed based on epoch)
self.net_.drop_path_prob = 0.0
def set_drop_path_prob(self, drop_path_prob):
self.net_.drop_path_prob = drop_path_prob
def forward(self, x):
return self.net_.forward(x)

View File

@@ -0,0 +1,201 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""NAS ops (adopted from DARTS)."""
import torch
import torch.nn as nn
OPS = {
'none': lambda C, stride, affine:
Zero(stride),
'avg_pool_2x2': lambda C, stride, affine:
nn.AvgPool2d(2, stride=stride, padding=0, count_include_pad=False),
'avg_pool_3x3': lambda C, stride, affine:
nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
'avg_pool_5x5': lambda C, stride, affine:
nn.AvgPool2d(5, stride=stride, padding=2, count_include_pad=False),
'max_pool_2x2': lambda C, stride, affine:
nn.MaxPool2d(2, stride=stride, padding=0),
'max_pool_3x3': lambda C, stride, affine:
nn.MaxPool2d(3, stride=stride, padding=1),
'max_pool_5x5': lambda C, stride, affine:
nn.MaxPool2d(5, stride=stride, padding=2),
'max_pool_7x7': lambda C, stride, affine:
nn.MaxPool2d(7, stride=stride, padding=3),
'skip_connect': lambda C, stride, affine:
Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
'conv_1x1': lambda C, stride, affine:
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 1, stride=stride, padding=0, bias=False),
nn.BatchNorm2d(C, affine=affine)
),
'conv_3x3': lambda C, stride, affine:
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(C, affine=affine)
),
'sep_conv_3x3': lambda C, stride, affine:
SepConv(C, C, 3, stride, 1, affine=affine),
'sep_conv_5x5': lambda C, stride, affine:
SepConv(C, C, 5, stride, 2, affine=affine),
'sep_conv_7x7': lambda C, stride, affine:
SepConv(C, C, 7, stride, 3, affine=affine),
'dil_conv_3x3': lambda C, stride, affine:
DilConv(C, C, 3, stride, 2, 2, affine=affine),
'dil_conv_5x5': lambda C, stride, affine:
DilConv(C, C, 5, stride, 4, 2, affine=affine),
'dil_sep_conv_3x3': lambda C, stride, affine:
DilSepConv(C, C, 3, stride, 2, 2, affine=affine),
'conv_3x1_1x3': lambda C, stride, affine:
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, (1,3), stride=(1, stride), padding=(0, 1), bias=False),
nn.Conv2d(C, C, (3,1), stride=(stride, 1), padding=(1, 0), bias=False),
nn.BatchNorm2d(C, affine=affine)
),
'conv_7x1_1x7': lambda C, stride, affine:
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, (1,7), stride=(1, stride), padding=(0, 3), bias=False),
nn.Conv2d(C, C, (7,1), stride=(stride, 1), padding=(3, 0), bias=False),
nn.BatchNorm2d(C, affine=affine)
),
}
class ReLUConvBN(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_out, kernel_size, stride=stride,
padding=padding, bias=False
),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.op(x)
class DilConv(nn.Module):
def __init__(
self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True
):
super(DilConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
def forward(self, x):
return self.op(x)
class SepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(SepConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=stride,
padding=padding, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_in, affine=affine),
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=1,
padding=padding, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
def forward(self, x):
return self.op(x)
class DilSepConv(nn.Module):
def __init__(
self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True
):
super(DilSepConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_in, affine=affine),
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=1,
padding=padding, dilation=dilation, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
def forward(self, x):
return self.op(x)
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class Zero(nn.Module):
def __init__(self, stride):
super(Zero, self).__init__()
self.stride = stride
def forward(self, x):
if self.stride == 1:
return x.mul(0.)
return x[:,:,::self.stride,::self.stride].mul(0.)
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, affine=True):
super(FactorizedReduce, self).__init__()
assert C_out % 2 == 0
self.relu = nn.ReLU(inplace=False)
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
def forward(self, x):
x = self.relu(x)
y = self.pad(x)
out = torch.cat([self.conv_1(x), self.conv_2(y[:,:,1:,1:])], dim=1)
out = self.bn(out)
return out

89
pycls/models/regnet.py Normal file
View File

@@ -0,0 +1,89 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""RegNet models."""
import numpy as np
from pycls.core.config import cfg
from pycls.models.anynet import AnyNet
def quantize_float(f, q):
"""Converts a float to closest non-zero int divisible by q."""
return int(round(f / q) * q)
def adjust_ws_gs_comp(ws, bms, gs):
"""Adjusts the compatibility of widths and groups."""
ws_bot = [int(w * b) for w, b in zip(ws, bms)]
gs = [min(g, w_bot) for g, w_bot in zip(gs, ws_bot)]
ws_bot = [quantize_float(w_bot, g) for w_bot, g in zip(ws_bot, gs)]
ws = [int(w_bot / b) for w_bot, b in zip(ws_bot, bms)]
return ws, gs
def get_stages_from_blocks(ws, rs):
"""Gets ws/ds of network at each stage from per block values."""
ts_temp = zip(ws + [0], [0] + ws, rs + [0], [0] + rs)
ts = [w != wp or r != rp for w, wp, r, rp in ts_temp]
s_ws = [w for w, t in zip(ws, ts[:-1]) if t]
s_ds = np.diff([d for d, t in zip(range(len(ts)), ts) if t]).tolist()
return s_ws, s_ds
def generate_regnet(w_a, w_0, w_m, d, q=8):
"""Generates per block ws from RegNet parameters."""
assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0
ws_cont = np.arange(d) * w_a + w_0
ks = np.round(np.log(ws_cont / w_0) / np.log(w_m))
ws = w_0 * np.power(w_m, ks)
ws = np.round(np.divide(ws, q)) * q
num_stages, max_stage = len(np.unique(ws)), ks.max() + 1
ws, ws_cont = ws.astype(int).tolist(), ws_cont.tolist()
return ws, num_stages, max_stage, ws_cont
class RegNet(AnyNet):
"""RegNet model."""
@staticmethod
def get_args():
"""Convert RegNet to AnyNet parameter format."""
# Generate RegNet ws per block
w_a, w_0, w_m, d = cfg.REGNET.WA, cfg.REGNET.W0, cfg.REGNET.WM, cfg.REGNET.DEPTH
ws, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d)
# Convert to per stage format
s_ws, s_ds = get_stages_from_blocks(ws, ws)
# Use the same gw, bm and ss for each stage
s_gs = [cfg.REGNET.GROUP_W for _ in range(num_stages)]
s_bs = [cfg.REGNET.BOT_MUL for _ in range(num_stages)]
s_ss = [cfg.REGNET.STRIDE for _ in range(num_stages)]
# Adjust the compatibility of ws and gws
s_ws, s_gs = adjust_ws_gs_comp(s_ws, s_bs, s_gs)
# Get AnyNet arguments defining the RegNet
return {
"stem_type": cfg.REGNET.STEM_TYPE,
"stem_w": cfg.REGNET.STEM_W,
"block_type": cfg.REGNET.BLOCK_TYPE,
"ds": s_ds,
"ws": s_ws,
"ss": s_ss,
"bms": s_bs,
"gws": s_gs,
"se_r": cfg.REGNET.SE_R if cfg.REGNET.SE_ON else None,
"nc": cfg.MODEL.NUM_CLASSES,
}
def __init__(self):
kwargs = RegNet.get_args()
super(RegNet, self).__init__(**kwargs)
@staticmethod
def complexity(cx, **kwargs):
"""Computes model complexity. If you alter the model, make sure to update."""
kwargs = RegNet.get_args() if not kwargs else kwargs
return AnyNet.complexity(cx, **kwargs)

280
pycls/models/resnet.py Normal file
View File

@@ -0,0 +1,280 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""ResNe(X)t models."""
import pycls.core.net as net
import torch.nn as nn
from pycls.core.config import cfg
# Stage depths for ImageNet models
_IN_STAGE_DS = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)}
def get_trans_fun(name):
"""Retrieves the transformation function by name."""
trans_funs = {
"basic_transform": BasicTransform,
"bottleneck_transform": BottleneckTransform,
}
err_str = "Transformation function '{}' not supported"
assert name in trans_funs.keys(), err_str.format(name)
return trans_funs[name]
class ResHead(nn.Module):
"""ResNet head: AvgPool, 1x1."""
def __init__(self, w_in, nc):
super(ResHead, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(w_in, nc, bias=True)
def forward(self, x):
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
@staticmethod
def complexity(cx, w_in, nc):
cx["h"], cx["w"] = 1, 1
cx = net.complexity_conv2d(cx, w_in, nc, 1, 1, 0, bias=True)
return cx
class BasicTransform(nn.Module):
"""Basic transformation: 3x3, BN, ReLU, 3x3, BN."""
def __init__(self, w_in, w_out, stride, w_b=None, num_gs=1):
err_str = "Basic transform does not support w_b and num_gs options"
assert w_b is None and num_gs == 1, err_str
super(BasicTransform, self).__init__()
self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False)
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False)
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_bn.final_bn = True
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, w_b=None, num_gs=1):
err_str = "Basic transform does not support w_b and num_gs options"
assert w_b is None and num_gs == 1, err_str
cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class BottleneckTransform(nn.Module):
"""Bottleneck transformation: 1x1, BN, ReLU, 3x3, BN, ReLU, 1x1, BN."""
def __init__(self, w_in, w_out, stride, w_b, num_gs):
super(BottleneckTransform, self).__init__()
# MSRA -> stride=2 is on 1x1; TH/C2 -> stride=2 is on 3x3
(s1, s3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride)
self.a = nn.Conv2d(w_in, w_b, 1, stride=s1, padding=0, bias=False)
self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
self.b = nn.Conv2d(w_b, w_b, 3, stride=s3, padding=1, groups=num_gs, bias=False)
self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
self.c = nn.Conv2d(w_b, w_out, 1, stride=1, padding=0, bias=False)
self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.c_bn.final_bn = True
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, w_b, num_gs):
(s1, s3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride)
cx = net.complexity_conv2d(cx, w_in, w_b, 1, s1, 0)
cx = net.complexity_batchnorm2d(cx, w_b)
cx = net.complexity_conv2d(cx, w_b, w_b, 3, s3, 1, num_gs)
cx = net.complexity_batchnorm2d(cx, w_b)
cx = net.complexity_conv2d(cx, w_b, w_out, 1, 1, 0)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class ResBlock(nn.Module):
"""Residual block: x + F(x)."""
def __init__(self, w_in, w_out, stride, trans_fun, w_b=None, num_gs=1):
super(ResBlock, self).__init__()
# Use skip connection with projection if shape changes
self.proj_block = (w_in != w_out) or (stride != 1)
if self.proj_block:
self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.f = trans_fun(w_in, w_out, stride, w_b, num_gs)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
if self.proj_block:
x = self.bn(self.proj(x)) + self.f(x)
else:
x = x + self.f(x)
x = self.relu(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, trans_fun, w_b, num_gs):
proj_block = (w_in != w_out) or (stride != 1)
if proj_block:
h, w = cx["h"], cx["w"]
cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0)
cx = net.complexity_batchnorm2d(cx, w_out)
cx["h"], cx["w"] = h, w # parallel branch
cx = trans_fun.complexity(cx, w_in, w_out, stride, w_b, num_gs)
return cx
class ResStage(nn.Module):
"""Stage of ResNet."""
def __init__(self, w_in, w_out, stride, d, w_b=None, num_gs=1):
super(ResStage, self).__init__()
for i in range(d):
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
trans_fun = get_trans_fun(cfg.RESNET.TRANS_FUN)
res_block = ResBlock(b_w_in, w_out, b_stride, trans_fun, w_b, num_gs)
self.add_module("b{}".format(i + 1), res_block)
def forward(self, x):
for block in self.children():
x = block(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, d, w_b=None, num_gs=1):
for i in range(d):
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
trans_f = get_trans_fun(cfg.RESNET.TRANS_FUN)
cx = ResBlock.complexity(cx, b_w_in, w_out, b_stride, trans_f, w_b, num_gs)
return cx
class ResStemCifar(nn.Module):
"""ResNet stem for CIFAR: 3x3, BN, ReLU."""
def __init__(self, w_in, w_out):
super(ResStemCifar, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 3, stride=1, padding=1, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out):
cx = net.complexity_conv2d(cx, w_in, w_out, 3, 1, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class ResStemIN(nn.Module):
"""ResNet stem for ImageNet: 7x7, BN, ReLU, MaxPool."""
def __init__(self, w_in, w_out):
super(ResStemIN, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 7, stride=2, padding=3, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
self.pool = nn.MaxPool2d(3, stride=2, padding=1)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out):
cx = net.complexity_conv2d(cx, w_in, w_out, 7, 2, 3)
cx = net.complexity_batchnorm2d(cx, w_out)
cx = net.complexity_maxpool2d(cx, 3, 2, 1)
return cx
class ResNet(nn.Module):
"""ResNet model."""
def __init__(self):
datasets = ["cifar10", "imagenet"]
err_str = "Dataset {} is not supported"
assert cfg.TRAIN.DATASET in datasets, err_str.format(cfg.TRAIN.DATASET)
assert cfg.TEST.DATASET in datasets, err_str.format(cfg.TEST.DATASET)
super(ResNet, self).__init__()
if "cifar" in cfg.TRAIN.DATASET:
self._construct_cifar()
else:
self._construct_imagenet()
self.apply(net.init_weights)
def _construct_cifar(self):
err_str = "Model depth should be of the format 6n + 2 for cifar"
assert (cfg.MODEL.DEPTH - 2) % 6 == 0, err_str
d = int((cfg.MODEL.DEPTH - 2) / 6)
self.stem = ResStemCifar(3, 16)
self.s1 = ResStage(16, 16, stride=1, d=d)
self.s2 = ResStage(16, 32, stride=2, d=d)
self.s3 = ResStage(32, 64, stride=2, d=d)
self.head = ResHead(64, nc=cfg.MODEL.NUM_CLASSES)
def _construct_imagenet(self):
g, gw = cfg.RESNET.NUM_GROUPS, cfg.RESNET.WIDTH_PER_GROUP
(d1, d2, d3, d4) = _IN_STAGE_DS[cfg.MODEL.DEPTH]
w_b = gw * g
self.stem = ResStemIN(3, 64)
self.s1 = ResStage(64, 256, stride=1, d=d1, w_b=w_b, num_gs=g)
self.s2 = ResStage(256, 512, stride=2, d=d2, w_b=w_b * 2, num_gs=g)
self.s3 = ResStage(512, 1024, stride=2, d=d3, w_b=w_b * 4, num_gs=g)
self.s4 = ResStage(1024, 2048, stride=2, d=d4, w_b=w_b * 8, num_gs=g)
self.head = ResHead(2048, nc=cfg.MODEL.NUM_CLASSES)
def forward(self, x):
for module in self.children():
x = module(x)
return x
@staticmethod
def complexity(cx):
"""Computes model complexity. If you alter the model, make sure to update."""
if "cifar" in cfg.TRAIN.DATASET:
d = int((cfg.MODEL.DEPTH - 2) / 6)
cx = ResStemCifar.complexity(cx, 3, 16)
cx = ResStage.complexity(cx, 16, 16, stride=1, d=d)
cx = ResStage.complexity(cx, 16, 32, stride=2, d=d)
cx = ResStage.complexity(cx, 32, 64, stride=2, d=d)
cx = ResHead.complexity(cx, 64, nc=cfg.MODEL.NUM_CLASSES)
else:
g, gw = cfg.RESNET.NUM_GROUPS, cfg.RESNET.WIDTH_PER_GROUP
(d1, d2, d3, d4) = _IN_STAGE_DS[cfg.MODEL.DEPTH]
w_b = gw * g
cx = ResStemIN.complexity(cx, 3, 64)
cx = ResStage.complexity(cx, 64, 256, 1, d=d1, w_b=w_b, num_gs=g)
cx = ResStage.complexity(cx, 256, 512, 2, d=d2, w_b=w_b * 2, num_gs=g)
cx = ResStage.complexity(cx, 512, 1024, 2, d=d3, w_b=w_b * 4, num_gs=g)
cx = ResStage.complexity(cx, 1024, 2048, 2, d=d4, w_b=w_b * 8, num_gs=g)
cx = ResHead.complexity(cx, 2048, nc=cfg.MODEL.NUM_CLASSES)
return cx

View File

@@ -1,13 +0,0 @@
#!/bin/bash
python search.py --dataset cifar10 --data_loc '../datasets/cifar10' --n_runs $1 --n_samples 10
python search.py --dataset cifar10 --trainval --data_loc '../datasets/cifar10' --n_runs $1 --n_samples 10
python search.py --dataset cifar100 --data_loc '../datasets/cifar100' --n_runs $1 --n_samples 10
python search.py --dataset ImageNet16-120 --data_loc '../datasets/ImageNet16' --n_runs $1 --n_samples 10
python search.py --dataset cifar10 --data_loc '../datasets/cifar10' --n_runs $1 --n_samples 100
python search.py --dataset cifar10 --trainval --data_loc '../datasets/cifar10' --n_runs $1 --n_samples 100
python search.py --dataset cifar100 --data_loc '../datasets/cifar100' --n_runs $1 --n_samples 100
python search.py --dataset ImageNet16-120 --data_loc '../datasets/ImageNet16' --n_runs $1 --n_samples 100
python process_results.py --n_runs $1

Binary file not shown.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 30 KiB

164
score_networks.py Normal file
View File

@@ -0,0 +1,164 @@
import argparse
import nasspace
import datasets
import random
import numpy as np
import torch
import os
from scores import get_score_func
from scipy import stats
from pycls.models.nas.nas import Cell
from utils import add_dropout, init_network
parser = argparse.ArgumentParser(description='NAS Without Training')
parser.add_argument('--data_loc', default='../cifardata/', type=str, help='dataset folder')
parser.add_argument('--api_loc', default='../NAS-Bench-201-v1_0-e61699.pth',
type=str, help='path to API')
parser.add_argument('--save_loc', default='results', type=str, help='folder to save results')
parser.add_argument('--save_string', default='naswot', type=str, help='prefix of results file')
parser.add_argument('--score', default='hook_logdet', type=str, help='the score to evaluate')
parser.add_argument('--nasspace', default='nasbench201', type=str, help='the nas search space to use')
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--repeat', default=1, type=int, help='how often to repeat a single image with a batch')
parser.add_argument('--augtype', default='none', type=str, help='which perturbations to use')
parser.add_argument('--sigma', default=0.05, type=float, help='noise level if augtype is "gaussnoise"')
parser.add_argument('--GPU', default='0', type=str)
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--init', default='', type=str)
parser.add_argument('--trainval', action='store_true')
parser.add_argument('--dropout', action='store_true')
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--maxofn', default=1, type=int, help='score is the max of this many evaluations of the network')
parser.add_argument('--n_samples', default=100, type=int)
parser.add_argument('--n_runs', default=500, type=int)
parser.add_argument('--stem_out_channels', default=16, type=int, help='output channels of stem convolution (nasbench101)')
parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules (nasbench101)')
parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack (nasbench101)')
parser.add_argument('--num_labels', default=1, type=int, help='#classes (nasbench101)')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU
# Reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
def get_batch_jacobian(net, x, target, device, args=None):
net.zero_grad()
x.requires_grad_(True)
y, out = net(x)
y.backward(torch.ones_like(y))
jacob = x.grad.detach()
return jacob, target.detach(), y.detach(), out.detach()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
savedataset = args.dataset
dataset = 'fake' if 'fake' in args.dataset else args.dataset
args.dataset = args.dataset.replace('fake', '')
if args.dataset == 'cifar10':
args.dataset = args.dataset + '-valid'
searchspace = nasspace.get_search_space(args)
if 'valid' in args.dataset:
args.dataset = args.dataset.replace('-valid', '')
train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
os.makedirs(args.save_loc, exist_ok=True)
filename = f'{args.save_loc}/{args.save_string}_{args.score}_{args.nasspace}_{savedataset}{"_" + args.init + "_" if args.init != "" else args.init}_{"_dropout" if args.dropout else ""}_{args.augtype}_{args.sigma}_{args.repeat}_{args.trainval}_{args.batch_size}_{args.maxofn}_{args.seed}'
accfilename = f'{args.save_loc}/{args.save_string}_accs_{args.nasspace}_{savedataset}_{args.trainval}'
if args.dataset == 'cifar10':
acc_type = 'ori-test'
val_acc_type = 'x-valid'
else:
acc_type = 'x-test'
val_acc_type = 'x-valid'
scores = np.zeros(len(searchspace))
try:
accs = np.load(accfilename + '.npy')
except:
accs = np.zeros(len(searchspace))
for i, (uid, network) in enumerate(searchspace):
# Reproducibility
try:
if args.dropout:
add_dropout(network, args.sigma)
if args.init != '':
init_network(network, args.init)
if 'hook_' in args.score:
network.K = np.zeros((args.batch_size, args.batch_size))
def counting_forward_hook(module, inp, out):
try:
if not module.visited_backwards:
return
if isinstance(inp, tuple):
inp = inp[0]
inp = inp.view(inp.size(0), -1)
x = (inp > 0).float()
K = x @ x.t()
K2 = (1.-x) @ (1.-x.t())
network.K = network.K + K.cpu().numpy() + K2.cpu().numpy()
except:
pass
def counting_backward_hook(module, inp, out):
module.visited_backwards = True
for name, module in network.named_modules():
if 'ReLU' in str(type(module)):
#hooks[name] = module.register_forward_hook(counting_hook)
module.register_forward_hook(counting_forward_hook)
module.register_backward_hook(counting_backward_hook)
network = network.to(device)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
s = []
for j in range(args.maxofn):
data_iterator = iter(train_loader)
x, target = next(data_iterator)
x2 = torch.clone(x)
x2 = x2.to(device)
x, target = x.to(device), target.to(device)
jacobs, labels, y, out = get_batch_jacobian(network, x, target, device, args)
if 'hook_' in args.score:
network(x2.to(device))
s.append(get_score_func(args.score)(network.K, target))
else:
s.append(get_score_func(args.score)(jacobs, labels))
scores[i] = np.mean(s)
accs[i] = searchspace.get_final_accuracy(uid, acc_type, args.trainval)
accs_ = accs[~np.isnan(scores)]
scores_ = scores[~np.isnan(scores)]
numnan = np.isnan(scores).sum()
tau, p = stats.kendalltau(accs_[:max(i-numnan, 1)], scores_[:max(i-numnan, 1)])
print(f'{tau}')
if i % 1000 == 0:
np.save(filename, scores)
np.save(accfilename, accs)
except Exception as e:
print(e)
accs[i] = searchspace.get_final_accuracy(uid, acc_type, args.trainval)
scores[i] = np.nan
np.save(filename, scores)
np.save(accfilename, accs)

32
scorehook.sh Normal file
View File

@@ -0,0 +1,32 @@
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nasbench201 --batch_size 128 --GPU 3 --dataset cifar10
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nasbench201 --batch_size 128 --GPU 3 --dataset cifar100 --data_loc ../cifar100/
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nasbench201 --batch_size 128 --GPU 3 --dataset ImageNet16-120 --data_loc ../imagenet16/Imagenet16/
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_pnas --batch_size 128 --GPU 3
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_enas --batch_size 128 --GPU 3
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_darts --batch_size 128 --GPU 3
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_darts_fix-w-d --batch_size 128 --GPU 3
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_nasnet --batch_size 128 --GPU 3
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_amoeba --batch_size 128 --GPU 3
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_resnet --batch_size 128 --GPU 3
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_resnext-a --batch_size 128 --GPU 3
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_resnext-b --batch_size 128 --GPU 3
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace amoeba_in --batch_size 128 --GPU 3 --dataset imagenette2 --data_loc ../imagenette2/
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_amoeba_in --batch_size 128 --GPU 3 --dataset imagenette2 --data_loc ../imagenette2/
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_darts_in --batch_size 128 --GPU 3 --dataset imagenette2 --data_loc ../imagenette2/
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_nasnet_in --batch_size 128 --GPU 3 --dataset imagenette2 --data_loc ../imagenette2/
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_pnas_in --batch_size 128 --GPU 3 --dataset imagenette2 --data_loc ../imagenette2/
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_enas_in --batch_size 128 --GPU 3 --dataset imagenette2 --data_loc ../imagenette2/
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_resnext-a_in --batch_size 128 --GPU 3 --dataset imagenette2 --data_loc ../imagenette2/
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nasbench201 --batch_size 128 --GPU 3 --dataset cifar100 --data_loc ../cifar100/
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nasbench201 --batch_size 128 --GPU 3 --dataset ImageNet16-120 --data_loc ../imagenet16/Imagenet16/
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nasbench101 --batch_size 128 --GPU 3 --api_loc ../nasbench_only108.tfrecord

21
scores.py Normal file
View File

@@ -0,0 +1,21 @@
import numpy as np
import torch
def hooklogdet(K, labels=None):
s, ld = np.linalg.slogdet(K)
return ld
def random_score(jacob, label=None):
return np.random.normal()
_scores = {
'hook_logdet': hooklogdet,
'random': random_score
}
def get_score_func(score_name):
return _scores[score_name]

204
search.py
View File

@@ -1,35 +1,49 @@
import os
import time
import argparse
import nasspace
import datasets
import random
import numpy as np
import torch
import os
from scores import get_score_func
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import trange
from statistics import mean
import time
from utils import add_dropout
parser = argparse.ArgumentParser(description='NAS Without Training')
parser.add_argument('--data_loc', default='../datasets/cifar', type=str, help='dataset folder')
parser.add_argument('--api_loc', default='../datasets/NAS-Bench-201-v1_1-096897.pth',
parser.add_argument('--data_loc', default='../cifardata/', type=str, help='dataset folder')
parser.add_argument('--api_loc', default='../NAS-Bench-201-v1_0-e61699.pth',
type=str, help='path to API')
parser.add_argument('--save_loc', default='results', type=str, help='folder to save results')
parser.add_argument('--batch_size', default=256, type=int)
parser.add_argument('--save_loc', default='results/ICML', type=str, help='folder to save results')
parser.add_argument('--save_string', default='naswot', type=str, help='prefix of results file')
parser.add_argument('--score', default='hook_logdet', type=str, help='the score to evaluate')
parser.add_argument('--nasspace', default='nasbench201', type=str, help='the nas search space to use')
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--kernel', action='store_true')
parser.add_argument('--dropout', action='store_true')
parser.add_argument('--repeat', default=1, type=int, help='how often to repeat a single image with a batch')
parser.add_argument('--augtype', default='none', type=str, help='which perturbations to use')
parser.add_argument('--sigma', default=0.05, type=float, help='noise level if augtype is "gaussnoise"')
parser.add_argument('--GPU', default='0', type=str)
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--init', default='', type=str)
parser.add_argument('--trainval', action='store_true')
parser.add_argument('--activations', action='store_true')
parser.add_argument('--cosine', action='store_true')
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--n_samples', default=100, type=int)
parser.add_argument('--n_runs', default=500, type=int)
parser.add_argument('--stem_out_channels', default=16, type=int, help='output channels of stem convolution (nasbench101)')
parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules (nasbench101)')
parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack (nasbench101)')
parser.add_argument('--num_labels', default=1, type=int, help='#classes (nasbench101)')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torch.optim as optim
from models import get_cell_based_tiny_net
# Reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
@@ -37,120 +51,140 @@ random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
import torchvision.transforms as transforms
from datasets import get_datasets
from config_utils import load_config
from nas_201_api import NASBench201API as API
def get_batch_jacobian(net, x, target, to, device, args=None):
def get_batch_jacobian(net, x, target, device, args=None):
net.zero_grad()
x.requires_grad_(True)
_, y = net(x)
y, ints = net(x)
y.backward(torch.ones_like(y))
jacob = x.grad.detach()
return jacob, target.detach()
def eval_score(jacob, labels=None):
corrs = np.corrcoef(jacob)
v, _ = np.linalg.eig(corrs)
k = 1e-5
return -np.sum(np.log(v + k) + 1./(v + k))
return jacob, target.detach(), y.detach(), ints.detach()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
THE_START = time.time()
api = API(args.api_loc)
searchspace = nasspace.get_search_space(args)
train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
os.makedirs(args.save_loc, exist_ok=True)
train_data, valid_data, xshape, class_num = get_datasets(args.dataset, args.data_loc, cutout=0)
if args.dataset == 'cifar10':
acc_type = 'ori-test'
val_acc_type = 'x-valid'
else:
acc_type = 'x-test'
val_acc_type = 'x-valid'
if args.trainval:
cifar_split = load_config('config_utils/cifar-split.txt', None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,
num_workers=0, pin_memory=True, sampler= torch.utils.data.sampler.SubsetRandomSampler(train_split))
else:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
num_workers=0, pin_memory=True)
times = []
chosen = []
acc = []
val_acc = []
topscores = []
dset = args.dataset if not args.trainval else 'cifar10-valid'
order_fn = np.nanargmax
if args.dataset == 'cifar10':
acc_type = 'ori-test'
val_acc_type = 'x-valid'
else:
acc_type = 'x-test'
val_acc_type = 'x-valid'
runs = trange(args.n_runs, desc='acc: ')
for N in runs:
start = time.time()
indices = np.random.randint(0,15625,args.n_samples)
indices = np.random.randint(0,len(searchspace),args.n_samples)
scores = []
npstate = np.random.get_state()
ranstate = random.getstate()
torchstate = torch.random.get_rng_state()
for arch in indices:
data_iterator = iter(train_loader)
x, target = next(data_iterator)
x, target = x.to(device), target.to(device)
config = api.get_net_config(arch, args.dataset)
config['num_classes'] = 1
network = get_cell_based_tiny_net(config) # create the network from configuration
network = network.to(device)
jacobs, labels= get_batch_jacobian(network, x, target, 1, device, args)
jacobs = jacobs.reshape(jacobs.size(0), -1).cpu().numpy()
try:
s = eval_score(jacobs, labels)
uid = searchspace[arch]
network = searchspace.get_network(uid)
network.to(device)
if args.dropout:
add_dropout(network, args.sigma)
if args.init != '':
init_network(network, args.init)
if 'hook_' in args.score:
network.K = np.zeros((args.batch_size, args.batch_size))
def counting_forward_hook(module, inp, out):
try:
if not module.visited_backwards:
return
if isinstance(inp, tuple):
inp = inp[0]
inp = inp.view(inp.size(0), -1)
x = (inp > 0).float()
K = x @ x.t()
K2 = (1.-x) @ (1.-x.t())
network.K = network.K + K.cpu().numpy() + K2.cpu().numpy()
except:
pass
def counting_backward_hook(module, inp, out):
module.visited_backwards = True
for name, module in network.named_modules():
if 'ReLU' in str(type(module)):
#hooks[name] = module.register_forward_hook(counting_hook)
module.register_forward_hook(counting_forward_hook)
module.register_backward_hook(counting_backward_hook)
random.setstate(ranstate)
np.random.set_state(npstate)
torch.set_rng_state(torchstate)
data_iterator = iter(train_loader)
x, target = next(data_iterator)
x2 = torch.clone(x)
x2 = x2.to(device)
x, target = x.to(device), target.to(device)
jacobs, labels, y, out = get_batch_jacobian(network, x, target, device, args)
if args.kernel:
s = get_score_func(args.score)(out, labels)
elif 'hook_' in args.score:
network(x2.to(device))
s = get_score_func(args.score)(network.K, target)
elif args.repeat < args.batch_size:
s = get_score_func(args.score)(jacobs, labels, args.repeat)
else:
s = get_score_func(args.score)(jacobs, labels)
except Exception as e:
print(e)
s = np.nan
s = 0.
scores.append(s)
#print(len(scores))
#print(scores)
#print(order_fn(scores))
best_arch = indices[order_fn(scores)]
info = api.query_by_index(best_arch)
uid = searchspace[best_arch]
topscores.append(scores[order_fn(scores)])
chosen.append(best_arch)
acc.append(info.get_metrics(dset, acc_type)['accuracy'])
#acc.append(searchspace.get_accuracy(uid, acc_type, args.trainval))
acc.append(searchspace.get_final_accuracy(uid, acc_type, False))
if not args.dataset == 'cifar10' or args.trainval:
val_acc.append(info.get_metrics(dset, val_acc_type)['accuracy'])
val_acc.append(searchspace.get_final_accuracy(uid, val_acc_type, args.trainval))
# val_acc.append(info.get_metrics(dset, val_acc_type)['accuracy'])
times.append(time.time()-start)
runs.set_description(f"acc: {mean(acc if not args.trainval else val_acc):.2f}%")
runs.set_description(f"acc: {mean(acc):.2f}% time:{mean(times):.2f}")
print(f"Final mean test accuracy: {np.mean(acc)}")
if len(val_acc) > 1:
print(f"Final mean validation accuracy: {np.mean(val_acc)}")
#if len(val_acc) > 1:
# print(f"Final mean validation accuracy: {np.mean(val_acc)}")
state = {'accs': acc,
'val_accs': val_acc,
'chosen': chosen,
'times': times,
'topscores': topscores,
}
dset = args.dataset if not args.trainval else 'cifar10-valid'
fname = f"{args.save_loc}/{dset}_{args.n_runs}_{args.n_samples}_{args.seed}.t7"
dset = args.dataset if not (args.trainval and args.dataset == 'cifar10') else 'cifar10-valid'
fname = f"{args.save_loc}/{args.save_string}_{args.score}_{args.nasspace}_{dset}_{args.kernel}_{args.dropout}_{args.augtype}_{args.sigma}_{args.repeat}_{args.batch_size}_{args.n_runs}_{args.n_samples}_{args.seed}.t7"
torch.save(state, fname)

100
utils.py Normal file
View File

@@ -0,0 +1,100 @@
import torch
from pycls.models.nas.nas import Cell
class DropChannel(torch.nn.Module):
def __init__(self, p, mod):
super(DropChannel, self).__init__()
self.mod = mod
self.p = p
def forward(self, s0, s1, droppath):
ret = self.mod(s0, s1, droppath)
return ret
class DropConnect(torch.nn.Module):
def __init__(self, p):
super(DropConnect, self).__init__()
self.p = p
def forward(self, inputs):
batch_size = inputs.shape[0]
dim1 = inputs.shape[2]
dim2 = inputs.shape[3]
channel_size = inputs.shape[1]
keep_prob = 1 - self.p
# generate binary_tensor mask according to probability (p for 0, 1-p for 1)
random_tensor = keep_prob
random_tensor += torch.rand([batch_size, channel_size, 1, 1], dtype=inputs.dtype, device=inputs.device)
binary_tensor = torch.floor(random_tensor)
output = inputs / keep_prob * binary_tensor
return output
def add_dropout(network, p, prefix=''):
#p = 0.5
for attr_str in dir(network):
target_attr = getattr(network, attr_str)
if isinstance(target_attr, torch.nn.Conv2d):
setattr(network, attr_str, torch.nn.Sequential(target_attr, DropConnect(p)))
elif isinstance(target_attr, Cell):
setattr(network, attr_str, DropChannel(p, target_attr))
for n, ch in list(network.named_children()):
#print(f'{prefix}add_dropout {n}')
if isinstance(ch, torch.nn.Conv2d):
setattr(network, n, torch.nn.Sequential(ch, DropConnect(p)))
elif isinstance(ch, Cell):
setattr(network, n, DropChannel(p, ch))
else:
add_dropout(ch, p, prefix + '\t')
def orth_init(m):
if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)):
torch.nn.init.orthogonal_(m.weight)
def uni_init(m):
if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)):
torch.nn.init.uniform_(m.weight)
def uni2_init(m):
if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)):
torch.nn.init.uniform_(m.weight, -1., 1.)
def uni3_init(m):
if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)):
torch.nn.init.uniform_(m.weight, -.5, .5)
def norm_init(m):
if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)):
torch.nn.init.norm_(m.weight)
def eye_init(m):
if isinstance(m, torch.nn.Linear):
torch.nn.init.eye_(m.weight)
elif isinstance(m, torch.nn.Conv2d):
torch.nn.init.dirac_(m.weight)
def fixup_init(m):
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.zero_(m.weight)
elif isinstance(m, torch.nn.Linear):
torch.nn.init.zero_(m.weight)
torch.nn.init.zero_(m.bias)
def init_network(network, init):
if init == 'orthogonal':
network.apply(orth_init)
elif init == 'uniform':
print('uniform')
network.apply(uni_init)
elif init == 'uniform2':
network.apply(uni2_init)
elif init == 'uniform3':
network.apply(uni3_init)
elif init == 'normal':
network.apply(norm_init)
elif init == 'identity':
network.apply(eye_init)

View File

@@ -1,81 +0,0 @@
import re
from graphviz import Digraph
import pandas as pd
import time
import argparse
parser = argparse.ArgumentParser(description='Fast cell visualisation')
parser.add_argument('--arch', default=1, type=int)
parser.add_argument('--save', action='store_true')
args = parser.parse_args()
def set_none(bit):
print(bit)
tmp = bit.split('~')
tmp[0] = 'none'
print('~'.join(tmp))
return '~'.join(tmp)
def remove_pointless_ops(archstr):
old = None
new = archstr
while old != new:
old = new
bits = old.strip('|').split('|')
if 'none~' in bits[0]: # node 1 has no connections to it
bits[3] = set_none(bits[3]) # node 1 -> 2 now none
bits[6] = set_none(bits[6]) # node 1 -> 3 now none
if 'none~' in bits[2] and 'none~' in bits[3]: # node 2 has no connections to it
bits[7] = set_none(bits[7]) # node 2 -> 3 now none
if 'none~' in bits[7]: # doesn't matter what comes through node 2
bits[2] = set_none(bits[2]) # node 0 -> 2 now none
bits[3] = set_none(bits[3]) # node 1 -> 2 now none
if 'none~' in bits[6] and 'none~' in bits[7]: # doesn't matter what comes through node 1
bits[0] = set_none(bits[0]) # node 0 -> 1 now none
new = '|'.join(bits)
print(new)
return new
df = pd.read_pickle('results/arch_score_acc.pd')
nodestr = df.iloc[args.arch]['cellstr']
nodestr = nodestr[1:-1] # remove leading and trailing bars |
nodestr = remove_pointless_ops(nodestr)
nodes = nodestr.split("|+|")
dot = Digraph(
format='pdf',
edge_attr=dict(fontsize='12'),
node_attr=dict(fixedsize='true',shape="circle", height='0.5', width='0.5'),
engine='dot')
dot.body.extend(['rankdir=LR'])
OPS = ['conv_3x3','avg_pool_3x3','skip_connect','conv_1x1','none']
dot.node('0', 'in')
## ops are separated by bars (|) so
for i, node in enumerate(nodes):
# if node 3 then label as output
if (i+1) == 3:
dot.node(str(i+1), 'out')
else:
dot.node(str(i+1))
for op_str in node.split('|'):
op_name = [o for o in OPS if o in op_str][0]
if op_name == 'none':
break
connect = re.findall('~[0-9]', op_str)[0]
connect = connect[1:]
dot.edge(connect,str(i+1), label=op_name)
dot.render( view=True)
if args.save:
dot.render(f'outputs/{args.arch}.gv')