2019-11-15 17:26:32 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								##################################################  
						 
					
						
							
								
									
										
										
										
											2020-01-15 00:52:06 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								##############################################################################  
						 
					
						
							
								
									
										
										
										
											2019-11-14 13:55:42 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								import  os ,  sys ,  time ,  glob ,  random ,  argparse  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								import  numpy  as  np ,  collections  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  copy  import  deepcopy  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								import  torch  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								import  torch . nn  as  nn  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  pathlib  import  Path  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								lib_dir  =  ( Path ( __file__ ) . parent  /  ' .. '  /  ' .. '  /  ' lib ' ) . resolve ( )  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								if  str ( lib_dir )  not  in  sys . path :  sys . path . insert ( 0 ,  str ( lib_dir ) )  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  config_utils  import  load_config ,  dict2config ,  configure2str  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  datasets      import  get_datasets ,  SearchDataset  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  procedures    import  prepare_seed ,  prepare_logger ,  save_checkpoint ,  copy_checkpoint ,  get_optim_scheduler  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  utils         import  get_model_infos ,  obtain_accuracy  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  log_utils     import  AverageMeter ,  time_string ,  convert_secs2time  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  models        import  get_search_spaces  
						 
					
						
							
								
									
										
										
										
											2020-01-15 00:52:06 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								from  nas_201_api   import  NASBench201API  as  API  
						 
					
						
							
								
									
										
										
										
											2019-11-14 13:55:42 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								from  R_EA          import  train_and_eval ,  random_architecture_func  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-11-19 11:58:04 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  main ( xargs ,  nas_bench ) :  
						 
					
						
							
								
									
										
										
										
											2019-11-14 13:55:42 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  assert  torch . cuda . is_available ( ) ,  ' CUDA is not available. ' 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  torch . backends . cudnn . enabled    =  True 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  torch . backends . cudnn . benchmark  =  False 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  torch . backends . cudnn . deterministic  =  True 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  torch . set_num_threads (  xargs . workers  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  prepare_seed ( xargs . rand_seed ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  logger  =  prepare_logger ( args ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-03-15 22:50:17 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  if  xargs . dataset  ==  ' cifar10 ' : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    dataname  =  ' cifar10-valid ' 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    dataname  =  xargs . dataset 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-31 22:02:11 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  if  xargs . data_path  is  not  None : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    train_data ,  valid_data ,  xshape ,  class_num  =  get_datasets ( xargs . dataset ,  xargs . data_path ,  - 1 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    split_Fpath  =  ' configs/nas-benchmark/cifar-split.txt ' 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    cifar_split  =  load_config ( split_Fpath ,  None ,  None ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    train_split ,  valid_split  =  cifar_split . train ,  cifar_split . valid 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    logger . log ( ' Load split file from  {:} ' . format ( split_Fpath ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    config_path  =  ' configs/nas-benchmark/algos/R-EA.config ' 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    config  =  load_config ( config_path ,  { ' class_num ' :  class_num ,  ' xshape ' :  xshape } ,  logger ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    # To split data 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    train_data_v2  =  deepcopy ( train_data ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    train_data_v2 . transform  =  valid_data . transform 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    valid_data     =  train_data_v2 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    search_data    =  SearchDataset ( xargs . dataset ,  train_data ,  train_split ,  valid_split ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    # data loader 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    train_loader  =  torch . utils . data . DataLoader ( train_data ,  batch_size = config . batch_size ,  sampler = torch . utils . data . sampler . SubsetRandomSampler ( train_split )  ,  num_workers = xargs . workers ,  pin_memory = True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    valid_loader   =  torch . utils . data . DataLoader ( valid_data ,  batch_size = config . batch_size ,  sampler = torch . utils . data . sampler . SubsetRandomSampler ( valid_split ) ,  num_workers = xargs . workers ,  pin_memory = True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    logger . log ( ' |||||||  {:10s}  ||||||| Train-Loader-Num= {:} , Valid-Loader-Num= {:} , batch size= {:} ' . format ( xargs . dataset ,  len ( train_loader ) ,  len ( valid_loader ) ,  config . batch_size ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    logger . log ( ' |||||||  {:10s}  ||||||| Config= {:} ' . format ( xargs . dataset ,  config ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    extra_info  =  { ' config ' :  config ,  ' train_loader ' :  train_loader ,  ' valid_loader ' :  valid_loader } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    config_path  =  ' configs/nas-benchmark/algos/R-EA.config ' 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    config  =  load_config ( config_path ,  None ,  logger ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    logger . log ( ' |||||||  {:10s}  ||||||| Config= {:} ' . format ( xargs . dataset ,  config ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    extra_info  =  { ' config ' :  config ,  ' train_loader ' :  None ,  ' valid_loader ' :  None } 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-14 13:55:42 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  search_space  =  get_search_spaces ( ' cell ' ,  xargs . search_space_name ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  random_arch  =  random_architecture_func ( xargs . max_nodes ,  search_space ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  #x =random_arch() ; y = mutate_arch(x) 
							 
						 
					
						
							
								
									
										
										
										
											2020-01-01 22:18:42 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  x_start_time  =  time . time ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-14 13:55:42 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  logger . log ( ' {:}  use nas_bench :  {:} ' . format ( time_string ( ) ,  nas_bench ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-24 17:36:47 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  best_arch ,  best_acc ,  total_time_cost ,  history  =  None ,  - 1 ,  0 ,  [ ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  #for idx in range(xargs.random_num): 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  while  total_time_cost  <  xargs . time_budget : 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-14 13:55:42 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    arch  =  random_arch ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-15 22:50:17 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    accuracy ,  cost_time  =  train_and_eval ( arch ,  nas_bench ,  extra_info ,  dataname ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-24 17:36:47 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  total_time_cost  +  cost_time  >  xargs . time_budget :  break 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    else :  total_time_cost  + =  cost_time 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    history . append ( arch ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-14 13:55:42 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    if  best_arch  is  None  or  best_acc  <  accuracy : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      best_acc ,  best_arch  =  accuracy ,  arch 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-24 17:36:47 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    logger . log ( ' [ {:03d} ] :  {:}  : accuracy =  {:.2f} % ' . format ( len ( history ) ,  arch ,  accuracy ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2020-01-01 22:18:42 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  logger . log ( ' {:}  best arch is  {:} , accuracy =  {:.2f} % , visit  {:}  archs with  {:.1f}  s (real-cost =  {:.3f}  s). ' . format ( time_string ( ) ,  best_arch ,  best_acc ,  len ( history ) ,  total_time_cost ,  time . time ( ) - x_start_time ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-14 13:55:42 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  
							 
						 
					
						
							
								
									
										
										
										
											2019-11-19 11:58:04 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  info  =  nas_bench . query_by_arch (  best_arch  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  if  info  is  None :  logger . log ( ' Did not find this architecture :  {:} . ' . format ( best_arch ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  else            :  logger . log ( ' {:} ' . format ( info ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-14 13:55:42 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  logger . log ( ' - ' * 100 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  logger . close ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-19 11:58:04 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  return  logger . log_dir ,  nas_bench . query_index_by_arch (  best_arch  ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-31 22:02:11 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-11-14 13:55:42 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								if  __name__  ==  ' __main__ ' :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  parser  =  argparse . ArgumentParser ( " Regularized Evolution Algorithm " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  parser . add_argument ( ' --data_path ' ,           type = str ,    help = ' Path to dataset ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  parser . add_argument ( ' --dataset ' ,             type = str ,    choices = [ ' cifar10 ' ,  ' cifar100 ' ,  ' ImageNet16-120 ' ] ,  help = ' Choose between Cifar10/100 and ImageNet-16. ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  # channels and number-of-cells 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  parser . add_argument ( ' --search_space_name ' ,   type = str ,    help = ' The search space name. ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  parser . add_argument ( ' --max_nodes ' ,           type = int ,    help = ' The maximum number of nodes. ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  parser . add_argument ( ' --channel ' ,             type = int ,    help = ' The number of channels. ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  parser . add_argument ( ' --num_cells ' ,           type = int ,    help = ' The number of cells in one stage. ' ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-24 17:36:47 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  #parser.add_argument('--random_num',         type=int,   help='The number of random selected architectures.') 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  parser . add_argument ( ' --time_budget ' ,         type = int ,    help = ' The total time cost budge for searching (in seconds). ' ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-14 13:55:42 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  # log 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  parser . add_argument ( ' --workers ' ,             type = int ,    default = 2 ,     help = ' number of data loading workers (default: 2) ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  parser . add_argument ( ' --save_dir ' ,            type = str ,    help = ' Folder to save checkpoints and log. ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  parser . add_argument ( ' --arch_nas_dataset ' ,    type = str ,    help = ' The path to load the architecture dataset (tiny-nas-benchmark). ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  parser . add_argument ( ' --print_freq ' ,          type = int ,    help = ' print frequency (default: 200) ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  parser . add_argument ( ' --rand_seed ' ,           type = int ,    help = ' manual seed ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  args  =  parser . parse_args ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-19 11:58:04 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  #if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  if  args . arch_nas_dataset  is  None  or  not  os . path . isfile ( args . arch_nas_dataset ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    nas_bench  =  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    print  ( ' {:}  build NAS-Benchmark-API from  {:} ' . format ( time_string ( ) ,  args . arch_nas_dataset ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    nas_bench  =  API ( args . arch_nas_dataset ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-19 11:58:04 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  if  args . rand_seed  <  0 : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    save_dir ,  all_indexes ,  num  =  None ,  [ ] ,  500 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  i  in  range ( num ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      print  ( ' {:}  :  {:03d} / {:03d} ' . format ( time_string ( ) ,  i ,  num ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      args . rand_seed  =  random . randint ( 1 ,  100000 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      save_dir ,  index  =  main ( args ,  nas_bench ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      all_indexes . append (  index  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    torch . save ( all_indexes ,  save_dir  /  ' results.pth ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    main ( args ,  nas_bench )