2020-02-23 10:30:37 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								#####################################################  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								#####################################################  
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								import  os ,  sys ,  time ,  argparse ,  collections  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  copy  import  deepcopy  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								import  torch  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  pathlib  import  Path  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  collections  import  defaultdict  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								lib_dir  =  ( Path ( __file__ ) . parent  /  ' .. '  /  ' .. '  /  ' lib ' ) . resolve ( )  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								if  str ( lib_dir )  not  in  sys . path :  sys . path . insert ( 0 ,  str ( lib_dir ) )  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  log_utils     import  AverageMeter ,  time_string ,  convert_secs2time  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  config_utils  import  load_config ,  dict2config  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  datasets      import  get_datasets  
						 
					
						
							
								
									
										
										
										
											2020-01-15 00:52:06 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								# NAS-Bench-201 related module or function  
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								from  models        import  CellStructure ,  get_cell_based_tiny_net  
						 
					
						
							
								
									
										
										
										
											2020-01-15 00:52:06 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								from  nas_201_api   import  ArchResults ,  ResultsCount  
						 
					
						
							
								
									
										
										
										
											2020-03-10 19:08:56 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								from  procedures    import  bench_pure_evaluate  as  pure_evaluate  
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								def  create_result_count ( used_seed ,  dataset ,  arch_config ,  results ,  dataloader_dict ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  xresult      =  ResultsCount ( dataset ,  results [ ' net_state_dict ' ] ,  results [ ' train_acc1es ' ] ,  results [ ' train_losses ' ] ,  \
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                               results [ ' param ' ] ,  results [ ' flop ' ] ,  arch_config ,  used_seed ,  results [ ' total_epoch ' ] ,  None ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  net_config  =  dict2config ( { ' name ' :  ' infer.tiny ' ,  ' C ' :  arch_config [ ' channel ' ] ,  ' N ' :  arch_config [ ' num_cells ' ] ,  ' genotype ' :  CellStructure . str2structure ( arch_config [ ' arch_str ' ] ) ,  ' num_classes ' : arch_config [ ' class_num ' ] } ,  None ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  network  =  get_cell_based_tiny_net ( net_config ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  network . load_state_dict ( xresult . get_net_param ( ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  if  ' train_times '  in  results :  # new version 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    xresult . update_train_info ( results [ ' train_acc1es ' ] ,  results [ ' train_acc5es ' ] ,  results [ ' train_losses ' ] ,  results [ ' train_times ' ] ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    xresult . update_eval ( results [ ' valid_acc1es ' ] ,  results [ ' valid_losses ' ] ,  results [ ' valid_times ' ] ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  dataset  ==  ' cifar10-valid ' : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      xresult . update_OLD_eval ( ' x-valid '  ,  results [ ' valid_acc1es ' ] ,  results [ ' valid_losses ' ] ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      loss ,  top1 ,  top5 ,  latencies  =  pure_evaluate ( dataloader_dict [ ' {:} @ {:} ' . format ( ' cifar10 ' ,  ' test ' ) ] ,  network . cuda ( ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      xresult . update_OLD_eval ( ' ori-test ' ,  { results [ ' total_epoch ' ] - 1 :  top1 } ,  { results [ ' total_epoch ' ] - 1 :  loss } ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      xresult . update_latency ( latencies ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    elif  dataset  ==  ' cifar10 ' : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      xresult . update_OLD_eval ( ' ori-test ' ,  results [ ' valid_acc1es ' ] ,  results [ ' valid_losses ' ] ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      loss ,  top1 ,  top5 ,  latencies  =  pure_evaluate ( dataloader_dict [ ' {:} @ {:} ' . format ( dataset ,  ' test ' ) ] ,  network . cuda ( ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      xresult . update_latency ( latencies ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    elif  dataset  ==  ' cifar100 '  or  dataset  ==  ' ImageNet16-120 ' : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      xresult . update_OLD_eval ( ' ori-test ' ,  results [ ' valid_acc1es ' ] ,  results [ ' valid_losses ' ] ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      loss ,  top1 ,  top5 ,  latencies  =  pure_evaluate ( dataloader_dict [ ' {:} @ {:} ' . format ( dataset ,  ' valid ' ) ] ,  network . cuda ( ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      xresult . update_OLD_eval ( ' x-valid ' ,  { results [ ' total_epoch ' ] - 1 :  top1 } ,  { results [ ' total_epoch ' ] - 1 :  loss } ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      loss ,  top1 ,  top5 ,  latencies  =  pure_evaluate ( dataloader_dict [ ' {:} @ {:} ' . format ( dataset ,   ' test ' ) ] ,  network . cuda ( ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      xresult . update_OLD_eval ( ' x-test '  ,  { results [ ' total_epoch ' ] - 1 :  top1 } ,  { results [ ' total_epoch ' ] - 1 :  loss } ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      xresult . update_latency ( latencies ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      raise  ValueError ( ' invalid dataset name :  {:} ' . format ( dataset ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  return  xresult 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								def  account_one_arch ( arch_index ,  arch_str ,  checkpoints ,  datasets ,  dataloader_dict ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  information  =  ArchResults ( arch_index ,  arch_str ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  for  checkpoint_path  in  checkpoints : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    checkpoint  =  torch . load ( checkpoint_path ,  map_location = ' cpu ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    used_seed   =  checkpoint_path . name . split ( ' - ' ) [ - 1 ] . split ( ' . ' ) [ 0 ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  dataset  in  datasets : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      assert  dataset  in  checkpoint ,  ' Can not find  {:}  in arch- {:}  from  {:} ' . format ( dataset ,  arch_index ,  checkpoint_path ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      results      =  checkpoint [ dataset ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      assert  results [ ' finish-train ' ] ,  ' This  {:}  arch seed= {:}  does not finish train on  {:}  :::  {:} ' . format ( arch_index ,  used_seed ,  dataset ,  checkpoint_path ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      arch_config  =  { ' channel ' :  results [ ' channel ' ] ,  ' num_cells ' :  results [ ' num_cells ' ] ,  ' arch_str ' :  arch_str ,  ' class_num ' :  results [ ' config ' ] [ ' class_num ' ] } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      xresult  =  create_result_count ( used_seed ,  dataset ,  arch_config ,  results ,  dataloader_dict ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      information . update ( dataset ,  int ( used_seed ) ,  xresult ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  return  information 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								def  GET_DataLoaders ( workers ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  torch . set_num_threads ( workers ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  root_dir   =  ( Path ( __file__ ) . parent  /  ' .. '  /  ' .. ' ) . resolve ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  torch_dir  =  Path ( os . environ [ ' TORCH_HOME ' ] ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  # cifar 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  cifar_config_path  =  root_dir  /  ' configs '  /  ' nas-benchmark '  /  ' CIFAR.config ' 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  cifar_config  =  load_config ( cifar_config_path ,  None ,  None ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' {:}  Create data-loader for all datasets ' . format ( time_string ( ) ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' - ' * 200 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  TRAIN_CIFAR10 ,  VALID_CIFAR10 ,  xshape ,  class_num  =  get_datasets ( ' cifar10 ' ,  str ( torch_dir / ' cifar.python ' ) ,  - 1 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' original CIFAR-10 :  {:}  training images and  {:}  test images :  {:}  input shape :  {:}  number of classes ' . format ( len ( TRAIN_CIFAR10 ) ,  len ( VALID_CIFAR10 ) ,  xshape ,  class_num ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  cifar10_splits  =  load_config ( root_dir  /  ' configs '  /  ' nas-benchmark '  /  ' cifar-split.txt ' ,  None ,  None ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  assert  cifar10_splits . train [ : 10 ]  ==  [ 0 ,  5 ,  7 ,  11 ,  13 ,  15 ,  16 ,  17 ,  20 ,  24 ]  and  cifar10_splits . valid [ : 10 ]  ==  [ 1 ,  2 ,  3 ,  4 ,  6 ,  8 ,  9 ,  10 ,  12 ,  14 ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  temp_dataset  =  deepcopy ( TRAIN_CIFAR10 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  temp_dataset . transform  =  VALID_CIFAR10 . transform 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  # data loader 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  trainval_cifar10_loader  =  torch . utils . data . DataLoader ( TRAIN_CIFAR10 ,  batch_size = cifar_config . batch_size ,  shuffle = True  ,  num_workers = workers ,  pin_memory = True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  train_cifar10_loader     =  torch . utils . data . DataLoader ( TRAIN_CIFAR10 ,  batch_size = cifar_config . batch_size ,  sampler = torch . utils . data . sampler . SubsetRandomSampler ( cifar10_splits . train ) ,  num_workers = workers ,  pin_memory = True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  valid_cifar10_loader     =  torch . utils . data . DataLoader ( temp_dataset  ,  batch_size = cifar_config . batch_size ,  sampler = torch . utils . data . sampler . SubsetRandomSampler ( cifar10_splits . valid ) ,  num_workers = workers ,  pin_memory = True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  test__cifar10_loader     =  torch . utils . data . DataLoader ( VALID_CIFAR10 ,  batch_size = cifar_config . batch_size ,  shuffle = False ,  num_workers = workers ,  pin_memory = True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' CIFAR-10  : trval-loader has  {:3d}  batch with  {:}  per batch ' . format ( len ( trainval_cifar10_loader ) ,  cifar_config . batch_size ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' CIFAR-10  : train-loader has  {:3d}  batch with  {:}  per batch ' . format ( len ( train_cifar10_loader ) ,  cifar_config . batch_size ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' CIFAR-10  : valid-loader has  {:3d}  batch with  {:}  per batch ' . format ( len ( valid_cifar10_loader ) ,  cifar_config . batch_size ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' CIFAR-10  : test--loader has  {:3d}  batch with  {:}  per batch ' . format ( len ( test__cifar10_loader ) ,  cifar_config . batch_size ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' - ' * 200 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  # CIFAR-100 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  TRAIN_CIFAR100 ,  VALID_CIFAR100 ,  xshape ,  class_num  =  get_datasets ( ' cifar100 ' ,  str ( torch_dir / ' cifar.python ' ) ,  - 1 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' original CIFAR-100:  {:}  training images and  {:}  test images :  {:}  input shape :  {:}  number of classes ' . format ( len ( TRAIN_CIFAR100 ) ,  len ( VALID_CIFAR100 ) ,  xshape ,  class_num ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  cifar100_splits  =  load_config ( root_dir  /  ' configs '  /  ' nas-benchmark '  /  ' cifar100-test-split.txt ' ,  None ,  None ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  assert  cifar100_splits . xvalid [ : 10 ]  ==  [ 1 ,  3 ,  4 ,  5 ,  8 ,  10 ,  13 ,  14 ,  15 ,  16 ]  and  cifar100_splits . xtest [ : 10 ]  ==  [ 0 ,  2 ,  6 ,  7 ,  9 ,  11 ,  12 ,  17 ,  20 ,  24 ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  train_cifar100_loader  =  torch . utils . data . DataLoader ( TRAIN_CIFAR100 ,  batch_size = cifar_config . batch_size ,  shuffle = True ,  num_workers = workers ,  pin_memory = True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  valid_cifar100_loader  =  torch . utils . data . DataLoader ( VALID_CIFAR100 ,  batch_size = cifar_config . batch_size ,  sampler = torch . utils . data . sampler . SubsetRandomSampler ( cifar100_splits . xvalid ) ,  num_workers = workers ,  pin_memory = True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  test__cifar100_loader  =  torch . utils . data . DataLoader ( VALID_CIFAR100 ,  batch_size = cifar_config . batch_size ,  sampler = torch . utils . data . sampler . SubsetRandomSampler ( cifar100_splits . xtest )  ,  num_workers = workers ,  pin_memory = True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' CIFAR-100  : train-loader has  {:3d}  batch ' . format ( len ( train_cifar100_loader ) ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' CIFAR-100  : valid-loader has  {:3d}  batch ' . format ( len ( valid_cifar100_loader ) ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' CIFAR-100  : test--loader has  {:3d}  batch ' . format ( len ( test__cifar100_loader ) ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' - ' * 200 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  imagenet16_config_path  =  ' configs/nas-benchmark/ImageNet-16.config ' 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  imagenet16_config  =  load_config ( imagenet16_config_path ,  None ,  None ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  TRAIN_ImageNet16_120 ,  VALID_ImageNet16_120 ,  xshape ,  class_num  =  get_datasets ( ' ImageNet16-120 ' ,  str ( torch_dir / ' cifar.python ' / ' ImageNet16 ' ) ,  - 1 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' original TRAIN_ImageNet16_120:  {:}  training images and  {:}  test images :  {:}  input shape :  {:}  number of classes ' . format ( len ( TRAIN_ImageNet16_120 ) ,  len ( VALID_ImageNet16_120 ) ,  xshape ,  class_num ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  imagenet_splits  =  load_config ( root_dir  /  ' configs '  /  ' nas-benchmark '  /  ' imagenet-16-120-test-split.txt ' ,  None ,  None ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  assert  imagenet_splits . xvalid [ : 10 ]  ==  [ 1 ,  2 ,  3 ,  6 ,  7 ,  8 ,  9 ,  12 ,  16 ,  18 ]  and  imagenet_splits . xtest [ : 10 ]  ==  [ 0 ,  4 ,  5 ,  10 ,  11 ,  13 ,  14 ,  15 ,  17 ,  20 ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  train_imagenet_loader  =  torch . utils . data . DataLoader ( TRAIN_ImageNet16_120 ,  batch_size = imagenet16_config . batch_size ,  shuffle = True ,  num_workers = workers ,  pin_memory = True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  valid_imagenet_loader  =  torch . utils . data . DataLoader ( VALID_ImageNet16_120 ,  batch_size = imagenet16_config . batch_size ,  sampler = torch . utils . data . sampler . SubsetRandomSampler ( imagenet_splits . xvalid ) ,  num_workers = workers ,  pin_memory = True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  test__imagenet_loader  =  torch . utils . data . DataLoader ( VALID_ImageNet16_120 ,  batch_size = imagenet16_config . batch_size ,  sampler = torch . utils . data . sampler . SubsetRandomSampler ( imagenet_splits . xtest )  ,  num_workers = workers ,  pin_memory = True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' ImageNet-16-120  : train-loader has  {:3d}  batch with  {:}  per batch ' . format ( len ( train_imagenet_loader ) ,  imagenet16_config . batch_size ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' ImageNet-16-120  : valid-loader has  {:3d}  batch with  {:}  per batch ' . format ( len ( valid_imagenet_loader ) ,  imagenet16_config . batch_size ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' ImageNet-16-120  : test--loader has  {:3d}  batch with  {:}  per batch ' . format ( len ( test__imagenet_loader ) ,  imagenet16_config . batch_size ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  # 'cifar10', 'cifar100', 'ImageNet16-120' 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  loaders  =  { ' cifar10@trainval ' :  trainval_cifar10_loader , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								             ' cifar10@train '    :  train_cifar10_loader , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								             ' cifar10@valid '    :  valid_cifar10_loader , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								             ' cifar10@test '     :  test__cifar10_loader , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								             ' cifar100@train '   :  train_cifar100_loader , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								             ' cifar100@valid '   :  valid_cifar100_loader , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								             ' cifar100@test '    :  test__cifar100_loader , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								             ' ImageNet16-120@train ' :  train_imagenet_loader , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								             ' ImageNet16-120@valid ' :  valid_imagenet_loader , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								             ' ImageNet16-120@test '  :  test__imagenet_loader } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  return  loaders 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								def  simplify ( save_dir ,  meta_file ,  basestr ,  target_dir ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  meta_infos      =  torch . load ( meta_file ,  map_location = ' cpu ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  meta_archs      =  meta_infos [ ' archs ' ]  # a list of architecture strings 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  meta_num_archs  =  meta_infos [ ' total ' ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  meta_max_node   =  meta_infos [ ' max_node ' ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  assert  meta_num_archs  ==  len ( meta_archs ) ,  ' invalid number of archs :  {:}  vs  {:} ' . format ( meta_num_archs ,  len ( meta_archs ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  sub_model_dirs  =  sorted ( list ( save_dir . glob ( ' *-*- {:} ' . format ( basestr ) ) ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' {:}  find  {:}  directories used to save checkpoints ' . format ( time_string ( ) ,  len ( sub_model_dirs ) ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  subdir2archs ,  num_evaluated_arch  =  collections . OrderedDict ( ) ,  0 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  num_seeds  =  defaultdict ( lambda :  0 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  for  index ,  sub_dir  in  enumerate ( sub_model_dirs ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    xcheckpoints  =  list ( sub_dir . glob ( ' arch-*-seed-*.pth ' ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    arch_indexes  =  set ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  checkpoint  in  xcheckpoints : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      temp_names  =  checkpoint . name . split ( ' - ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      assert  len ( temp_names )  ==  4  and  temp_names [ 0 ]  ==  ' arch '  and  temp_names [ 2 ]  ==  ' seed ' ,  ' invalid checkpoint name :  {:} ' . format ( checkpoint . name ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      arch_indexes . add (  temp_names [ 1 ]  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    subdir2archs [ sub_dir ]  =  sorted ( list ( arch_indexes ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    num_evaluated_arch    + =  len ( arch_indexes ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    # count number of seeds for each architecture 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  arch_index  in  arch_indexes : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      num_seeds [  len ( list ( sub_dir . glob ( ' arch- {:} -seed-*.pth ' . format ( arch_index ) ) ) )  ]  + =  1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print ( ' {:}  There are  {:5d}  architectures that have been evaluated ( {:}  in total). ' . format ( time_string ( ) ,  num_evaluated_arch ,  meta_num_archs ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  for  key  in  sorted (  list (  num_seeds . keys ( )  )  ) :  print  ( ' {:}  There are  {:5d}  architectures that are evaluated  {:}  times. ' . format ( time_string ( ) ,  num_seeds [ key ] ,  key ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  dataloader_dict  =  GET_DataLoaders (  6  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  to_save_simply  =  save_dir  /  ' simplifies ' 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  to_save_allarc  =  save_dir  /  ' simplifies '  /  ' architectures ' 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  if  not  to_save_simply . exists ( ) :  to_save_simply . mkdir ( parents = True ,  exist_ok = True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  if  not  to_save_allarc . exists ( ) :  to_save_allarc . mkdir ( parents = True ,  exist_ok = True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  assert  ( save_dir  /  target_dir )  in  subdir2archs ,  ' can not find  {:} ' . format ( target_dir ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  arch2infos ,  datasets  =  { } ,  ( ' cifar10-valid ' ,  ' cifar10 ' ,  ' cifar100 ' ,  ' ImageNet16-120 ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  evaluated_indexes     =  set ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  target_directory      =  save_dir  /  target_dir 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  target_less_dir       =  save_dir  /  ' {:} -LESS ' . format ( target_dir ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  arch_indexes          =  subdir2archs [  target_directory  ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  num_seeds             =  defaultdict ( lambda :  0 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  end_time              =  time . time ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  arch_time             =  AverageMeter ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  for  idx ,  arch_index  in  enumerate ( arch_indexes ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    checkpoints  =  list ( target_directory . glob ( ' arch- {:} -seed-*.pth ' . format ( arch_index ) ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ckps_less    =  list ( target_less_dir . glob ( ' arch- {:} -seed-*.pth ' . format ( arch_index ) ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    # create the arch info for each architecture 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    try : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      arch_info_full  =  account_one_arch ( arch_index ,  meta_archs [ int ( arch_index ) ] ,  checkpoints ,  datasets ,  dataloader_dict ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      arch_info_less  =  account_one_arch ( arch_index ,  meta_archs [ int ( arch_index ) ] ,  ckps_less ,  [ ' cifar10-valid ' ] ,  dataloader_dict ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      num_seeds [  len ( checkpoints )  ]  + =  1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    except : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      print ( ' Loading  {:}  failed, :  {:} ' . format ( arch_index ,  checkpoints ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      continue 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    assert  int ( arch_index )  not  in  evaluated_indexes ,  ' conflict arch-index :  {:} ' . format ( arch_index ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    assert  0  < =  int ( arch_index )  <  len ( meta_archs ) ,  ' invalid arch-index  {:}  (not found in meta_archs) ' . format ( arch_index ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    arch_info  =  { ' full ' :  arch_info_full ,  ' less ' :  arch_info_less } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    evaluated_indexes . add (  int ( arch_index )  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    arch2infos [ int ( arch_index ) ]  =  arch_info 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    torch . save ( { ' full ' :  arch_info_full . state_dict ( ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                ' less ' :  arch_info_less . state_dict ( ) } ,  to_save_allarc  /  ' {:} -FULL.pth ' . format ( arch_index ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    arch_info [ ' full ' ] . clear_params ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    arch_info [ ' less ' ] . clear_params ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    torch . save ( { ' full ' :  arch_info_full . state_dict ( ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                ' less ' :  arch_info_less . state_dict ( ) } ,  to_save_allarc  /  ' {:} -SIMPLE.pth ' . format ( arch_index ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    # measure elapsed time 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    arch_time . update ( time . time ( )  -  end_time ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    end_time   =  time . time ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    need_time  =  ' {:} ' . format (  convert_secs2time ( arch_time . avg  *  ( len ( arch_indexes ) - idx - 1 ) ,  True )  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    print ( ' {:}   {:}  [ {:03d} / {:03d} ] :  {:}  still need  {:} ' . format ( time_string ( ) ,  target_dir ,  idx ,  len ( arch_indexes ) ,  arch_index ,  need_time ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  # measure time 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  xstrs  =  [ ' {:} : {:03d} ' . format ( key ,  num_seeds [ key ] )  for  key  in  sorted (  list (  num_seeds . keys ( )  )  )  ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print ( ' {:}   {:}  done :  {:} ' . format ( time_string ( ) ,  target_dir ,  xstrs ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  final_infos  =  { ' meta_archs '  :  meta_archs , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                 ' total_archs ' :  meta_num_archs , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                 ' basestr '     :  basestr , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                 ' arch2infos '  :  arch2infos , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                 ' evaluated_indexes ' :  evaluated_indexes } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  save_file_name  =  to_save_simply  /  ' {:} .pth ' . format ( target_dir ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  torch . save ( final_infos ,  save_file_name ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' Save  {:}  /  {:}  architecture results into  {:} . ' . format ( len ( evaluated_indexes ) ,  meta_num_archs ,  save_file_name ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								def  merge_all ( save_dir ,  meta_file ,  basestr ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  meta_infos      =  torch . load ( meta_file ,  map_location = ' cpu ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  meta_archs      =  meta_infos [ ' archs ' ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  meta_num_archs  =  meta_infos [ ' total ' ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  meta_max_node   =  meta_infos [ ' max_node ' ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  assert  meta_num_archs  ==  len ( meta_archs ) ,  ' invalid number of archs :  {:}  vs  {:} ' . format ( meta_num_archs ,  len ( meta_archs ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  sub_model_dirs  =  sorted ( list ( save_dir . glob ( ' *-*- {:} ' . format ( basestr ) ) ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' {:}  find  {:}  directories used to save checkpoints ' . format ( time_string ( ) ,  len ( sub_model_dirs ) ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  for  index ,  sub_dir  in  enumerate ( sub_model_dirs ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    arch_info_files  =  sorted (  list ( sub_dir . glob ( ' arch-*-seed-*.pth ' )  )  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    print  ( ' The  {:02d} / {:02d} -th directory :  {:}  :  {:}  runs. ' . format ( index ,  len ( sub_model_dirs ) ,  sub_dir ,  len ( arch_info_files ) ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  arch2infos ,  evaluated_indexes  =  dict ( ) ,  set ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  for  IDX ,  sub_dir  in  enumerate ( sub_model_dirs ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ckp_path  =  sub_dir . parent  /  ' simplifies '  /  ' {:} .pth ' . format ( sub_dir . name ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  ckp_path . exists ( ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      sub_ckps  =  torch . load ( ckp_path ,  map_location = ' cpu ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      assert  sub_ckps [ ' total_archs ' ]  ==  meta_num_archs  and  sub_ckps [ ' basestr ' ]  ==  basestr 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      xarch2infos  =  sub_ckps [ ' arch2infos ' ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      xevalindexs  =  sub_ckps [ ' evaluated_indexes ' ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      for  eval_index  in  xevalindexs : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        assert  eval_index  not  in  evaluated_indexes  and  eval_index  not  in  arch2infos 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        #arch2infos[eval_index] = xarch2infos[eval_index].state_dict() 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        arch2infos [ eval_index ]  =  { ' full ' :  xarch2infos [ eval_index ] [ ' full ' ] . state_dict ( ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                                  ' less ' :  xarch2infos [ eval_index ] [ ' less ' ] . state_dict ( ) } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        evaluated_indexes . add (  eval_index  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      print  ( ' {:}  [ {:03d} / {:03d} ] merge data from  {:}  with  {:}  models. ' . format ( time_string ( ) ,  IDX ,  len ( sub_model_dirs ) ,  ckp_path ,  len ( xevalindexs ) ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      raise  ValueError ( ' Can not find  {:} ' . format ( ckp_path ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      #print ('{:} [{:03d}/{:03d}] can not find {:}, skip.'.format(time_string(), IDX, len(subdir2archs), ckp_path)) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  evaluated_indexes  =  sorted (  list (  evaluated_indexes  )  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' Finally, there are  {:}  architectures that have been trained and evaluated. ' . format ( len ( evaluated_indexes ) ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  to_save_simply  =  save_dir  /  ' simplifies ' 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  if  not  to_save_simply . exists ( ) :  to_save_simply . mkdir ( parents = True ,  exist_ok = True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  final_infos  =  { ' meta_archs '  :  meta_archs , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                 ' total_archs ' :  meta_num_archs , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                 ' arch2infos '  :  arch2infos , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                 ' evaluated_indexes ' :  evaluated_indexes } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  save_file_name  =  to_save_simply  /  ' {:} -final-infos.pth ' . format ( basestr ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  torch . save ( final_infos ,  save_file_name ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' Save  {:}  /  {:}  architecture results into  {:} . ' . format ( len ( evaluated_indexes ) ,  meta_num_archs ,  save_file_name ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								if  __name__  ==  ' __main__ ' :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-01-15 00:52:06 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  parser  =  argparse . ArgumentParser ( description = ' NAS-BENCH-201 ' ,  formatter_class = argparse . ArgumentDefaultsHelpFormatter ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  parser . add_argument ( ' --mode '          ,   type = str ,  choices = [ ' cal ' ,  ' merge ' ] ,             help = ' The running mode for this script. ' ) 
							 
						 
					
						
							
								
									
										
										
										
											2020-01-15 00:52:06 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  parser . add_argument ( ' --base_save_dir ' ,   type = str ,  default = ' ./output/NAS-BENCH-201-4 ' ,   help = ' The base-name of folder to save checkpoints and log. ' ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  parser . add_argument ( ' --target_dir '    ,   type = str ,                                       help = ' The target directory. ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  parser . add_argument ( ' --max_node '      ,   type = int ,  default = 4 ,                            help = ' The maximum node in a cell. ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  parser . add_argument ( ' --channel '       ,   type = int ,  default = 16 ,                           help = ' The number of channels. ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  parser . add_argument ( ' --num_cells '     ,   type = int ,  default = 5 ,                            help = ' The number of cells in one stage. ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  args  =  parser . parse_args ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  
							 
						 
					
						
							
								
									
										
										
										
											2020-03-10 19:08:56 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  save_dir   =  Path ( args . base_save_dir ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  meta_path  =  save_dir  /  ' meta-node- {:} .pth ' . format ( args . max_node ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  assert  save_dir . exists ( ) ,   ' invalid save dir path :  {:} ' . format ( save_dir ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  assert  meta_path . exists ( ) ,  ' invalid saved meta path :  {:} ' . format ( meta_path ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' start the statistics of our nas-benchmark from  {:}  using  {:} . ' . format ( save_dir ,  args . target_dir ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  basestr    =  ' C {:} -N {:} ' . format ( args . channel ,  args . num_cells ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  if  args . mode  ==  ' cal ' : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    simplify ( save_dir ,  meta_path ,  basestr ,  args . target_dir ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  elif  args . mode  ==  ' merge ' : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    merge_all ( save_dir ,  meta_path ,  basestr ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  else : 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-10 19:08:56 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    raise  ValueError ( ' invalid mode :  {:} ' . format ( args . mode ) )