2020-02-23 10:30:37 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								#####################################################  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								#####################################################  
						 
					
						
							
								
									
										
										
										
											2020-03-10 19:08:56 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								# python exps/NAS-Bench-201/check.py --base_str C16-N5-LESS  
						 
					
						
							
								
									
										
										
										
											2020-02-23 10:30:37 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								#####################################################  
						 
					
						
							
								
									
										
										
										
											2020-03-09 19:38:00 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								import  sys ,  time ,  argparse ,  collections  
						 
					
						
							
								
									
										
										
										
											2019-12-26 23:29:36 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								import  torch  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  pathlib  import  Path  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  collections  import  defaultdict  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								lib_dir  =  ( Path ( __file__ ) . parent  /  ' .. '  /  ' .. '  /  ' lib ' ) . resolve ( )  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								if  str ( lib_dir )  not  in  sys . path :  sys . path . insert ( 0 ,  str ( lib_dir ) )  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  log_utils     import  AverageMeter ,  time_string ,  convert_secs2time  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								def  check_files ( save_dir ,  meta_file ,  basestr ) :  
						 
					
						
							
								
									
										
										
										
											2020-03-10 19:08:56 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  meta_infos  =  torch . load ( meta_file ,  map_location = ' cpu ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  meta_archs  =  meta_infos [ ' archs ' ] 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-26 23:29:36 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  meta_num_archs  =  meta_infos [ ' total ' ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  assert  meta_num_archs  ==  len ( meta_archs ) ,  ' invalid number of archs :  {:}  vs  {:} ' . format ( meta_num_archs ,  len ( meta_archs ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  sub_model_dirs  =  sorted ( list ( save_dir . glob ( ' *-*- {:} ' . format ( basestr ) ) ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print  ( ' {:}  find  {:}  directories used to save checkpoints ' . format ( time_string ( ) ,  len ( sub_model_dirs ) ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  subdir2archs ,  num_evaluated_arch  =  collections . OrderedDict ( ) ,  0 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  num_seeds  =  defaultdict ( lambda :  0 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  for  index ,  sub_dir  in  enumerate ( sub_model_dirs ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    xcheckpoints  =  list ( sub_dir . glob ( ' arch-*-seed-*.pth ' ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    #xcheckpoints = list(sub_dir.glob('arch-*-seed-0777.pth')) + list(sub_dir.glob('arch-*-seed-0888.pth')) + list(sub_dir.glob('arch-*-seed-0999.pth')) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    arch_indexes  =  set ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  checkpoint  in  xcheckpoints : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      temp_names  =  checkpoint . name . split ( ' - ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      assert  len ( temp_names )  ==  4  and  temp_names [ 0 ]  ==  ' arch '  and  temp_names [ 2 ]  ==  ' seed ' ,  ' invalid checkpoint name :  {:} ' . format ( checkpoint . name ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      arch_indexes . add (  temp_names [ 1 ]  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    subdir2archs [ sub_dir ]  =  sorted ( list ( arch_indexes ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    num_evaluated_arch    + =  len ( arch_indexes ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    # count number of seeds for each architecture 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  arch_index  in  arch_indexes : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      num_seeds [  len ( list ( sub_dir . glob ( ' arch- {:} -seed-*.pth ' . format ( arch_index ) ) ) )  ]  + =  1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  print ( ' There are  {:5d}  architectures that have been evaluated ( {:}  in total,  {:}  ckps in total). ' . format ( num_evaluated_arch ,  meta_num_archs ,  sum ( k * v  for  k ,  v  in  num_seeds . items ( ) ) ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  for  key  in  sorted (  list (  num_seeds . keys ( )  )  ) :  print  ( ' There are  {:5d}  architectures that are evaluated  {:}  times. ' . format ( num_seeds [ key ] ,  key ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  dir2ckps ,  dir2ckp_exists  =  dict ( ) ,  dict ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  start_time ,  epoch_time  =  time . time ( ) ,  AverageMeter ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  for  IDX ,  ( sub_dir ,  arch_indexes )  in  enumerate ( subdir2archs . items ( ) ) : 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-10 19:08:56 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  basestr  ==  ' C16-N5 ' : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      seeds  =  [ 777 ,  888 ,  999 ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    elif  basestr  ==  ' C16-N5-LESS ' : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      seeds  =  [ 111 ,  777 ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      raise  ValueError ( ' Invalid base str :  {:} ' . format ( basestr ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-26 23:29:36 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    numrs  =  defaultdict ( lambda :  0 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    all_checkpoints ,  all_ckp_exists  =  [ ] ,  [ ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  arch_index  in  arch_indexes : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      checkpoints  =  [ ' arch- {:} -seed- {:04d} .pth ' . format ( arch_index ,  seed )  for  seed  in  seeds ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ckp_exists   =  [ ( sub_dir / x ) . exists ( )  for  x  in  checkpoints ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      arch_index   =  int ( arch_index ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      assert  0  < =  arch_index  <  len ( meta_archs ) ,  ' invalid arch-index  {:}  (not found in meta_archs) ' . format ( arch_index ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      all_checkpoints  + =  checkpoints 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      all_ckp_exists   + =  ckp_exists 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      numrs [ sum ( ckp_exists ) ]  + =  1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    dir2ckps [  str ( sub_dir )  ]        =  all_checkpoints 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    dir2ckp_exists [  str ( sub_dir )  ]  =  all_ckp_exists 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    # measure time 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    epoch_time . update ( time . time ( )  -  start_time ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    start_time  =  time . time ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    numrstr  =  ' ,  ' . join (  [ ' {:} :  {:03d} ' . format ( x ,  numrs [ x ] )  for  x  in  sorted ( numrs . keys ( ) ) ]  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    print ( ' {:}  load [ {:2d} / {:2d} ] [ {:03d}  archs] [ {:04d} -> {:04d}  ckps]  {:}  done, need  {:} .  {:} ' . format ( time_string ( ) ,  IDX + 1 ,  len ( subdir2archs ) ,  len ( arch_indexes ) ,  len ( all_checkpoints ) ,  sum ( all_ckp_exists ) ,  sub_dir ,  convert_secs2time ( epoch_time . avg  *  ( len ( subdir2archs ) - IDX - 1 ) ,  True ) ,  numrstr ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								if  __name__  ==  ' __main__ ' :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-01-15 00:52:06 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  parser  =  argparse . ArgumentParser ( description = ' NAS Benchmark 201 ' ,  formatter_class = argparse . ArgumentDefaultsHelpFormatter ) 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-10 19:08:56 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  parser . add_argument ( ' --base_save_dir ' ,  type = str ,  default = ' ./output/NAS-BENCH-201-4 ' ,  help = ' The base-name of folder to save checkpoints and log. ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  parser . add_argument ( ' --meta_path ' ,      type = str ,  default = ' ./output/NAS-BENCH-201-4/meta-node-4.pth ' ,  help = ' The meta file path. ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  parser . add_argument ( ' --base_str ' ,       type = str ,  default = ' C16-N5 ' ,                    help = ' The basic string. ' ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-26 23:29:36 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  args  =  parser . parse_args ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-10 19:08:56 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  save_dir  =  Path ( args . base_save_dir ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  meta_path  =  Path ( args . meta_path ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-26 23:29:36 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  assert  save_dir . exists ( ) ,   ' invalid save dir path :  {:} ' . format ( save_dir ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  assert  meta_path . exists ( ) ,  ' invalid saved meta path :  {:} ' . format ( meta_path ) 
							 
						 
					
						
							
								
									
										
										
										
											2020-01-15 00:52:06 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  print  ( ' check NAS-Bench-201 in  {:} ' . format ( save_dir ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-26 23:29:36 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-03-10 19:08:56 +11:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  check_files ( save_dir ,  meta_path ,  args . base_str )