################################################## # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # ################################################## import torch from os import path as osp __all__ = ['get_cell_based_tiny_net', 'get_search_spaces'] # the cell-based NAS models def get_cell_based_tiny_net(config): group_names = ['GDAS', 'DARTS'] if config.name in group_names: from .cell_searchs import nas_super_nets from .cell_operations import SearchSpaceNames if isinstance(config.space, str): search_space = SearchSpaceNames[config.space] else: search_space = config.space return nas_super_nets[config.name]( config.C, config.N, config.max_nodes, config.num_classes, search_space, config.affine) else: raise ValueError('invalid network name : {:}'.format(config.name)) # obtain the search space, i.e., a dict mapping the operation name into a python-function for this op def get_search_spaces(xtype, name): if xtype == 'cell': from .cell_operations import SearchSpaceNames assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys()) return SearchSpaceNames[name] else: raise ValueError('invalid search-space type is {:}'.format(xtype))