33 lines
		
	
	
		
			1.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			33 lines
		
	
	
		
			1.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| ##################################################
 | |
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
 | |
| ##################################################
 | |
| import torch
 | |
| from os import path as osp
 | |
| 
 | |
| __all__ = ['get_cell_based_tiny_net', 'get_search_spaces']
 | |
| 
 | |
| 
 | |
| # the cell-based NAS models
 | |
| def get_cell_based_tiny_net(config):
 | |
|   group_names = ['GDAS', 'DARTS']
 | |
|   if config.name in group_names:
 | |
|     from .cell_searchs import nas_super_nets
 | |
|     from .cell_operations import SearchSpaceNames
 | |
|     if isinstance(config.space, str): search_space = SearchSpaceNames[config.space]
 | |
|     else: search_space = config.space
 | |
|     return nas_super_nets[config.name](
 | |
|                   config.C, config.N, config.max_nodes,
 | |
|                   config.num_classes, search_space, config.affine)
 | |
|   else:
 | |
|     raise ValueError('invalid network name : {:}'.format(config.name))
 | |
| 
 | |
| 
 | |
| # obtain the search space, i.e., a dict mapping the operation name into a python-function for this op
 | |
| def get_search_spaces(xtype, name):
 | |
|   if xtype == 'cell':
 | |
|     from .cell_operations import SearchSpaceNames
 | |
|     assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys())
 | |
|     return SearchSpaceNames[name]
 | |
|   else:
 | |
|     raise ValueError('invalid search-space type is {:}'.format(xtype))
 |