2019-11-15 07:15:07 +01:00
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
2019-11-19 01:58:04 +01:00
import os , sys , copy , random , torch , numpy as np
2019-12-20 10:41:49 +01:00
from collections import OrderedDict , defaultdict
2019-11-10 14:46:02 +01:00
def print_information ( information , extra_info = None , show = False ) :
dataset_names = information . get_dataset_names ( )
strings = [ information . arch_str , ' datasets : {:} , extra-info : {:} ' . format ( dataset_names , extra_info ) ]
def metric2str ( loss , acc ) :
return ' loss = {:.3f} , top1 = {:.2f} % ' . format ( loss , acc )
for ida , dataset in enumerate ( dataset_names ) :
flop , param , latency = information . get_comput_costs ( dataset )
str1 = ' {:14s} FLOP= {:6.2f} M, Params= {:.3f} MB, latency= {:} ms. ' . format ( dataset , flop , param , ' {:.2f} ' . format ( latency * 1000 ) if latency > 0 else None )
train_loss , train_acc = information . get_metrics ( dataset , ' train ' )
if dataset == ' cifar10-valid ' :
valid_loss , valid_acc = information . get_metrics ( dataset , ' x-valid ' )
str2 = ' {:14s} train : [ {:} ], valid : [ {:} ] ' . format ( dataset , metric2str ( train_loss , train_acc ) , metric2str ( valid_loss , valid_acc ) )
elif dataset == ' cifar10 ' :
test__loss , test__acc = information . get_metrics ( dataset , ' ori-test ' )
str2 = ' {:14s} train : [ {:} ], test : [ {:} ] ' . format ( dataset , metric2str ( train_loss , train_acc ) , metric2str ( test__loss , test__acc ) )
else :
valid_loss , valid_acc = information . get_metrics ( dataset , ' x-valid ' )
test__loss , test__acc = information . get_metrics ( dataset , ' x-test ' )
str2 = ' {:14s} train : [ {:} ], valid : [ {:} ], test : [ {:} ] ' . format ( dataset , metric2str ( train_loss , train_acc ) , metric2str ( valid_loss , valid_acc ) , metric2str ( test__loss , test__acc ) )
strings + = [ str1 , str2 ]
if show : print ( ' \n ' . join ( strings ) )
return strings
2019-12-20 10:41:49 +01:00
class NASBench102API ( object ) :
2019-11-10 14:46:02 +01:00
2019-11-14 03:55:42 +01:00
def __init__ ( self , file_path_or_dict , verbose = True ) :
2019-11-10 14:46:02 +01:00
if isinstance ( file_path_or_dict , str ) :
2019-12-20 10:41:49 +01:00
if verbose : print ( ' try to create NAS-Bench-102 api from {:} ' . format ( file_path_or_dict ) )
2019-11-10 14:46:02 +01:00
assert os . path . isfile ( file_path_or_dict ) , ' invalid path : {:} ' . format ( file_path_or_dict )
file_path_or_dict = torch . load ( file_path_or_dict )
2019-11-14 03:55:42 +01:00
else :
file_path_or_dict = copy . deepcopy ( file_path_or_dict )
2019-11-10 14:46:02 +01:00
assert isinstance ( file_path_or_dict , dict ) , ' It should be a dict instead of {:} ' . format ( type ( file_path_or_dict ) )
2019-12-20 10:41:49 +01:00
import pdb ; pdb . set_trace ( ) # we will update this api soon
2019-11-10 14:46:02 +01:00
keys = ( ' meta_archs ' , ' arch2infos ' , ' evaluated_indexes ' )
for key in keys : assert key in file_path_or_dict , ' Can not find key[ {:} ] in the dict ' . format ( key )
self . meta_archs = copy . deepcopy ( file_path_or_dict [ ' meta_archs ' ] )
2019-11-14 03:55:42 +01:00
self . arch2infos = OrderedDict ( )
for xkey in sorted ( list ( file_path_or_dict [ ' arch2infos ' ] . keys ( ) ) ) :
self . arch2infos [ xkey ] = ArchResults . create_from_state_dict ( file_path_or_dict [ ' arch2infos ' ] [ xkey ] )
self . evaluated_indexes = sorted ( list ( file_path_or_dict [ ' evaluated_indexes ' ] ) )
2019-11-10 14:46:02 +01:00
self . archstr2index = { }
for idx , arch in enumerate ( self . meta_archs ) :
2019-11-14 03:55:42 +01:00
#assert arch.tostr() not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch.tostr()])
assert arch not in self . archstr2index , ' This [ {:} ]-th arch {:} already in the dict ( {:} ). ' . format ( idx , arch , self . archstr2index [ arch ] )
self . archstr2index [ arch ] = idx
2019-11-10 14:46:02 +01:00
def __getitem__ ( self , index ) :
return copy . deepcopy ( self . meta_archs [ index ] )
def __len__ ( self ) :
return len ( self . meta_archs )
def __repr__ ( self ) :
return ( ' {name} ( {num} / {total} architectures) ' . format ( name = self . __class__ . __name__ , num = len ( self . evaluated_indexes ) , total = len ( self . meta_archs ) ) )
def query_index_by_arch ( self , arch ) :
2019-11-14 03:55:42 +01:00
if isinstance ( arch , str ) :
if arch in self . archstr2index : arch_index = self . archstr2index [ arch ]
else : arch_index = - 1
elif hasattr ( arch , ' tostr ' ) :
if arch . tostr ( ) in self . archstr2index : arch_index = self . archstr2index [ arch . tostr ( ) ]
else : arch_index = - 1
2019-11-10 14:46:02 +01:00
else : arch_index = - 1
return arch_index
def query_by_arch ( self , arch ) :
arch_index = self . query_index_by_arch ( arch )
if arch_index == - 1 : return None
if arch_index in self . arch2infos :
strings = print_information ( self . arch2infos [ arch_index ] , ' arch-index= {:} ' . format ( arch_index ) )
return ' \n ' . join ( strings )
else :
print ( ' Find this arch-index : {:} , but this arch is not evaluated. ' . format ( arch_index ) )
return None
def query_by_index ( self , arch_index , dataname ) :
assert arch_index in self . arch2infos , ' arch_index [ {:} ] does not in arch2info ' . format ( arch_index )
archInfo = copy . deepcopy ( self . arch2infos [ arch_index ] )
assert dataname in archInfo . get_dataset_names ( ) , ' invalid dataset-name : {:} ' . format ( dataname )
info = archInfo . query ( dataname )
return info
2019-11-14 03:55:42 +01:00
def query_meta_info_by_index ( self , arch_index ) :
assert arch_index in self . arch2infos , ' arch_index [ {:} ] does not in arch2info ' . format ( arch_index )
archInfo = copy . deepcopy ( self . arch2infos [ arch_index ] )
return archInfo
2019-11-10 14:46:02 +01:00
def find_best ( self , dataset , metric_on_set , FLOP_max = None , Param_max = None ) :
best_index , highest_accuracy = - 1 , None
for i , idx in enumerate ( self . evaluated_indexes ) :
flop , param , latency = self . arch2infos [ idx ] . get_comput_costs ( dataset )
if FLOP_max is not None and flop > FLOP_max : continue
if Param_max is not None and param > Param_max : continue
loss , accuracy = self . arch2infos [ idx ] . get_metrics ( dataset , metric_on_set )
if best_index == - 1 :
best_index , highest_accuracy = idx , accuracy
elif highest_accuracy < accuracy :
best_index , highest_accuracy = idx , accuracy
return best_index
def arch ( self , index ) :
assert 0 < = index < len ( self . meta_archs ) , ' invalid index : {:} vs. {:} . ' . format ( index , len ( self . meta_archs ) )
return copy . deepcopy ( self . meta_archs [ index ] )
def show ( self , index = - 1 ) :
if index == - 1 : # show all architectures
print ( self )
for i , idx in enumerate ( self . evaluated_indexes ) :
print ( ' \n ' + ' - ' * 10 + ' The ( {:5d} / {:5d} ) {:06d} -th architecture! ' . format ( i , len ( self . evaluated_indexes ) , idx ) + ' - ' * 10 )
print ( ' arch : {:} ' . format ( self . meta_archs [ idx ] ) )
strings = print_information ( self . arch2infos [ idx ] )
print ( ' > ' * 20 )
print ( ' \n ' . join ( strings ) )
print ( ' < ' * 20 )
else :
if 0 < = index < len ( self . meta_archs ) :
if index not in self . evaluated_indexes : print ( ' The {:} -th architecture has not been evaluated or not saved. ' . format ( index ) )
else :
strings = print_information ( self . arch2infos [ index ] )
print ( ' \n ' . join ( strings ) )
else :
print ( ' This index ( {:} ) is out of range (0~ {:} ). ' . format ( index , len ( self . meta_archs ) ) )
class ArchResults ( object ) :
def __init__ ( self , arch_index , arch_str ) :
self . arch_index = int ( arch_index )
self . arch_str = copy . deepcopy ( arch_str )
self . all_results = dict ( )
self . dataset_seed = dict ( )
self . clear_net_done = False
def get_comput_costs ( self , dataset ) :
x_seeds = self . dataset_seed [ dataset ]
results = [ self . all_results [ ( dataset , seed ) ] for seed in x_seeds ]
2019-12-20 10:41:49 +01:00
flops = [ result . flop for result in results ]
params = [ result . params for result in results ]
2019-11-10 14:46:02 +01:00
lantencies = [ result . get_latency ( ) for result in results ]
2019-12-20 10:41:49 +01:00
lantencies = [ x for x in lantencies if x > 0 ]
mean_latency = np . mean ( lantencies ) if len ( lantencies ) > 0 else None
time_infos = defaultdict ( list )
for result in results :
time_info = result . get_times ( )
for key , value in time_info . items ( ) : time_infos [ key ] . append ( value )
info = { ' flops ' : np . mean ( flops ) ,
' params ' : np . mean ( params ) ,
' latency ' : mean_latency }
for key , value in time_infos . items ( ) :
if len ( value ) > 0 and value [ 0 ] is not None :
info [ key ] = np . mean ( value )
else : info [ key ] = None
return info
2019-11-10 14:46:02 +01:00
2019-11-19 01:58:04 +01:00
def get_metrics ( self , dataset , setname , iepoch = None , is_random = False ) :
2019-11-10 14:46:02 +01:00
x_seeds = self . dataset_seed [ dataset ]
results = [ self . all_results [ ( dataset , seed ) ] for seed in x_seeds ]
2019-12-20 10:41:49 +01:00
infos = defaultdict ( list )
2019-11-10 14:46:02 +01:00
for result in results :
if setname == ' train ' :
info = result . get_train ( iepoch )
else :
info = result . get_eval ( setname , iepoch )
2019-12-20 10:41:49 +01:00
for key , value in info . items ( ) : infos [ key ] . append ( value )
return_info = dict ( )
2019-11-19 01:58:04 +01:00
if is_random :
2019-12-20 10:41:49 +01:00
index = random . randint ( 0 , len ( results ) - 1 )
for key , value in infos . items ( ) : return_info [ key ] = value [ index ]
2019-11-19 01:58:04 +01:00
else :
2019-12-20 10:41:49 +01:00
for key , value in infos . items ( ) :
if len ( value ) > 0 and value [ 0 ] is not None :
return_info [ key ] = np . mean ( value )
else : return_info [ key ] = None
return return_info
2019-11-10 14:46:02 +01:00
def show ( self , is_print = False ) :
return print_information ( self , None , is_print )
def get_dataset_names ( self ) :
return list ( self . dataset_seed . keys ( ) )
def query ( self , dataset , seed = None ) :
if seed is None :
x_seeds = self . dataset_seed [ dataset ]
return [ self . all_results [ ( dataset , seed ) ] for seed in x_seeds ]
else :
return self . all_results [ ( dataset , seed ) ]
def arch_idx_str ( self ) :
return ' {:06d} ' . format ( self . arch_index )
def update ( self , dataset_name , seed , result ) :
if dataset_name not in self . dataset_seed :
self . dataset_seed [ dataset_name ] = [ ]
assert seed not in self . dataset_seed [ dataset_name ] , ' {:} -th arch alreadly has this seed ( {:} ) on {:} ' . format ( self . arch_index , seed , dataset_name )
self . dataset_seed [ dataset_name ] . append ( seed )
self . dataset_seed [ dataset_name ] = sorted ( self . dataset_seed [ dataset_name ] )
assert ( dataset_name , seed ) not in self . all_results
self . all_results [ ( dataset_name , seed ) ] = result
self . clear_net_done = False
def state_dict ( self ) :
state_dict = dict ( )
for key , value in self . __dict__ . items ( ) :
if key == ' all_results ' : # contain the class of ResultsCount
xvalue = dict ( )
assert isinstance ( value , dict ) , ' invalid type of value for {:} : {:} ' . format ( key , type ( value ) )
for _k , _v in value . items ( ) :
assert isinstance ( _v , ResultsCount ) , ' invalid type of value for {:} / {:} : {:} ' . format ( key , _k , type ( _v ) )
xvalue [ _k ] = _v . state_dict ( )
else :
xvalue = value
state_dict [ key ] = xvalue
return state_dict
def load_state_dict ( self , state_dict ) :
new_state_dict = dict ( )
for key , value in state_dict . items ( ) :
if key == ' all_results ' : # to convert to the class of ResultsCount
xvalue = dict ( )
assert isinstance ( value , dict ) , ' invalid type of value for {:} : {:} ' . format ( key , type ( value ) )
for _k , _v in value . items ( ) :
xvalue [ _k ] = ResultsCount . create_from_state_dict ( _v )
else : xvalue = value
new_state_dict [ key ] = xvalue
self . __dict__ . update ( new_state_dict )
@staticmethod
def create_from_state_dict ( state_dict_or_file ) :
x = ArchResults ( - 1 , - 1 )
if isinstance ( state_dict_or_file , str ) : # a file path
state_dict = torch . load ( state_dict_or_file )
elif isinstance ( state_dict_or_file , dict ) :
state_dict = state_dict_or_file
else :
raise ValueError ( ' invalid type of state_dict_or_file : {:} ' . format ( type ( state_dict_or_file ) ) )
x . load_state_dict ( state_dict )
return x
def clear_params ( self ) :
for key , result in self . all_results . items ( ) :
result . net_state_dict = None
self . clear_net_done = True
def __repr__ ( self ) :
return ( ' {name} (arch-index= {index} , arch= {arch} , {num} runs, clear= {clear} ) ' . format ( name = self . __class__ . __name__ , index = self . arch_index , arch = self . arch_str , num = len ( self . all_results ) , clear = self . clear_net_done ) )
class ResultsCount ( object ) :
def __init__ ( self , name , state_dict , train_accs , train_losses , params , flop , arch_config , seed , epochs , latency ) :
self . name = name
self . net_state_dict = state_dict
2019-12-20 10:41:49 +01:00
self . train_acc1es = copy . deepcopy ( train_accs )
self . train_acc5es = None
2019-11-10 14:46:02 +01:00
self . train_losses = copy . deepcopy ( train_losses )
2019-12-20 10:41:49 +01:00
self . train_times = None
2019-11-10 14:46:02 +01:00
self . arch_config = copy . deepcopy ( arch_config )
self . params = params
self . flop = flop
self . seed = seed
self . epochs = epochs
self . latency = latency
# evaluation results
self . reset_eval ( )
2019-12-20 10:41:49 +01:00
def update_train_info ( self , train_acc1es , train_acc5es , train_losses , train_times ) :
self . train_acc1es = train_acc1es
self . train_acc5es = train_acc5es
self . train_losses = train_losses
self . train_times = train_times
2019-11-10 14:46:02 +01:00
def reset_eval ( self ) :
self . eval_names = [ ]
2019-12-20 10:41:49 +01:00
self . eval_acc1es = { }
self . eval_times = { }
2019-11-10 14:46:02 +01:00
self . eval_losses = { }
def update_latency ( self , latency ) :
self . latency = copy . deepcopy ( latency )
2019-12-20 10:41:49 +01:00
def update_eval ( self , accs , losses , times ) : # old version
data_names = set ( [ x . split ( ' @ ' ) [ 0 ] for x in accs . keys ( ) ] )
for data_name in data_names :
assert data_name not in self . eval_names , ' {:} has already been added into eval-names ' . format ( data_name )
self . eval_names . append ( data_name )
for iepoch in range ( self . epochs ) :
xkey = ' {:} @ {:} ' . format ( data_name , iepoch )
self . eval_acc1es [ xkey ] = accs [ xkey ]
self . eval_losses [ xkey ] = losses [ xkey ]
self . eval_times [ xkey ] = times [ xkey ]
def update_OLD_eval ( self , name , accs , losses ) : # old version
2019-11-10 14:46:02 +01:00
assert name not in self . eval_names , ' {:} has already added ' . format ( name )
self . eval_names . append ( name )
2019-12-20 10:41:49 +01:00
for iepoch in range ( self . epochs ) :
if iepoch in accs :
self . eval_acc1es [ ' {:} @ {:} ' . format ( name , iepoch ) ] = accs [ iepoch ]
self . eval_losses [ ' {:} @ {:} ' . format ( name , iepoch ) ] = losses [ iepoch ]
2019-11-10 14:46:02 +01:00
def __repr__ ( self ) :
num_eval = len ( self . eval_names )
2019-12-20 10:41:49 +01:00
set_name = ' [ ' + ' , ' . join ( self . eval_names ) + ' ] '
return ( ' {name} ( {xname} , arch= {arch} , FLOP= {flop:.2f} M, Param= {param:.3f} MB, seed= {seed} , {num_eval} eval-sets: {set_name} ) ' . format ( name = self . __class__ . __name__ , xname = self . name , arch = self . arch_config [ ' arch_str ' ] , flop = self . flop , param = self . params , seed = self . seed , num_eval = num_eval , set_name = set_name ) )
2019-11-10 14:46:02 +01:00
2019-12-20 10:41:49 +01:00
def get_latency ( self ) :
if self . latency is None : return - 1
else : return sum ( self . latency ) / len ( self . latency )
def get_times ( self ) :
if self . train_times is not None and isinstance ( self . train_times , dict ) :
train_times = list ( self . train_times . values ( ) )
time_info = { ' T-train@epoch ' : np . mean ( train_times ) , ' T-train@total ' : np . sum ( train_times ) }
for name in self . eval_names :
xtimes = [ self . eval_times [ ' {:} @ {:} ' . format ( name , i ) ] for i in range ( self . epochs ) ]
time_info [ ' T- {:} @epoch ' . format ( name ) ] = np . mean ( xtimes )
time_info [ ' T- {:} @total ' . format ( name ) ] = np . sum ( xtimes )
else :
time_info = { ' T-train@epoch ' : None , ' T-train@total ' : None }
for name in self . eval_names :
time_info [ ' T- {:} @epoch ' . format ( name ) ] = None
time_info [ ' T- {:} @total ' . format ( name ) ] = None
return time_info
def get_eval_set ( self ) :
2019-11-10 14:46:02 +01:00
return self . eval_names
def get_train ( self , iepoch = None ) :
if iepoch is None : iepoch = self . epochs - 1
assert 0 < = iepoch < self . epochs , ' invalid iepoch= {:} < {:} ' . format ( iepoch , self . epochs )
2019-12-20 10:41:49 +01:00
if self . train_times is not None : xtime = self . train_times [ iepoch ]
else : xtime = None
return { ' iepoch ' : iepoch ,
' loss ' : self . train_losses [ iepoch ] ,
' accuracy ' : self . train_acc1es [ iepoch ] ,
' time ' : xtime }
2019-11-10 14:46:02 +01:00
def get_eval ( self , name , iepoch = None ) :
if iepoch is None : iepoch = self . epochs - 1
assert 0 < = iepoch < self . epochs , ' invalid iepoch= {:} < {:} ' . format ( iepoch , self . epochs )
2019-12-20 10:41:49 +01:00
if isinstance ( self . eval_times , dict ) and len ( self . eval_times ) > 0 :
xtime = self . eval_times [ ' {:} @ {:} ' . format ( name , iepoch ) ]
else : xtime = None
return { ' iepoch ' : iepoch ,
' loss ' : self . eval_losses [ ' {:} @ {:} ' . format ( name , iepoch ) ] ,
' accuracy ' : self . eval_acc1es [ ' {:} @ {:} ' . format ( name , iepoch ) ] ,
' time ' : xtime }
2019-11-10 14:46:02 +01:00
def get_net_param ( self ) :
return self . net_state_dict
2019-12-20 10:41:49 +01:00
def get_config ( self , str2structure ) :
#return copy.deepcopy(self.arch_config)
return { ' name ' : ' infer.tiny ' , ' C ' : self . arch_config [ ' channel ' ] , \
' N ' : self . arch_config [ ' num_cells ' ] , \
' genotype ' : str2structure ( self . arch_config [ ' arch_str ' ] ) , ' num_classes ' : self . arch_config [ ' class_num ' ] }
2019-11-10 14:46:02 +01:00
def state_dict ( self ) :
_state_dict = { key : value for key , value in self . __dict__ . items ( ) }
return _state_dict
def load_state_dict ( self , state_dict ) :
self . __dict__ . update ( state_dict )
@staticmethod
def create_from_state_dict ( state_dict ) :
x = ResultsCount ( None , None , None , None , None , None , None , None , None , None )
x . load_state_dict ( state_dict )
return x