update NAS-Bench-201 API to support str2structure
This commit is contained in:
parent
f49f8c7451
commit
29565cd943
@ -251,6 +251,26 @@ class NASBench201API(object):
|
|||||||
else:
|
else:
|
||||||
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))
|
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))
|
||||||
|
|
||||||
|
# This func shows how to read the string0based architecture encoding
|
||||||
|
# the same as the `str2structure` func in `AutoDL-Projects/lib/models/cell_searchs/genotypes.py`
|
||||||
|
# Usage:
|
||||||
|
# arch = api.str2structure( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
|
||||||
|
# print ('there are {:} nodes in this arch'.format(len(arch)+1)) # arch is a list
|
||||||
|
# for i, node in enumerate(arch):
|
||||||
|
# print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node))
|
||||||
|
@staticmethod
|
||||||
|
def str2structure(xstr):
|
||||||
|
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
|
||||||
|
nodestrs = xstr.split('+')
|
||||||
|
genotypes = []
|
||||||
|
for i, node_str in enumerate(nodestrs):
|
||||||
|
inputs = list(filter(lambda x: x != '', node_str.split('|')))
|
||||||
|
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
|
||||||
|
inputs = ( xi.split('~') for xi in inputs )
|
||||||
|
input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs)
|
||||||
|
genotypes.append( input_infos )
|
||||||
|
return genotypes
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ArchResults(object):
|
class ArchResults(object):
|
||||||
|
Loading…
Reference in New Issue
Block a user