Support accumulate and reset time function for API
This commit is contained in:
		| @@ -55,10 +55,16 @@ def get_cell_based_tiny_net(config): | ||||
|  | ||||
| # obtain the search space, i.e., a dict mapping the operation name into a python-function for this op | ||||
| def get_search_spaces(xtype, name) -> List[Text]: | ||||
|   if xtype == 'cell': | ||||
|   if xtype == 'cell' or xtype == 'tss':  # The topology search space. | ||||
|     from .cell_operations import SearchSpaceNames | ||||
|     assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys()) | ||||
|     return SearchSpaceNames[name] | ||||
|   elif xtype == 'sss':  # The size search space. | ||||
|     if name == 'nas-bench-301': | ||||
|       return {'candidates': [8, 16, 24, 32, 40, 48, 56, 64], | ||||
|               'numbers': 5} | ||||
|     else: | ||||
|       raise ValueError('Invalid name : {:}'.format(name)) | ||||
|   else: | ||||
|     raise ValueError('invalid search-space type is {:}'.format(xtype)) | ||||
|  | ||||
|   | ||||
| @@ -26,6 +26,7 @@ DARTS_SPACE           = ['none', 'skip_connect', 'dua_sepc_3x3', 'dua_sepc_5x5', | ||||
|  | ||||
| SearchSpaceNames = {'connect-nas'  : CONNECT_NAS_BENCHMARK, | ||||
|                     'nas-bench-201': NAS_BENCH_201, | ||||
|                     'nas-bench-301': NAS_BENCH_201, | ||||
|                     'darts'        : DARTS_SPACE} | ||||
|  | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user