2020-02-23 10:30:37 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								#####################################################  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #  
						 
					
						
							
								
									
										
										
										
											2019-12-31 22:02:11 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								########################################################  
						 
					
						
							
								
									
										
										
										
											2020-01-15 00:52:06 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								# python exps/NAS-Bench-201/test-correlation.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth  
						 
					
						
							
								
									
										
										
										
											2019-12-31 22:02:11 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								########################################################  
						 
					
						
							
								
									
										
										
										
											2020-03-13 14:00:54 -07:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								import  sys ,  argparse  
						 
					
						
							
								
									
										
										
										
											2019-12-31 22:02:11 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								import  numpy  as  np  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  copy  import  deepcopy  
						 
					
						
							
								
									
										
										
										
											2020-01-01 22:18:42 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								from  tqdm  import  tqdm  
						 
					
						
							
								
									
										
										
										
											2019-12-31 22:02:11 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								import  torch  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  pathlib  import  Path  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								lib_dir  =  ( Path ( __file__ ) . parent  /  ' .. '  /  ' .. '  /  ' lib ' ) . resolve ( )  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								if  str ( lib_dir )  not  in  sys . path :  sys . path . insert ( 0 ,  str ( lib_dir ) )  
						 
					
						
							
								
									
										
										
										
											2020-03-13 14:00:54 -07:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								from  log_utils     import  time_string  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  models        import  CellStructure  
						 
					
						
							
								
									
										
										
										
											2020-01-15 00:52:06 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								from  nas_201_api   import  NASBench201API  as  API  
						 
					
						
							
								
									
										
										
										
											2019-12-31 22:02:11 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								def  check_unique_arch ( meta_file ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  api  =  API ( str ( meta_file ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  arch_strs  =  deepcopy ( api . meta_archs ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  xarchs  =  [ CellStructure . str2structure ( x )  for  x  in  arch_strs ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  def  get_unique_matrix ( archs ,  consider_zero ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    UniquStrs  =  [ arch . to_unique_str ( consider_zero )  for  arch  in  archs ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    print  ( ' {:}  create unique-string ( {:} / {:} ) done ' . format ( time_string ( ) ,  len ( set ( UniquStrs ) ) ,  len ( UniquStrs ) ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    Unique2Index  =  dict ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  index ,  xstr  in  enumerate ( UniquStrs ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      if  xstr  not  in  Unique2Index :  Unique2Index [ xstr ]  =  list ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      Unique2Index [ xstr ] . append (  index  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    sm_matrix  =  torch . eye ( len ( archs ) ) . bool ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  _ ,  xlist  in  Unique2Index . items ( ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      for  i  in  xlist : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        for  j  in  xlist : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          sm_matrix [ i , j ]  =  True 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    unique_ids ,  unique_num  =  [ - 1  for  _  in  archs ] ,  0 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  i  in  range ( len ( unique_ids ) ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      if  unique_ids [ i ]  >  - 1 :  continue 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      neighbours  =  sm_matrix [ i ] . nonzero ( ) . view ( - 1 ) . tolist ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      for  nghb  in  neighbours : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        assert  unique_ids [ nghb ]  ==  - 1 ,  ' impossible ' 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        unique_ids [ nghb ]  =  unique_num 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      unique_num  + =  1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  sm_matrix ,  unique_ids ,  unique_num 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' There are  {:}  valid-archs ' . format (  sum ( arch . check_valid ( )  for  arch  in  xarchs )  ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  sm_matrix ,  uniqueIDs ,  unique_num  =  get_unique_matrix ( xarchs ,  None ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' {:}  There are  {:}  unique architectures (considering nothing). ' . format ( time_string ( ) ,  unique_num ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  sm_matrix ,  uniqueIDs ,  unique_num  =  get_unique_matrix ( xarchs ,  False ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' {:}  There are  {:}  unique architectures (not considering zero). ' . format ( time_string ( ) ,  unique_num ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  sm_matrix ,  uniqueIDs ,  unique_num  =  get_unique_matrix ( xarchs ,   True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' {:}  There are  {:}  unique architectures (considering zero). ' . format ( time_string ( ) ,  unique_num ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-01-01 22:18:42 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  check_cor_for_bandit ( meta_file ,  test_epoch ,  use_less_or_not ,  is_rand = True ,  need_print = False ) :  
						 
					
						
							
								
									
										
										
										
											2019-12-31 22:02:11 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  if  isinstance ( meta_file ,  API ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    api  =  meta_file 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    api  =  API ( str ( meta_file ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2020-01-02 16:49:16 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  cifar10_currs      =  [ ] 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-31 22:02:11 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  cifar10_valid      =  [ ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  cifar10_test       =  [ ] 
							 
						 
					
						
							
								
									
										
										
										
											2020-01-01 22:18:42 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  cifar100_valid     =  [ ] 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-31 22:02:11 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  cifar100_test      =  [ ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  imagenet_test      =  [ ] 
							 
						 
					
						
							
								
									
										
										
										
											2020-01-01 22:18:42 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  imagenet_valid     =  [ ] 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-31 22:02:11 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  for  idx ,  arch  in  enumerate ( api ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    results  =  api . get_more_info ( idx ,  ' cifar10-valid '  ,  test_epoch - 1 ,  use_less_or_not ,  is_rand ) 
							 
						 
					
						
							
								
									
										
										
										
											2020-01-02 16:49:16 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    cifar10_currs . append (  results [ ' valid-accuracy ' ]  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    # --->>>>> 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    results  =  api . get_more_info ( idx ,  ' cifar10-valid '  ,  None ,  False ,  is_rand ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-31 22:02:11 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    cifar10_valid . append (  results [ ' valid-accuracy ' ]  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    results  =  api . get_more_info ( idx ,  ' cifar10 '        ,  None ,  False ,  is_rand ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    cifar10_test . append (  results [ ' test-accuracy ' ]  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    results  =  api . get_more_info ( idx ,  ' cifar100 '       ,  None ,  False ,  is_rand ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    cifar100_test . append (  results [ ' test-accuracy ' ]  ) 
							 
						 
					
						
							
								
									
										
										
										
											2020-01-01 22:18:42 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    cifar100_valid . append (  results [ ' valid-accuracy ' ]  ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-31 22:02:11 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    results  =  api . get_more_info ( idx ,  ' ImageNet16-120 ' ,  None ,  False ,  is_rand ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    imagenet_test . append (  results [ ' test-accuracy ' ]  ) 
							 
						 
					
						
							
								
									
										
										
										
											2020-01-01 22:18:42 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    imagenet_valid . append (  results [ ' valid-accuracy ' ]  ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-31 22:02:11 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  def  get_cor ( A ,  B ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  float ( np . corrcoef ( A ,  B ) [ 0 , 1 ] ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  cors  =  [ ] 
							 
						 
					
						
							
								
									
										
										
										
											2020-01-02 16:49:16 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  for  basestr ,  xlist  in  zip ( [ ' C-010-V ' ,  ' C-010-T ' ,  ' C-100-V ' ,  ' C-100-T ' ,  ' I16-V ' ,  ' I16-T ' ] ,  [ cifar10_valid ,  cifar10_test ,  cifar100_valid ,  cifar100_test ,  imagenet_valid ,  imagenet_test ] ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    correlation  =  get_cor ( cifar10_currs ,  xlist ) 
							 
						 
					
						
							
								
									
										
										
										
											2020-01-01 22:18:42 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  need_print :  print  ( ' With  {:3d} / {:} -epochs-training, the correlation between cifar10-valid and  {:}  is :  {:} ' . format ( test_epoch ,  ' 012 '  if  use_less_or_not  else  ' 200 ' ,  basestr ,  correlation ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-31 22:02:11 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    cors . append (  correlation  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    #print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist))) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    #print('-'*200) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  #print('*'*230) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  return  cors 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-01-01 22:18:42 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  check_cor_for_bandit_v2 ( meta_file ,  test_epoch ,  use_less_or_not ,  is_rand ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  corrs  =  [ ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  for  i  in  tqdm ( range ( 100 ) ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    x  =  check_cor_for_bandit ( meta_file ,  test_epoch ,  use_less_or_not ,  is_rand ,  False ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    corrs . append (  x  ) 
							 
						 
					
						
							
								
									
										
										
										
											2020-01-02 16:49:16 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  #xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  xstrs  =  [ ' C-010-V ' ,  ' C-010-T ' ,  ' C-100-V ' ,  ' C-100-T ' ,  ' I16-V ' ,  ' I16-T ' ] 
							 
						 
					
						
							
								
									
										
										
										
											2020-01-01 22:18:42 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  correlations  =  np . array ( corrs ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print ( ' ------>>>>>>>>  {:03d} / {:}  >>>>>>>> ------ ' . format ( test_epoch ,  ' 012 '  if  use_less_or_not  else  ' 200 ' ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  for  idx ,  xstr  in  enumerate ( xstrs ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    print  ( ' {:8s}  ::: mean= {:.4f} , std= {:.4f}  ::  {:.4f} \\ pm {:.4f} ' . format ( xstr ,  correlations [ : , idx ] . mean ( ) ,  correlations [ : , idx ] . std ( ) ,  correlations [ : , idx ] . mean ( ) ,  correlations [ : , idx ] . std ( ) ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print ( ' ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-12-31 22:02:11 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								if  __name__  ==  ' __main__ ' :  
						 
					
						
							
								
									
										
										
										
											2020-01-15 00:52:06 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  parser  =  argparse . ArgumentParser ( " Analysis of NAS-Bench-201 " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  parser . add_argument ( ' --save_dir ' ,   type = str ,  default = ' ./output/search-cell-nas-bench-201/visuals ' ,  help = ' The base-name of folder to save checkpoints and log. ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  parser . add_argument ( ' --api_path ' ,   type = str ,  default = None ,                                          help = ' The path to the NAS-Bench-201 benchmark file. ' ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-31 22:02:11 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  args  =  parser . parse_args ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  vis_save_dir  =  Path ( args . save_dir ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  vis_save_dir . mkdir ( parents = True ,  exist_ok = True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  meta_file  =  Path ( args . api_path ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  assert  meta_file . exists ( ) ,  ' invalid path for api :  {:} ' . format ( meta_file ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  #check_unique_arch(meta_file) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  api  =  API ( str ( meta_file ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  #for iepoch in [11, 25, 50, 100, 150, 175, 200]: 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  #  check_cor_for_bandit(api,  6, iepoch) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  #  check_cor_for_bandit(api, 12, iepoch) 
							 
						 
					
						
							
								
									
										
										
										
											2020-01-01 22:18:42 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  check_cor_for_bandit_v2 ( api ,    6 ,   True ,  True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  check_cor_for_bandit_v2 ( api ,   12 ,   True ,  True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  check_cor_for_bandit_v2 ( api ,   12 ,  False ,  True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  check_cor_for_bandit_v2 ( api ,   24 ,  False ,  True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  check_cor_for_bandit_v2 ( api ,  100 ,  False ,  True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  check_cor_for_bandit_v2 ( api ,  150 ,  False ,  True ) 
							 
						 
					
						
							
								
									
										
										
										
											2020-01-02 16:49:16 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  check_cor_for_bandit_v2 ( api ,  175 ,  False ,  True ) 
							 
						 
					
						
							
								
									
										
										
										
											2020-01-01 22:18:42 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  check_cor_for_bandit_v2 ( api ,  200 ,  False ,  True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print ( ' ---- ' )