Update find_best API
This commit is contained in:
		| @@ -92,6 +92,10 @@ class ImageNet16(data.Dataset): | ||||
|     #std_data  = np.mean(np.mean(std_data, axis=0), axis=0) | ||||
|     #print ('Std  : {:}'.format(std_data)) | ||||
|  | ||||
|   def __repr__(self): | ||||
|     return ('{name}({num} images, {classes} classes)'.format(name=self.__class__.__name__, num=len(self.data), classes=len(set(self.targets)))) | ||||
|  | ||||
|  | ||||
|   def __getitem__(self, index): | ||||
|     img, target = self.data[index], self.targets[index] - 1 | ||||
|  | ||||
| @@ -114,16 +118,16 @@ class ImageNet16(data.Dataset): | ||||
|         return False | ||||
|     return True | ||||
|  | ||||
| # | ||||
| """ | ||||
| if __name__ == '__main__': | ||||
|   train = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None)  | ||||
|   valid = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False, None)  | ||||
|   train = ImageNet16('~/.torch/cifar.python/ImageNet16', True , None)  | ||||
|   valid = ImageNet16('~/.torch/cifar.python/ImageNet16', False, None)  | ||||
|  | ||||
|   print ( len(train) ) | ||||
|   print ( len(valid) ) | ||||
|   image, label = train[111] | ||||
|   trainX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None, 200) | ||||
|   validX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False , None, 200) | ||||
|   trainX = ImageNet16('~/.torch/cifar.python/ImageNet16', True , None, 200) | ||||
|   validX = ImageNet16('~/.torch/cifar.python/ImageNet16', False , None, 200) | ||||
|   print ( len(trainX) ) | ||||
|   print ( len(validX) ) | ||||
|   #import pdb; pdb.set_trace() | ||||
| """ | ||||
|   | ||||
| @@ -482,6 +482,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): | ||||
|     best_index, highest_accuracy = -1, None | ||||
|     evaluated_indexes = sorted(list(self.evaluated_indexes)) | ||||
|     for arch_index in evaluated_indexes: | ||||
|       self._prepare_info(arch_index) | ||||
|       arch_info = self.arch2infos_dict[arch_index][hp] | ||||
|       info = arch_info.get_compute_costs(dataset)  # the information of costs | ||||
|       flop, param, latency = info['flops'], info['params'], info['latency'] | ||||
| @@ -622,6 +623,8 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): | ||||
|         print('<' * 40 + '------------' + '<' * 40) | ||||
|     else: | ||||
|       if 0 <= index < len(self.meta_archs): | ||||
|         if index not in self.evaluated_indexes: | ||||
|           self._prepare_info(index) | ||||
|         if index not in self.evaluated_indexes: | ||||
|           print('The {:}-th architecture has not been evaluated ' | ||||
|                 'or not saved.'.format(index)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user