update the new graph to json function
This commit is contained in:
		| @@ -39,6 +39,16 @@ op_to_atom = { | |||||||
|     'none': 'S',           # Sulfur for no operation |     'none': 'S',           # Sulfur for no operation | ||||||
|     'output': 'He'         # Helium for output |     'output': 'He'         # Helium for output | ||||||
| } | } | ||||||
|  |  | ||||||
|  | op_type = { | ||||||
|  |     'nor_conv_1x1': 1, | ||||||
|  |     'nor_conv_3x3': 2, | ||||||
|  |     'avg_pool_3x3': 3, | ||||||
|  |     'skip_connect': 4, | ||||||
|  |     'output': 5, | ||||||
|  |     'none': 6, | ||||||
|  |     'input': 7 | ||||||
|  | } | ||||||
| class DataModule(AbstractDataModule): | class DataModule(AbstractDataModule): | ||||||
|     def __init__(self, cfg): |     def __init__(self, cfg): | ||||||
|         self.datadir = cfg.dataset.datadir |         self.datadir = cfg.dataset.datadir | ||||||
| @@ -343,6 +353,121 @@ class DataModule_original(AbstractDataModule): | |||||||
|     def test_dataloader(self): |     def test_dataloader(self): | ||||||
|         return self.test_loader |         return self.test_loader | ||||||
|  |  | ||||||
|  | def new_graphs_to_json(graphs, filename): | ||||||
|  |     source_name = "nasbench-201" | ||||||
|  |     num_graph = len(graphs) | ||||||
|  |  | ||||||
|  |     node_name_list = [] | ||||||
|  |     node_count_list = [] | ||||||
|  |      | ||||||
|  |     for op_name in op_type: | ||||||
|  |         node_name_list.append(op_name) | ||||||
|  |         node_count_list.append(0)  | ||||||
|  |      | ||||||
|  |     node_name_list.append('*') | ||||||
|  |     node_count_list.append(0) | ||||||
|  |     n_nodes_per_graph = [0] * num_graph | ||||||
|  |     edge_count_list = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] | ||||||
|  |     valencies = [0] * (len(op_type) + 1) | ||||||
|  |     transition_E = np.zeros((len(op_type) + 1, len(op_type) + 1, 2)) | ||||||
|  |  | ||||||
|  |     n_node_list = [] | ||||||
|  |     n_edge_list = [] | ||||||
|  |  | ||||||
|  |     for graph in graphs: | ||||||
|  |         ops = graph[1] | ||||||
|  |         adj = graph[0] | ||||||
|  |  | ||||||
|  |         n_node = len(ops) | ||||||
|  |         n_edge = len(ops) | ||||||
|  |         n_node_list.append(n_node) | ||||||
|  |         n_edge_list.append(n_edge) | ||||||
|  |  | ||||||
|  |         n_nodes_per_graph[n_node] += 1 | ||||||
|  |         cur_node_count_arr = np.zeros(len(op_type) + 1) | ||||||
|  |  | ||||||
|  |         for op in ops: | ||||||
|  |             node = op | ||||||
|  |             if node == '*': | ||||||
|  |                 node_count_list[-1] += 1 | ||||||
|  |                 cur_node_count_arr[-1] += 1 | ||||||
|  |             else: | ||||||
|  |                 node_count_list[op_type[node]] += 1 | ||||||
|  |                 cur_node_count_arr[op_type[node]] += 1 | ||||||
|  |                 try: | ||||||
|  |                     valencies[int(op_type[node])] += 1 | ||||||
|  |                 except: | ||||||
|  |                     print('int(op_type[node])', int(op_type[node])) | ||||||
|  |          | ||||||
|  |         transition_E_temp = np.zeros((len(op_type) + 1, len(op_type) + 1, 2)) | ||||||
|  |         for i in range(n_node): | ||||||
|  |             for j in range(n_node): | ||||||
|  |                 if i == j or adj[i][j] == 0: | ||||||
|  |                     continue | ||||||
|  |                 start_node, end_node = i, j | ||||||
|  |                  | ||||||
|  |                 start_index = op_type[ops[start_node]] | ||||||
|  |                 end_index = op_type[ops[end_node]] | ||||||
|  |                 bond_index = 1 | ||||||
|  |                 edge_count_list[bond_index] += 2 | ||||||
|  |                  | ||||||
|  |                 transition_E[start_index, end_index, bond_index] += 2 | ||||||
|  |                 transition_E[end_index, start_index, bond_index] += 2 | ||||||
|  |                 transition_E_temp[start_index, end_index, bond_index] += 2 | ||||||
|  |                 transition_E_temp[end_index, start_index, bond_index] += 2 | ||||||
|  |  | ||||||
|  |         edge_count_list[0] += n_node * (n_node - 1) - n_edge * 2 | ||||||
|  |         cur_tot_edge = cur_node_count_arr.reshape(-1,1) * cur_node_count_arr.reshape(1,-1) * 2 | ||||||
|  |         print(f"cur_tot_edge={cur_tot_edge}, shape: {cur_tot_edge.shape}") | ||||||
|  |         cur_tot_edge = cur_tot_edge - np.diag(cur_node_count_arr) * 2 | ||||||
|  |         transition_E[:, :, 0] += cur_tot_edge - transition_E_temp.sum(axis=-1) | ||||||
|  |         assert (cur_tot_edge > transition_E_temp.sum(axis=-1)).sum() >= 0 | ||||||
|  |      | ||||||
|  |     n_nodes_per_graph = np.array(n_nodes_per_graph) / np.sum(n_nodes_per_graph) | ||||||
|  |     n_nodes_per_graph = n_nodes_per_graph.tolist()[:51] | ||||||
|  |  | ||||||
|  |     node_count_list = np.array(node_count_list) / np.sum(node_count_list) | ||||||
|  |     print('processed meta info: ------', filename, '------') | ||||||
|  |     print('len node_count_list', len(node_count_list)) | ||||||
|  |     print('len node_name_list', len(node_name_list)) | ||||||
|  |     active_nodes = np.array(node_name_list)[node_count_list > 0] | ||||||
|  |     active_nodes = active_nodes.tolist() | ||||||
|  |     node_count_list = node_count_list.tolist() | ||||||
|  |  | ||||||
|  |     edge_count_list = np.array(edge_count_list) / np.sum(edge_count_list) | ||||||
|  |     edge_count_list = edge_count_list.tolist() | ||||||
|  |     valencies = np.array(valencies) / np.sum(valencies) | ||||||
|  |     valencies = valencies.tolist() | ||||||
|  |  | ||||||
|  |     no_edge = np.sum(transition_E, axis=-1) == 0 | ||||||
|  |     first_elt = transition_E[:, :, 0] | ||||||
|  |     first_elt[no_edge] = 1 | ||||||
|  |     transition_E[:, :, 0] = first_elt | ||||||
|  |  | ||||||
|  |     transition_E = transition_E / np.sum(transition_E, axis=-1, keepdims=True) | ||||||
|  |  | ||||||
|  |     meta_dict = { | ||||||
|  |         'source': source_name, | ||||||
|  |         'num_graph': num_graph, | ||||||
|  |         'n_nodes_per_graph': n_nodes_per_graph, | ||||||
|  |         'max_n_nodes': max(n_node_list), | ||||||
|  |         'max_n_edges': max(n_edge_list), | ||||||
|  |         'node_type_list': node_count_list, | ||||||
|  |         'edge_type_list': edge_count_list, | ||||||
|  |         'valencies': valencies, | ||||||
|  |         'active_nodes': active_nodes, | ||||||
|  |         'num_active_nodes': len(active_nodes), | ||||||
|  |         'transition_E': transition_E.tolist(), | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     with open(f'{filename}.meta.json', 'w') as f: | ||||||
|  |         json.dump(meta_dict, f) | ||||||
|  |      | ||||||
|  |     return meta_dict | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def graphs_to_json(graphs, filename): | def graphs_to_json(graphs, filename): | ||||||
|     bonds = { |     bonds = { | ||||||
|         'nor_conv_1x1': 1, |         'nor_conv_1x1': 1, | ||||||
| @@ -490,7 +615,7 @@ def graphs_to_json(graphs, filename): | |||||||
|         'atom_type_dist': atom_count_list, |         'atom_type_dist': atom_count_list, | ||||||
|         'bond_type_dist': bond_count_list, |         'bond_type_dist': bond_count_list, | ||||||
|         'valencies': valencies, |         'valencies': valencies, | ||||||
|         'active_atoms': [atom_name_list[i] for i in range(118) if atom_count_list[i] > 0], |         'active_nodes': [atom_name_list[i] for i in range(118) if atom_count_list[i] > 0], | ||||||
|         'num_atom_type': len([atom_name_list[i] for i in range(118) if atom_count_list[i] > 0]), |         'num_atom_type': len([atom_name_list[i] for i in range(118) if atom_count_list[i] > 0]), | ||||||
|         'transition_E': transition_E.tolist(), |         'transition_E': transition_E.tolist(), | ||||||
|     } |     } | ||||||
| @@ -503,10 +628,10 @@ class Dataset(InMemoryDataset): | |||||||
|         self.target_prop = target_prop |         self.target_prop = target_prop | ||||||
|         source = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' |         source = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||||
|         self.source = source |         self.source = source | ||||||
|         super().__init__(root, transform, pre_transform, pre_filter) |  | ||||||
|         print(self.processed_paths[0]) #/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth.pt |  | ||||||
|         self.api = API(source)  # Initialize NAS-Bench-201 API |         self.api = API(source)  # Initialize NAS-Bench-201 API | ||||||
|         print('API loaded') |         print('API loaded') | ||||||
|  |         super().__init__(root, transform, pre_transform, pre_filter) | ||||||
|  |         print(self.processed_paths[0]) #/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth.pt | ||||||
|         print('Dataset initialized') |         print('Dataset initialized') | ||||||
|         self.data, self.slices = torch.load(self.processed_paths[0]) |         self.data, self.slices = torch.load(self.processed_paths[0]) | ||||||
|         self.data.edge_attr = self.data.edge_attr.squeeze() |         self.data.edge_attr = self.data.edge_attr.squeeze() | ||||||
| @@ -732,30 +857,35 @@ class DataInfos(AbstractDatasetInfos): | |||||||
|             arch_info = self.api.query_meta_info_by_index(i) |             arch_info = self.api.query_meta_info_by_index(i) | ||||||
|             nodes, edges = parse_architecture_string(arch_info.arch_str) |             nodes, edges = parse_architecture_string(arch_info.arch_str) | ||||||
|             adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges)     |             adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges)     | ||||||
|             if i < 5: |             # if i < 5: | ||||||
|                 print("Adjacency Matrix:") |             #     print("Adjacency Matrix:") | ||||||
|                 print(adj_matrix) |             #     print(adj_matrix) | ||||||
|                 print("Operations List:") |             #     print("Operations List:") | ||||||
|                 print(ops) |             #     print(ops) | ||||||
|             for op in ops: |             for op in ops: | ||||||
|                 if op not in ops_type: |                 if op not in ops_type: | ||||||
|                     ops_type[op] = len(ops_type) |                     ops_type[op] = len(ops_type) | ||||||
|             len_ops.add(len(ops)) |             len_ops.add(len(ops)) | ||||||
|             graphs.append((adj_matrix, ops)) |             graphs.append((adj_matrix, ops)) | ||||||
|  |  | ||||||
|         meta_dict = graphs_to_json(graphs, 'nasbench-201') |         # check first five graphs | ||||||
|  |         for i in range(5): | ||||||
|  |             print(f'graph {i} : {graphs[i]}') | ||||||
|  |         print(f'ops_type: {ops_type}') | ||||||
|  |  | ||||||
|  |         meta_dict = new_graphs_to_json(graphs, 'nasbench-201') | ||||||
|         self.base_path = base_path |         self.base_path = base_path | ||||||
|         self.active_atoms = meta_dict['active_atoms'] |         self.active_nodes = meta_dict['active_nodes'] | ||||||
|         self.max_n_nodes = meta_dict['max_node'] |         self.max_n_nodes = meta_dict['max_n_nodes'] | ||||||
|         self.original_max_n_nodes = meta_dict['max_node'] |         self.original_max_n_nodes = meta_dict['max_n_nodes'] | ||||||
|         self.n_nodes = torch.Tensor(meta_dict['n_atoms_per_mol_dist']) |         self.n_nodes = torch.Tensor(meta_dict['n_nodes_per_graph']) | ||||||
|         self.edge_types = torch.Tensor(meta_dict['bond_type_dist']) |         self.edge_types = torch.Tensor(meta_dict['edge_type_dist']) | ||||||
|         self.transition_E = torch.Tensor(meta_dict['transition_E']) |         self.transition_E = torch.Tensor(meta_dict['transition_E']) | ||||||
|  |  | ||||||
|         self.atom_decoder = meta_dict['active_atoms'] |         self.node_decoder = meta_dict['active_nodes'] | ||||||
|         node_types = torch.Tensor(meta_dict['atom_type_dist']) |         node_types = torch.Tensor(meta_dict['node_type_dist']) | ||||||
|         active_index = (node_types > 0).nonzero().squeeze() |         active_index = (node_types > 0).nonzero().squeeze() | ||||||
|         self.node_types = torch.Tensor(meta_dict['atom_type_dist'])[active_index] |         self.node_types = torch.Tensor(meta_dict['node_type_dist'])[active_index] | ||||||
|         self.nodes_dist = DistributionNodes(self.n_nodes) |         self.nodes_dist = DistributionNodes(self.n_nodes) | ||||||
|         self.active_index = active_index |         self.active_index = active_index | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user