update NAS-Bench-201 API to support str2structure

This commit is contained in:
D-X-Y 2020-01-18 16:37:28 +11:00
parent f49f8c7451
commit 29565cd943

View File

@ -251,6 +251,26 @@ class NASBench201API(object):
else: else:
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs))) print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))
# This func shows how to read the string0based architecture encoding
# the same as the `str2structure` func in `AutoDL-Projects/lib/models/cell_searchs/genotypes.py`
# Usage:
# arch = api.str2structure( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
# print ('there are {:} nodes in this arch'.format(len(arch)+1)) # arch is a list
# for i, node in enumerate(arch):
# print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node))
@staticmethod
def str2structure(xstr):
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
nodestrs = xstr.split('+')
genotypes = []
for i, node_str in enumerate(nodestrs):
inputs = list(filter(lambda x: x != '', node_str.split('|')))
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
inputs = ( xi.split('~') for xi in inputs )
input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs)
genotypes.append( input_infos )
return genotypes
class ArchResults(object): class ArchResults(object):