update nasbench201 adj_mat and ops mat
This commit is contained in:
		| @@ -41,13 +41,13 @@ op_to_atom = { | |||||||
| } | } | ||||||
|  |  | ||||||
| op_type = { | op_type = { | ||||||
|  |     'input': 0, | ||||||
|     'nor_conv_1x1': 1, |     'nor_conv_1x1': 1, | ||||||
|     'nor_conv_3x3': 2, |     'nor_conv_3x3': 2, | ||||||
|     'avg_pool_3x3': 3, |     'avg_pool_3x3': 3, | ||||||
|     'skip_connect': 4, |     'skip_connect': 4, | ||||||
|     'output': 5, |     'none': 5, | ||||||
|     'none': 6, |     'output': 6, | ||||||
|     'input': 7 |  | ||||||
| } | } | ||||||
| class DataModule(AbstractDataModule): | class DataModule(AbstractDataModule): | ||||||
|     def __init__(self, cfg): |     def __init__(self, cfg): | ||||||
| @@ -130,20 +130,44 @@ class DataModule(AbstractDataModule): | |||||||
|         print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index)) |         print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index)) | ||||||
|         return train_index, val_index, test_index, [] |         return train_index, val_index, test_index, [] | ||||||
|  |  | ||||||
|     def parse_architecture_string(self, arch_str): |     # def parse_architecture_string(self, arch_str): | ||||||
|             stages = arch_str.split('+') |     #         stages = arch_str.split('+') | ||||||
|             nodes = ['input'] |     #         nodes = ['input'] | ||||||
|             edges = [] |     #         edges = [] | ||||||
|              |              | ||||||
|             for stage in stages: |     #         for stage in stages: | ||||||
|                 operations = stage.strip('|').split('|') |     #             operations = stage.strip('|').split('|') | ||||||
|                 for op in operations: |     #             for op in operations: | ||||||
|                     operation, idx = op.split('~') |     #                 operation, idx = op.split('~') | ||||||
|                     idx = int(idx) |     #                 idx = int(idx) | ||||||
|                     edges.append((idx, len(nodes)))  # Add edge from idx to the new node |     #                 edges.append((idx, len(nodes)))  # Add edge from idx to the new node | ||||||
|                     nodes.append(operation) |     #                 nodes.append(operation) | ||||||
|             nodes.append('output')  # Add the output node |     #         nodes.append('output')  # Add the output node | ||||||
|             return nodes, edges |     #         return nodes, edges | ||||||
|  |     def parse_architecture_string(arch_str): | ||||||
|  |         # print(arch_str) | ||||||
|  |         steps = arch_str.split('+') | ||||||
|  |         nodes = ['input']  # Start with input node | ||||||
|  |         adj_mat = np.array([[0, 1, 1, 0, 1, 0, 0, 0], | ||||||
|  |                             [0, 0, 0, 1, 0, 1 ,0 ,0], | ||||||
|  |                             [0, 0, 0, 0, 0, 0, 1, 0], | ||||||
|  |                             [0, 0, 0, 0, 0, 0, 1, 0], | ||||||
|  |                             [0, 0, 0, 0, 0, 0, 0, 1], | ||||||
|  |                             [0, 0, 0, 0, 0, 0, 0, 1], | ||||||
|  |                             [0, 0, 0, 0, 0, 0, 0, 1], | ||||||
|  |                             [0, 0, 0, 0, 0, 0, 0, 0]])  | ||||||
|  |         steps = arch_str.split('+') | ||||||
|  |         steps_coding = ['0', '0', '1', '0', '1', '2'] | ||||||
|  |         cont = 0 | ||||||
|  |         for step in steps: | ||||||
|  |             step =  step.strip('|').split('|') | ||||||
|  |             for node in step: | ||||||
|  |                 n, idx = node.split('~') | ||||||
|  |                 assert idx == steps_coding[cont] | ||||||
|  |                 cont += 1 | ||||||
|  |                 nodes.append(n) | ||||||
|  |         nodes.append('output')  # Add output node | ||||||
|  |         return nodes, adj_mat | ||||||
|  |  | ||||||
|     # def create_molecule_from_graph(nodes, edges): |     # def create_molecule_from_graph(nodes, edges): | ||||||
|     def create_molecule_from_graph(self, graph): |     def create_molecule_from_graph(self, graph): | ||||||
| @@ -182,11 +206,11 @@ class DataModule(AbstractDataModule): | |||||||
|          |          | ||||||
|         return mol |         return mol | ||||||
|  |  | ||||||
|     def arch_str_to_smiles(self, arch_str): |     # def arch_str_to_smiles(self, arch_str): | ||||||
|         nodes, edges = self.parse_architecture_string(arch_str) |     #     nodes, edges = self.parse_architecture_string(arch_str) | ||||||
|         mol = self.create_molecule_from_graph(nodes, edges) |     #     mol = self.create_molecule_from_graph(nodes, edges) | ||||||
|         smiles = Chem.MolToSmiles(mol) |     #     smiles = Chem.MolToSmiles(mol) | ||||||
|         return smiles |     #     return smiles | ||||||
|  |  | ||||||
|     def get_train_graphs(self): |     def get_train_graphs(self): | ||||||
|         train_graphs = [] |         train_graphs = [] | ||||||
| @@ -684,8 +708,9 @@ class Dataset(InMemoryDataset): | |||||||
|             for i in range(len_data): |             for i in range(len_data): | ||||||
|                 arch_info = self.api.query_meta_info_by_index(i) |                 arch_info = self.api.query_meta_info_by_index(i) | ||||||
|                 results = self.api.query_by_index(i, 'cifar100') |                 results = self.api.query_by_index(i, 'cifar100') | ||||||
|                 nodes, edges = parse_architecture_string(arch_info.arch_str) |                 # nodes, edges = parse_architecture_string(arch_info.arch_str) | ||||||
|                 adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges) |                 ops, adj_matrix = parse_architecture_string(arch_info.arch_str) | ||||||
|  |                 # adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges) | ||||||
|                 for op in ops: |                 for op in ops: | ||||||
|                     if op not in active_nodes: |                     if op not in active_nodes: | ||||||
|                         active_nodes.add(op) |                         active_nodes.add(op) | ||||||
| @@ -901,15 +926,26 @@ def parse_architecture_string(arch_str): | |||||||
|     # print(arch_str) |     # print(arch_str) | ||||||
|     steps = arch_str.split('+') |     steps = arch_str.split('+') | ||||||
|     nodes = ['input']  # Start with input node |     nodes = ['input']  # Start with input node | ||||||
|     edges = [] |     adj_mat = np.array([[0, 1, 1, 0, 1, 0, 0, 0], | ||||||
|     for i, step in enumerate(steps): |                         [0, 0, 0, 1, 0, 1 ,0 ,0], | ||||||
|         step = step.strip('|').split('|') |                         [0, 0, 0, 0, 0, 0, 1, 0], | ||||||
|  |                         [0, 0, 0, 0, 0, 0, 1, 0], | ||||||
|  |                         [0, 0, 0, 0, 0, 0, 0, 1], | ||||||
|  |                         [0, 0, 0, 0, 0, 0, 0, 1], | ||||||
|  |                         [0, 0, 0, 0, 0, 0, 0, 1], | ||||||
|  |                         [0, 0, 0, 0, 0, 0, 0, 0]])  | ||||||
|  |     steps = arch_str.split('+') | ||||||
|  |     steps_coding = ['0', '0', '1', '0', '1', '2'] | ||||||
|  |     cont = 0 | ||||||
|  |     for step in steps: | ||||||
|  |         step =  step.strip('|').split('|') | ||||||
|         for node in step: |         for node in step: | ||||||
|             op, idx = node.split('~') |             n, idx = node.split('~') | ||||||
|             edges.append((int(idx), i+1))  # i+1 because 0 is input node |             assert idx == steps_coding[cont] | ||||||
|             nodes.append(op) |             cont += 1 | ||||||
|  |             nodes.append(n) | ||||||
|     nodes.append('output')  # Add output node |     nodes.append('output')  # Add output node | ||||||
|     return nodes, edges |     return nodes, adj_mat | ||||||
|  |  | ||||||
| def create_adj_matrix_and_ops(nodes, edges): | def create_adj_matrix_and_ops(nodes, edges): | ||||||
|     num_nodes = len(nodes) |     num_nodes = len(nodes) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user