2020-07-13 12:04:52 +02:00
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
##############################################################################
# Random Search for Hyper-Parameter Optimization, JMLR 2012 ##################
##############################################################################
# python ./exps/algos-v2/random_wo_share.py --dataset cifar10 --search_space tss
2020-07-13 13:35:13 +02:00
# python ./exps/algos-v2/random_wo_share.py --dataset cifar100 --search_space tss
# python ./exps/algos-v2/random_wo_share.py --dataset ImageNet16-120 --search_space tss
2020-07-13 12:04:52 +02:00
##############################################################################
import os , sys , time , glob , random , argparse
import numpy as np , collections
from copy import deepcopy
import torch
import torch . nn as nn
from pathlib import Path
lib_dir = ( Path ( __file__ ) . parent / ' .. ' / ' .. ' / ' lib ' ) . resolve ( )
if str ( lib_dir ) not in sys . path : sys . path . insert ( 0 , str ( lib_dir ) )
from config_utils import load_config , dict2config , configure2str
from datasets import get_datasets , SearchDataset
from procedures import prepare_seed , prepare_logger , save_checkpoint , copy_checkpoint , get_optim_scheduler
from utils import get_model_infos , obtain_accuracy
from log_utils import AverageMeter , time_string , convert_secs2time
from models import get_search_spaces
2020-07-30 15:07:11 +02:00
from nats_bench import create
2020-07-13 13:35:13 +02:00
from regularized_ea import random_topology_func , random_size_func
2020-07-13 12:04:52 +02:00
def main ( xargs , api ) :
torch . set_num_threads ( 4 )
prepare_seed ( xargs . rand_seed )
logger = prepare_logger ( args )
2020-07-13 13:35:13 +02:00
logger . log ( ' {:} use api : {:} ' . format ( time_string ( ) , api ) )
api . reset_time ( )
2020-07-13 12:04:52 +02:00
search_space = get_search_spaces ( xargs . search_space , ' nas-bench-301 ' )
if xargs . search_space == ' tss ' :
random_arch = random_topology_func ( search_space )
else :
random_arch = random_size_func ( search_space )
best_arch , best_acc , total_time_cost , history = None , - 1 , [ ] , [ ]
2020-07-13 13:35:13 +02:00
current_best_index = [ ]
while len ( total_time_cost ) == 0 or total_time_cost [ - 1 ] < xargs . time_budget :
2020-07-13 12:04:52 +02:00
arch = random_arch ( )
2020-07-14 13:53:21 +02:00
accuracy , _ , _ , total_cost = api . simulate_train_eval ( arch , xargs . dataset , hp = ' 12 ' )
2020-07-13 12:04:52 +02:00
total_time_cost . append ( total_cost )
history . append ( arch )
if best_arch is None or best_acc < accuracy :
best_acc , best_arch = accuracy , arch
logger . log ( ' [ {:03d} ] : {:} : accuracy = {:.2f} % ' . format ( len ( history ) , arch , accuracy ) )
2020-07-13 13:35:13 +02:00
current_best_index . append ( api . query_index_by_arch ( best_arch ) )
logger . log ( ' {:} best arch is {:} , accuracy = {:.2f} % , visit {:} archs with {:.1f} s. ' . format ( time_string ( ) , best_arch , best_acc , len ( history ) , total_time_cost [ - 1 ] ) )
2020-07-13 12:04:52 +02:00
info = api . query_info_str_by_arch ( best_arch , ' 200 ' if xargs . search_space == ' tss ' else ' 90 ' )
logger . log ( ' {:} ' . format ( info ) )
logger . log ( ' - ' * 100 )
logger . close ( )
2020-07-13 13:35:13 +02:00
return logger . log_dir , current_best_index , total_time_cost
2020-07-13 12:04:52 +02:00
if __name__ == ' __main__ ' :
parser = argparse . ArgumentParser ( " Random NAS " )
parser . add_argument ( ' --dataset ' , type = str , choices = [ ' cifar10 ' , ' cifar100 ' , ' ImageNet16-120 ' ] , help = ' Choose between Cifar10/100 and ImageNet-16. ' )
parser . add_argument ( ' --search_space ' , type = str , choices = [ ' tss ' , ' sss ' ] , help = ' Choose the search space. ' )
parser . add_argument ( ' --time_budget ' , type = int , default = 20000 , help = ' The total time cost budge for searching (in seconds). ' )
parser . add_argument ( ' --loops_if_rand ' , type = int , default = 500 , help = ' The total runs for evaluation. ' )
# log
2020-07-13 13:35:13 +02:00
parser . add_argument ( ' --save_dir ' , type = str , default = ' ./output/search ' , help = ' Folder to save checkpoints and log. ' )
2020-07-13 12:04:52 +02:00
parser . add_argument ( ' --rand_seed ' , type = int , default = - 1 , help = ' manual seed ' )
args = parser . parse_args ( )
2020-07-30 15:07:11 +02:00
api = create ( None , args . search_space , verbose = False )
2020-07-13 12:04:52 +02:00
args . save_dir = os . path . join ( ' {:} - {:} ' . format ( args . save_dir , args . search_space ) , args . dataset , ' RANDOM ' )
print ( ' save-dir : {:} ' . format ( args . save_dir ) )
if args . rand_seed < 0 :
2020-07-13 13:35:13 +02:00
save_dir , all_info = None , collections . OrderedDict ( )
2020-07-13 12:04:52 +02:00
for i in range ( args . loops_if_rand ) :
print ( ' {:} : {:03d} / {:03d} ' . format ( time_string ( ) , i , args . loops_if_rand ) )
args . rand_seed = random . randint ( 1 , 100000 )
save_dir , all_archs , all_total_times = main ( args , api )
all_info [ i ] = { ' all_archs ' : all_archs ,
' all_total_times ' : all_total_times }
save_path = save_dir / ' results.pth '
print ( ' save into {:} ' . format ( save_path ) )
torch . save ( all_info , save_path )
else :
main ( args , api )