2019-11-15 17:15:07 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								################################################## 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								################################################## 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-19 11:58:04 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								import  os ,  sys ,  copy ,  random ,  torch ,  numpy  as  np 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								from  collections  import  OrderedDict ,  defaultdict 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-11 00:46:02 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								def  print_information ( information ,  extra_info = None ,  show = False ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  dataset_names  =  information . get_dataset_names ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  strings  =  [ information . arch_str ,  ' datasets :  {:} , extra-info :  {:} ' . format ( dataset_names ,  extra_info ) ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  metric2str ( loss ,  acc ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  ' loss =  {:.3f} , top1 =  {:.2f} % ' . format ( loss ,  acc ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  for  ida ,  dataset  in  enumerate ( dataset_names ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    flop ,  param ,  latency  =  information . get_comput_costs ( dataset ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    str1  =  ' {:14s}  FLOP= {:6.2f}  M, Params= {:.3f}  MB, latency= {:}  ms. ' . format ( dataset ,  flop ,  param ,  ' {:.2f} ' . format ( latency * 1000 )  if  latency  >  0  else  None ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    train_loss ,  train_acc  =  information . get_metrics ( dataset ,  ' train ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if  dataset  ==  ' cifar10-valid ' : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      valid_loss ,  valid_acc  =  information . get_metrics ( dataset ,  ' x-valid ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      str2  =  ' {:14s}  train : [ {:} ], valid : [ {:} ] ' . format ( dataset ,  metric2str ( train_loss ,  train_acc ) ,  metric2str ( valid_loss ,  valid_acc ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    elif  dataset  ==  ' cifar10 ' : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      test__loss ,  test__acc  =  information . get_metrics ( dataset ,  ' ori-test ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      str2  =  ' {:14s}  train : [ {:} ], test  : [ {:} ] ' . format ( dataset ,  metric2str ( train_loss ,  train_acc ) ,  metric2str ( test__loss ,  test__acc ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      valid_loss ,  valid_acc  =  information . get_metrics ( dataset ,  ' x-valid ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      test__loss ,  test__acc  =  information . get_metrics ( dataset ,  ' x-test ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      str2  =  ' {:14s}  train : [ {:} ], valid : [ {:} ], test : [ {:} ] ' . format ( dataset ,  metric2str ( train_loss ,  train_acc ) ,  metric2str ( valid_loss ,  valid_acc ) ,  metric2str ( test__loss ,  test__acc ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    strings  + =  [ str1 ,  str2 ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  if  show :  print ( ' \n ' . join ( strings ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  return  strings 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								class  NASBench102API ( object ) : 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-11 00:46:02 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-11-14 13:55:42 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								  def  __init__ ( self ,  file_path_or_dict ,  verbose = True ) : 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-11 00:46:02 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    if  isinstance ( file_path_or_dict ,  str ) : 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								      if  verbose :  print ( ' try to create NAS-Bench-102 api from  {:} ' . format ( file_path_or_dict ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-11 00:46:02 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								      assert  os . path . isfile ( file_path_or_dict ) ,  ' invalid path :  {:} ' . format ( file_path_or_dict ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      file_path_or_dict  =  torch . load ( file_path_or_dict ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-14 13:55:42 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      file_path_or_dict  =  copy . deepcopy (  file_path_or_dict  ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-11 00:46:02 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    assert  isinstance ( file_path_or_dict ,  dict ) ,  ' It should be a dict instead of  {:} ' . format ( type ( file_path_or_dict ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    import  pdb ;  pdb . set_trace ( )  # we will update this api soon 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-11 00:46:02 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    keys  =  ( ' meta_archs ' ,  ' arch2infos ' ,  ' evaluated_indexes ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    for  key  in  keys :  assert  key  in  file_path_or_dict ,  ' Can not find key[ {:} ] in the dict ' . format ( key ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . meta_archs  =  copy . deepcopy (  file_path_or_dict [ ' meta_archs ' ]  ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-14 13:55:42 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    self . arch2infos  =  OrderedDict ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    for  xkey  in  sorted ( list ( file_path_or_dict [ ' arch2infos ' ] . keys ( ) ) ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      self . arch2infos [ xkey ]  =  ArchResults . create_from_state_dict (  file_path_or_dict [ ' arch2infos ' ] [ xkey ]  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . evaluated_indexes  =  sorted ( list ( file_path_or_dict [ ' evaluated_indexes ' ] ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-11 00:46:02 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    self . archstr2index  =  { } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    for  idx ,  arch  in  enumerate ( self . meta_archs ) : 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-14 13:55:42 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								      #assert arch.tostr() not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch.tostr()]) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      assert  arch  not  in  self . archstr2index ,  ' This [ {:} ]-th arch  {:}  already in the dict ( {:} ). ' . format ( idx ,  arch ,  self . archstr2index [ arch ] ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      self . archstr2index [  arch  ]  =  idx 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-11 00:46:02 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  __getitem__ ( self ,  index ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  copy . deepcopy (  self . meta_archs [ index ]  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  __len__ ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  len ( self . meta_archs ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  __repr__ ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  ( ' {name} ( {num} / {total}  architectures) ' . format ( name = self . __class__ . __name__ ,  num = len ( self . evaluated_indexes ) ,  total = len ( self . meta_archs ) ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  query_index_by_arch ( self ,  arch ) : 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-14 13:55:42 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    if  isinstance ( arch ,  str ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      if  arch  in  self . archstr2index :  arch_index  =  self . archstr2index [  arch  ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      else                          :  arch_index  =  - 1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    elif  hasattr ( arch ,  ' tostr ' ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      if  arch . tostr ( )  in  self . archstr2index :  arch_index  =  self . archstr2index [  arch . tostr ( )  ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      else                                  :  arch_index  =  - 1 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-11 00:46:02 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    else :  arch_index  =  - 1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  arch_index 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  query_by_arch ( self ,  arch ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    arch_index  =  self . query_index_by_arch ( arch ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if  arch_index  ==  - 1 :  return  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if  arch_index  in  self . arch2infos : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      strings  =  print_information ( self . arch2infos [  arch_index  ] ,  ' arch-index= {:} ' . format ( arch_index ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      return  ' \n ' . join ( strings ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      print  ( ' Find this arch-index :  {:} , but this arch is not evaluated. ' . format ( arch_index ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      return  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  query_by_index ( self ,  arch_index ,  dataname ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    assert  arch_index  in  self . arch2infos ,  ' arch_index [ {:} ] does not in arch2info ' . format ( arch_index ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    archInfo  =  copy . deepcopy (  self . arch2infos [  arch_index  ]  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    assert  dataname  in  archInfo . get_dataset_names ( ) ,  ' invalid dataset-name :  {:} ' . format ( dataname ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    info  =  archInfo . query ( dataname ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  info 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-11-14 13:55:42 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								  def  query_meta_info_by_index ( self ,  arch_index ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    assert  arch_index  in  self . arch2infos ,  ' arch_index [ {:} ] does not in arch2info ' . format ( arch_index ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    archInfo  =  copy . deepcopy (  self . arch2infos [  arch_index  ]  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  archInfo 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-11-11 00:46:02 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								  def  find_best ( self ,  dataset ,  metric_on_set ,  FLOP_max = None ,  Param_max = None ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    best_index ,  highest_accuracy  =  - 1 ,  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    for  i ,  idx  in  enumerate ( self . evaluated_indexes ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      flop ,  param ,  latency  =  self . arch2infos [ idx ] . get_comput_costs ( dataset ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      if  FLOP_max   is  not  None  and  flop   >  FLOP_max  :  continue 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      if  Param_max  is  not  None  and  param  >  Param_max :  continue 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      loss ,  accuracy  =  self . arch2infos [ idx ] . get_metrics ( dataset ,  metric_on_set ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      if  best_index  ==  - 1 : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        best_index ,  highest_accuracy  =  idx ,  accuracy 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      elif  highest_accuracy  <  accuracy : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        best_index ,  highest_accuracy  =  idx ,  accuracy 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  best_index 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  arch ( self ,  index ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    assert  0  < =  index  <  len ( self . meta_archs ) ,  ' invalid index :  {:}  vs.  {:} . ' . format ( index ,  len ( self . meta_archs ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  copy . deepcopy ( self . meta_archs [ index ] ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  show ( self ,  index = - 1 ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if  index  ==  - 1 :  # show all architectures 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      print ( self ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      for  i ,  idx  in  enumerate ( self . evaluated_indexes ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        print ( ' \n '  +  ' - '  *  10  +  '  The ( {:5d} / {:5d} )  {:06d} -th architecture!  ' . format ( i ,  len ( self . evaluated_indexes ) ,  idx )  +  ' - ' * 10 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        print ( ' arch :  {:} ' . format ( self . meta_archs [ idx ] ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        strings  =  print_information ( self . arch2infos [ idx ] ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        print ( ' > '  *  20 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        print ( ' \n ' . join ( strings ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        print ( ' < '  *  20 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      if  0  < =  index  <  len ( self . meta_archs ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        if  index  not  in  self . evaluated_indexes :  print ( ' The  {:} -th architecture has not been evaluated or not saved. ' . format ( index ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          strings  =  print_information ( self . arch2infos [ index ] ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          print ( ' \n ' . join ( strings ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        print ( ' This index ( {:} ) is out of range (0~ {:} ). ' . format ( index ,  len ( self . meta_archs ) ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								class  ArchResults ( object ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  __init__ ( self ,  arch_index ,  arch_str ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . arch_index    =  int ( arch_index ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . arch_str      =  copy . deepcopy ( arch_str ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . all_results   =  dict ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . dataset_seed  =  dict ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . clear_net_done  =  False 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  get_comput_costs ( self ,  dataset ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    x_seeds  =  self . dataset_seed [ dataset ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    results  =  [ self . all_results [  ( dataset ,  seed )  ]  for  seed  in  x_seeds ] 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    flops       =  [ result . flop  for  result  in  results ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    params      =  [ result . params  for  result  in  results ] 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-11 00:46:02 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    lantencies  =  [ result . get_latency ( )  for  result  in  results ] 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    lantencies  =  [ x  for  x  in  lantencies  if  x  >  0 ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    mean_latency  =  np . mean ( lantencies )  if  len ( lantencies )  >  0  else  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    time_infos  =  defaultdict ( list ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    for  result  in  results : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      time_info  =  result . get_times ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      for  key ,  value  in  time_info . items ( ) :  time_infos [ key ] . append (  value  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								     
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    info  =  { ' flops '   :  np . mean ( flops ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            ' params '  :  np . mean ( params ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            ' latency ' :  mean_latency } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    for  key ,  value  in  time_infos . items ( ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      if  len ( value )  >  0  and  value [ 0 ]  is  not  None : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        info [ key ]  =  np . mean ( value ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      else :  info [ key ]  =  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  info 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-11 00:46:02 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-11-19 11:58:04 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								  def  get_metrics ( self ,  dataset ,  setname ,  iepoch = None ,  is_random = False ) : 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-11 00:46:02 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    x_seeds  =  self . dataset_seed [ dataset ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    results  =  [ self . all_results [  ( dataset ,  seed )  ]  for  seed  in  x_seeds ] 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    infos    =  defaultdict ( list ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-11 00:46:02 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    for  result  in  results : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      if  setname  ==  ' train ' : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        info  =  result . get_train ( iepoch ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        info  =  result . get_eval ( setname ,  iepoch ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								      for  key ,  value  in  info . items ( ) :  infos [ key ] . append (  value  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return_info  =  dict ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-19 11:58:04 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    if  is_random : 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								      index  =  random . randint ( 0 ,  len ( results ) - 1 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      for  key ,  value  in  infos . items ( ) :  return_info [ key ]  =  value [ index ] 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-19 11:58:04 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    else : 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								      for  key ,  value  in  infos . items ( ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        if  len ( value )  >  0  and  value [ 0 ]  is  not  None : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          return_info [ key ]  =  np . mean ( value ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        else :  return_info [ key ]  =  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  return_info 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-11 00:46:02 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  show ( self ,  is_print = False ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  print_information ( self ,  None ,  is_print ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  get_dataset_names ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  list ( self . dataset_seed . keys ( ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  query ( self ,  dataset ,  seed = None ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if  seed  is  None : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      x_seeds  =  self . dataset_seed [ dataset ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      return  [ self . all_results [  ( dataset ,  seed )  ]  for  seed  in  x_seeds ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      return  self . all_results [  ( dataset ,  seed )  ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  arch_idx_str ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  ' {:06d} ' . format ( self . arch_index ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  update ( self ,  dataset_name ,  seed ,  result ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if  dataset_name  not  in  self . dataset_seed : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      self . dataset_seed [ dataset_name ]  =  [ ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    assert  seed  not  in  self . dataset_seed [ dataset_name ] ,  ' {:} -th arch alreadly has this seed ( {:} ) on  {:} ' . format ( self . arch_index ,  seed ,  dataset_name ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . dataset_seed [  dataset_name  ] . append (  seed  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . dataset_seed [  dataset_name  ]  =  sorted (  self . dataset_seed [  dataset_name  ]  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    assert  ( dataset_name ,  seed )  not  in  self . all_results 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . all_results [  ( dataset_name ,  seed )  ]  =  result 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . clear_net_done  =  False 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  state_dict ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    state_dict  =  dict ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    for  key ,  value  in  self . __dict__ . items ( ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      if  key  ==  ' all_results ' :  # contain the class of ResultsCount 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        xvalue  =  dict ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        assert  isinstance ( value ,  dict ) ,  ' invalid type of value for  {:}  :  {:} ' . format ( key ,  type ( value ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        for  _k ,  _v  in  value . items ( ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          assert  isinstance ( _v ,  ResultsCount ) ,  ' invalid type of value for  {:} / {:}  :  {:} ' . format ( key ,  _k ,  type ( _v ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          xvalue [ _k ]  =  _v . state_dict ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        xvalue  =  value 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      state_dict [ key ]  =  xvalue 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  state_dict 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  load_state_dict ( self ,  state_dict ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    new_state_dict  =  dict ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    for  key ,  value  in  state_dict . items ( ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      if  key  ==  ' all_results ' :  # to convert to the class of ResultsCount 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        xvalue  =  dict ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        assert  isinstance ( value ,  dict ) ,  ' invalid type of value for  {:}  :  {:} ' . format ( key ,  type ( value ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        for  _k ,  _v  in  value . items ( ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          xvalue [ _k ]  =  ResultsCount . create_from_state_dict ( _v ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      else :  xvalue  =  value 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      new_state_dict [ key ]  =  xvalue 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . __dict__ . update ( new_state_dict ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  @staticmethod 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  create_from_state_dict ( state_dict_or_file ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    x  =  ArchResults ( - 1 ,  - 1 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if  isinstance ( state_dict_or_file ,  str ) :  # a file path 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      state_dict  =  torch . load ( state_dict_or_file ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    elif  isinstance ( state_dict_or_file ,  dict ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      state_dict  =  state_dict_or_file 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      raise  ValueError ( ' invalid type of state_dict_or_file :  {:} ' . format ( type ( state_dict_or_file ) ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    x . load_state_dict ( state_dict ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  x 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  clear_params ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    for  key ,  result  in  self . all_results . items ( ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      result . net_state_dict  =  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . clear_net_done  =  True  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  __repr__ ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  ( ' {name} (arch-index= {index} , arch= {arch} ,  {num}  runs, clear= {clear} ) ' . format ( name = self . __class__ . __name__ ,  index = self . arch_index ,  arch = self . arch_str ,  num = len ( self . all_results ) ,  clear = self . clear_net_done ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								class  ResultsCount ( object ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  __init__ ( self ,  name ,  state_dict ,  train_accs ,  train_losses ,  params ,  flop ,  arch_config ,  seed ,  epochs ,  latency ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . name            =  name 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . net_state_dict  =  state_dict 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    self . train_acc1es  =  copy . deepcopy ( train_accs ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . train_acc5es  =  None 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-11 00:46:02 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    self . train_losses  =  copy . deepcopy ( train_losses ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    self . train_times   =  None 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-11 00:46:02 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    self . arch_config   =  copy . deepcopy ( arch_config ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . params      =  params 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . flop        =  flop 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . seed        =  seed 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . epochs      =  epochs 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . latency     =  latency 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    # evaluation results 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . reset_eval ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								  def  update_train_info ( self ,  train_acc1es ,  train_acc5es ,  train_losses ,  train_times ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . train_acc1es  =  train_acc1es 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . train_acc5es  =  train_acc5es 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . train_losses  =  train_losses 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . train_times   =  train_times 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-11-11 00:46:02 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								  def  reset_eval ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . eval_names   =  [ ] 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    self . eval_acc1es  =  { } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . eval_times   =  { } 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-11 00:46:02 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    self . eval_losses  =  { } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  update_latency ( self ,  latency ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . latency  =  copy . deepcopy (  latency  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								  def  update_eval ( self ,  accs ,  losses ,  times ) :  # old version 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    data_names  =  set ( [ x . split ( ' @ ' ) [ 0 ]  for  x  in  accs . keys ( ) ] ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    for  data_name  in  data_names : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      assert  data_name  not  in  self . eval_names ,  ' {:}  has already been added into eval-names ' . format ( data_name ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      self . eval_names . append (  data_name  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      for  iepoch  in  range ( self . epochs ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        xkey  =  ' {:} @ {:} ' . format ( data_name ,  iepoch ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        self . eval_acc1es [  xkey  ]  =  accs [  xkey  ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        self . eval_losses [  xkey  ]  =  losses [  xkey  ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        self . eval_times  [  xkey  ]  =  times [  xkey  ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  update_OLD_eval ( self ,  name ,  accs ,  losses ) :  # old version 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-11 00:46:02 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    assert  name  not  in  self . eval_names ,  ' {:}  has already added ' . format ( name ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . eval_names . append (  name  ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    for  iepoch  in  range ( self . epochs ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      if  iepoch  in  accs : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        self . eval_acc1es [ ' {:} @ {:} ' . format ( name , iepoch ) ]  =  accs [ iepoch ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        self . eval_losses [ ' {:} @ {:} ' . format ( name , iepoch ) ]  =  losses [ iepoch ] 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-11 00:46:02 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  __repr__ ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    num_eval  =  len ( self . eval_names ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    set_name  =  ' [ '  +  ' ,  ' . join ( self . eval_names )  +  ' ] ' 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  ( ' {name} ( {xname} , arch= {arch} , FLOP= {flop:.2f} M, Param= {param:.3f} MB, seed= {seed} ,  {num_eval}  eval-sets:  {set_name} ) ' . format ( name = self . __class__ . __name__ ,  xname = self . name ,  arch = self . arch_config [ ' arch_str ' ] ,  flop = self . flop ,  param = self . params ,  seed = self . seed ,  num_eval = num_eval ,  set_name = set_name ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-11 00:46:02 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								  def  get_latency ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if  self . latency  is  None :  return  - 1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    else :  return  sum ( self . latency )  /  len ( self . latency ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  get_times ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if  self . train_times  is  not  None  and  isinstance ( self . train_times ,  dict ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      train_times  =  list (  self . train_times . values ( )  ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      time_info  =  { ' T-train@epoch ' :  np . mean ( train_times ) ,  ' T-train@total ' :  np . sum ( train_times ) } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      for  name  in  self . eval_names : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        xtimes  =  [ self . eval_times [ ' {:} @ {:} ' . format ( name , i ) ]  for  i  in  range ( self . epochs ) ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        time_info [ ' T- {:} @epoch ' . format ( name ) ]  =  np . mean ( xtimes ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        time_info [ ' T- {:} @total ' . format ( name ) ]  =  np . sum ( xtimes ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      time_info  =  { ' T-train@epoch ' :                  None ,  ' T-train@total ' :                None  } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      for  name  in  self . eval_names : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        time_info [ ' T- {:} @epoch ' . format ( name ) ]  =  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        time_info [ ' T- {:} @total ' . format ( name ) ]  =  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  time_info 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  get_eval_set ( self ) : 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-11 00:46:02 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    return  self . eval_names 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  get_train ( self ,  iepoch = None ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if  iepoch  is  None :  iepoch  =  self . epochs - 1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    assert  0  < =  iepoch  <  self . epochs ,  ' invalid iepoch= {:}  <  {:} ' . format ( iepoch ,  self . epochs ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    if  self . train_times  is  not  None :  xtime  =  self . train_times [ iepoch ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    else                            :  xtime  =  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  { ' iepoch '   :  iepoch , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            ' loss '     :  self . train_losses [ iepoch ] , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            ' accuracy ' :  self . train_acc1es [ iepoch ] , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            ' time '     :  xtime } 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-11 00:46:02 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  get_eval ( self ,  name ,  iepoch = None ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if  iepoch  is  None :  iepoch  =  self . epochs - 1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    assert  0  < =  iepoch  <  self . epochs ,  ' invalid iepoch= {:}  <  {:} ' . format ( iepoch ,  self . epochs ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    if  isinstance ( self . eval_times , dict )  and  len ( self . eval_times )  >  0 : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      xtime  =  self . eval_times [ ' {:} @ {:} ' . format ( name , iepoch ) ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    else :  xtime  =  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  { ' iepoch '   :  iepoch , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            ' loss '     :  self . eval_losses [ ' {:} @ {:} ' . format ( name , iepoch ) ] , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            ' accuracy ' :  self . eval_acc1es [ ' {:} @ {:} ' . format ( name , iepoch ) ] , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            ' time '     :  xtime } 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-11 00:46:02 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  get_net_param ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  self . net_state_dict 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 20:41:49 +11:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								  def  get_config ( self ,  str2structure ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    #return copy.deepcopy(self.arch_config) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  { ' name ' :  ' infer.tiny ' ,  ' C ' :  self . arch_config [ ' channel ' ] ,  \
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            ' N '    :  self . arch_config [ ' num_cells ' ] ,  \
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            ' genotype ' :  str2structure ( self . arch_config [ ' arch_str ' ] ) ,  ' num_classes ' :  self . arch_config [ ' class_num ' ] } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-11-11 00:46:02 +11:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								  def  state_dict ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    _state_dict  =  { key :  value  for  key ,  value  in  self . __dict__ . items ( ) } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  _state_dict 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  load_state_dict ( self ,  state_dict ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    self . __dict__ . update ( state_dict ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  @staticmethod 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  def  create_from_state_dict ( state_dict ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    x  =  ResultsCount ( None ,  None ,  None ,  None ,  None ,  None ,  None ,  None ,  None ,  None ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    x . load_state_dict ( state_dict ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  x