update README
This commit is contained in:
		@@ -26,9 +26,10 @@ It is recommended to put these data into `$TORCH_HOME` (`~/.torch/` by default).
 | 
			
		||||
 | 
			
		||||
1. Creating an API instance from a file:
 | 
			
		||||
```
 | 
			
		||||
from nas_102_api import NASBench102API
 | 
			
		||||
api = NASBench102API('$path_to_meta_nas_bench_file')
 | 
			
		||||
api = NASBench102API('NAS-Bench-102-v1_0-e61699.pth')
 | 
			
		||||
from nas_102_api import NASBench102API as API
 | 
			
		||||
api = API('$path_to_meta_nas_bench_file')
 | 
			
		||||
api = API('NAS-Bench-102-v1_0-e61699.pth')
 | 
			
		||||
api = API('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-Bench-102-v1_0-e61699.pth'))
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
2. Show the number of architectures `len(api)` and each architecture `api[i]`:
 | 
			
		||||
@@ -45,12 +46,12 @@ api.show(1)
 | 
			
		||||
api.show(2)
 | 
			
		||||
 | 
			
		||||
# show the mean loss and accuracy of an architecture
 | 
			
		||||
info = api.query_meta_info_by_index(1)
 | 
			
		||||
res_metrics = info.get_metrics('cifar10', 'train')
 | 
			
		||||
cost_metrics = info.get_comput_costs('cifar100')
 | 
			
		||||
info = api.query_meta_info_by_index(1)  # This is an instance of `ArchResults`
 | 
			
		||||
res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric names as keys
 | 
			
		||||
cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency
 | 
			
		||||
 | 
			
		||||
# get the detailed information
 | 
			
		||||
results = api.query_by_index(1, 'cifar100')
 | 
			
		||||
results = api.query_by_index(1, 'cifar100') # a list of all trials on cifar100
 | 
			
		||||
print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1]))
 | 
			
		||||
print ('Latency : {:}'.format(results[0].get_latency()))
 | 
			
		||||
print ('Train Info : {:}'.format(results[0].get_train()))
 | 
			
		||||
 
 | 
			
		||||
@@ -35,6 +35,8 @@ We build a new benchmark for neural architecture search, please see more details
 | 
			
		||||
The benchmark data file (v1.0) is `NAS-Bench-102-v1_0-e61699.pth`, which can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs).
 | 
			
		||||
 | 
			
		||||
## [Network Pruning via Transformable Architecture Search](https://arxiv.org/abs/1905.09717)
 | 
			
		||||
[](https://paperswithcode.com/sota/network-pruning-on-cifar-100?p=network-pruning-via-transformable)
 | 
			
		||||
 | 
			
		||||
In this paper, we proposed a differentiable searching strategy for transformable architectures, i.e., searching for the depth and width of a deep neural network.
 | 
			
		||||
You could see the highlight of our Transformable Architecture Search (TAS) at our [project page](https://xuanyidong.com/assets/projects/NeurIPS-2019-TAS.html).
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@
 | 
			
		||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
 | 
			
		||||
##################################################
 | 
			
		||||
# required to install hpbandster #################
 | 
			
		||||
# bash ./scripts-search/algos/BOHB.sh -1         #
 | 
			
		||||
##################################################
 | 
			
		||||
import os, sys, time, glob, random, argparse
 | 
			
		||||
import numpy as np, collections
 | 
			
		||||
@@ -19,7 +20,6 @@ from utils        import get_model_infos, obtain_accuracy
 | 
			
		||||
from log_utils    import AverageMeter, time_string, convert_secs2time
 | 
			
		||||
from nas_102_api  import NASBench102API as API
 | 
			
		||||
from models       import CellStructure, get_search_spaces
 | 
			
		||||
from R_EA import train_and_eval
 | 
			
		||||
# BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018
 | 
			
		||||
import ConfigSpace
 | 
			
		||||
from hpbandster.optimizers.bohb import BOHB
 | 
			
		||||
@@ -53,21 +53,44 @@ def config2structure_func(max_nodes):
 | 
			
		||||
 | 
			
		||||
class MyWorker(Worker):
 | 
			
		||||
 | 
			
		||||
  def __init__(self, *args, sleep_interval=0, convert_func=None, nas_bench=None, **kwargs):
 | 
			
		||||
  def __init__(self, *args, convert_func=None, nas_bench=None, time_scale=None, **kwargs):
 | 
			
		||||
    super().__init__(*args, **kwargs)
 | 
			
		||||
    self.sleep_interval = sleep_interval
 | 
			
		||||
    self.convert_func   = convert_func
 | 
			
		||||
    self.nas_bench      = nas_bench
 | 
			
		||||
    self.test_time      = 0
 | 
			
		||||
    self.time_scale     = time_scale
 | 
			
		||||
    self.seen_arch      = 0
 | 
			
		||||
    self.sim_cost_time  = 0
 | 
			
		||||
    self.real_cost_time = 0
 | 
			
		||||
 | 
			
		||||
  def compute(self, config, budget, **kwargs):
 | 
			
		||||
    structure = self.convert_func( config )
 | 
			
		||||
    reward, time_cost = train_and_eval(structure, self.nas_bench, None)
 | 
			
		||||
    import pdb; pdb.set_trace()
 | 
			
		||||
    self.test_time += 1
 | 
			
		||||
    start_time = time.time()
 | 
			
		||||
    structure  = self.convert_func( config )
 | 
			
		||||
    arch_index = self.nas_bench.query_index_by_arch( structure )
 | 
			
		||||
    iepoch     = 0
 | 
			
		||||
    while iepoch < 12:
 | 
			
		||||
      info     = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', iepoch, True)
 | 
			
		||||
      cur_time = info['train-all-time'] + info['valid-per-time']
 | 
			
		||||
      cur_vacc = info['valid-accuracy']
 | 
			
		||||
      if time.time() - start_time + cur_time / self.time_scale > budget:
 | 
			
		||||
        break
 | 
			
		||||
      else:
 | 
			
		||||
        iepoch += 1
 | 
			
		||||
    self.sim_cost_time += cur_time
 | 
			
		||||
    self.seen_arch += 1
 | 
			
		||||
    remaining_time = cur_time / self.time_scale - (time.time() - start_time)
 | 
			
		||||
    if remaining_time > 0:
 | 
			
		||||
      time.sleep(remaining_time)
 | 
			
		||||
    else:
 | 
			
		||||
      import pdb; pdb.set_trace()
 | 
			
		||||
    self.real_cost_time += (time.time() - start_time)
 | 
			
		||||
    return ({
 | 
			
		||||
            'loss': float(100-reward),
 | 
			
		||||
            'info': time_cost})
 | 
			
		||||
            'loss': 100 - float(cur_vacc),
 | 
			
		||||
            'info': {'seen-arch'     : self.seen_arch,
 | 
			
		||||
                     'sim-test-time' : self.sim_cost_time,
 | 
			
		||||
                     'real-test-time': self.real_cost_time,
 | 
			
		||||
                     'current-arch'  : arch_index,
 | 
			
		||||
                     'current-budget': budget}
 | 
			
		||||
            })
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(xargs, nas_bench):
 | 
			
		||||
@@ -116,26 +139,30 @@ def main(xargs, nas_bench):
 | 
			
		||||
  #logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string()))
 | 
			
		||||
  workers = []
 | 
			
		||||
  for i in range(num_workers):
 | 
			
		||||
    w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, run_id=hb_run_id, id=i)
 | 
			
		||||
    w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, time_scale=xargs.time_scale, run_id=hb_run_id, id=i)
 | 
			
		||||
    w.run(background=True)
 | 
			
		||||
    workers.append(w)
 | 
			
		||||
 | 
			
		||||
  simulate_time_budge = xargs.time_budget // xargs.time_scale
 | 
			
		||||
  start_time = time.time()
 | 
			
		||||
  logger.log('simulate_time_budge : {:} (in seconds).'.format(simulate_time_budge))
 | 
			
		||||
  bohb = BOHB(configspace=cs,
 | 
			
		||||
            run_id=hb_run_id,
 | 
			
		||||
            eta=3, min_budget=3, max_budget=xargs.time_budget,
 | 
			
		||||
            eta=3, min_budget=simulate_time_budge//3, max_budget=simulate_time_budge,
 | 
			
		||||
            nameserver=ns_host,
 | 
			
		||||
            nameserver_port=ns_port,
 | 
			
		||||
            num_samples=xargs.num_samples,
 | 
			
		||||
            random_fraction=xargs.random_fraction, bandwidth_factor=xargs.bandwidth_factor,
 | 
			
		||||
            ping_interval=10, min_bandwidth=xargs.min_bandwidth)
 | 
			
		||||
  #          optimization_strategy=xargs.strategy, num_samples=xargs.num_samples,
 | 
			
		||||
  
 | 
			
		||||
  results = bohb.run(xargs.n_iters, min_n_workers=num_workers)
 | 
			
		||||
  import pdb; pdb.set_trace()
 | 
			
		||||
 | 
			
		||||
  bohb.shutdown(shutdown_workers=True)
 | 
			
		||||
  NS.shutdown()
 | 
			
		||||
 | 
			
		||||
  real_cost_time = time.time() - start_time
 | 
			
		||||
  import pdb; pdb.set_trace()
 | 
			
		||||
 | 
			
		||||
  id2config = results.get_id2config_mapping()
 | 
			
		||||
  incumbent = results.get_incumbent_id()
 | 
			
		||||
 | 
			
		||||
@@ -163,6 +190,7 @@ if __name__ == '__main__':
 | 
			
		||||
  parser.add_argument('--channel',            type=int,   help='The number of channels.')
 | 
			
		||||
  parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.')
 | 
			
		||||
  parser.add_argument('--time_budget',        type=int,   help='The total time cost budge for searching (in seconds).')
 | 
			
		||||
  parser.add_argument('--time_scale' ,        type=int,   help='The time scale to accelerate the time budget.')
 | 
			
		||||
  # BOHB
 | 
			
		||||
  parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function')
 | 
			
		||||
  parser.add_argument('--min_bandwidth',    default=.3, type=float, nargs='?', help='minimum bandwidth for KDE')
 | 
			
		||||
 
 | 
			
		||||
@@ -59,7 +59,7 @@ def train_and_eval(arch, nas_bench, extra_info):
 | 
			
		||||
  if nas_bench is not None:
 | 
			
		||||
    arch_index = nas_bench.query_index_by_arch( arch )
 | 
			
		||||
    assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
 | 
			
		||||
    info = nas_bench.get_more_info(arch_index, 'cifar10-valid', True)
 | 
			
		||||
    info = nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True)
 | 
			
		||||
    valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time']
 | 
			
		||||
    #_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs
 | 
			
		||||
  else:
 | 
			
		||||
 
 | 
			
		||||
@@ -147,14 +147,14 @@ class NASBench102API(object):
 | 
			
		||||
    archresult = arch2infos[index]
 | 
			
		||||
    return archresult.get_net_param(dataset, seed)
 | 
			
		||||
 | 
			
		||||
  def get_more_info(self, index, dataset, use_12epochs_result=False):
 | 
			
		||||
  def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False):
 | 
			
		||||
    if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
 | 
			
		||||
    else                  : basestr, arch2infos = '200epochs', self.arch2infos_full
 | 
			
		||||
    archresult = arch2infos[index]
 | 
			
		||||
    if dataset == 'cifar10-valid':
 | 
			
		||||
      train_info = archresult.get_metrics(dataset, 'train', is_random=True)
 | 
			
		||||
      valid_info = archresult.get_metrics(dataset, 'x-valid', is_random=True)
 | 
			
		||||
      test__info = archresult.get_metrics(dataset, 'ori-test', is_random=True)
 | 
			
		||||
      train_info = archresult.get_metrics(dataset, 'train'   , iepoch=iepoch, is_random=True)
 | 
			
		||||
      valid_info = archresult.get_metrics(dataset, 'x-valid' , iepoch=iepoch, is_random=True)
 | 
			
		||||
      test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=True)
 | 
			
		||||
      total      = train_info['iepoch'] + 1
 | 
			
		||||
      return {'train-loss'    : train_info['loss'],
 | 
			
		||||
              'train-accuracy': train_info['accuracy'],
 | 
			
		||||
 
 | 
			
		||||
@@ -34,6 +34,6 @@ OMP_NUM_THREADS=4 python ./exps/algos/BOHB.py \
 | 
			
		||||
	--dataset ${dataset} --data_path ${data_path} \
 | 
			
		||||
	--search_space_name ${space} \
 | 
			
		||||
	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
 | 
			
		||||
	--time_budget 12000 \
 | 
			
		||||
	--n_iters 100 --num_samples 4 --random_fraction 0 \
 | 
			
		||||
	--time_budget 12000 --time_scale 200 \
 | 
			
		||||
	--n_iters 64 --num_samples 4 --random_fraction 0 \
 | 
			
		||||
	--workers 4 --print_freq 200 --rand_seed ${seed}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user