diff --git a/lib/nas_201_api/api.py b/lib/nas_201_api/api.py index 15566c4..f7c44cc 100644 --- a/lib/nas_201_api/api.py +++ b/lib/nas_201_api/api.py @@ -251,6 +251,26 @@ class NASBench201API(object): else: print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs))) + # This func shows how to read the string0based architecture encoding + # the same as the `str2structure` func in `AutoDL-Projects/lib/models/cell_searchs/genotypes.py` + # Usage: + # arch = api.str2structure( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' ) + # print ('there are {:} nodes in this arch'.format(len(arch)+1)) # arch is a list + # for i, node in enumerate(arch): + # print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node)) + @staticmethod + def str2structure(xstr): + assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr)) + nodestrs = xstr.split('+') + genotypes = [] + for i, node_str in enumerate(nodestrs): + inputs = list(filter(lambda x: x != '', node_str.split('|'))) + for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput) + inputs = ( xi.split('~') for xi in inputs ) + input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs) + genotypes.append( input_infos ) + return genotypes + class ArchResults(object):