888 lines
46 KiB
Plaintext
888 lines
46 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from nas_201_api import NASBench201API as API"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"api = API('./NAS-Bench-201-v1_1-096897.pth', verbose=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"results = api.query_by_index(1, 'cifar100')\n",
|
|
"print('There are {:} trials for this architecture [{:}] on CIFAR-100'.format(len(results), api[1]))\n",
|
|
"# \n",
|
|
"for seed, result in results.items():\n",
|
|
" print('Latency : {:}'.format(result.get_latency()))\n",
|
|
" print('Train Info : {:}'.format(result.get_train()))\n",
|
|
" print('Valid Info : {:}'.format(result.get_eval('x-valid')))\n",
|
|
" print('Test Info : {:}'.format(result.get_eval('x-test')))\n",
|
|
" print('')\n",
|
|
" print('Train Info [10-th epoch]: {:}'.format(result.get_train(10)))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 56,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import sys\n",
|
|
"sys.path.append('../') \n",
|
|
"\n",
|
|
"import os\n",
|
|
"import os.path as osp\n",
|
|
"import pathlib\n",
|
|
"import json\n",
|
|
"\n",
|
|
"import torch\n",
|
|
"import torch.nn.functional as F\n",
|
|
"from rdkit import Chem, RDLogger\n",
|
|
"from rdkit.Chem.rdchem import BondType as BT\n",
|
|
"from tqdm import tqdm\n",
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"from torch_geometric.data import Data, InMemoryDataset\n",
|
|
"from torch_geometric.loader import DataLoader\n",
|
|
"from sklearn.model_selection import train_test_split\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 57,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def random_data_split(task, dataset):\n",
|
|
" nan_count = torch.isnan(dataset.y[:, 0]).sum().item()\n",
|
|
" labeled_len = len(dataset) - nan_count\n",
|
|
" full_idx = list(range(labeled_len))\n",
|
|
" train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2\n",
|
|
" train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42)\n",
|
|
" train_index, val_index, _, _ = train_test_split(train_index, train_index, test_size=valid_ratio/(valid_ratio+train_ratio), random_state=42)\n",
|
|
" unlabeled_index = list(range(labeled_len, len(dataset)))\n",
|
|
" print(task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index), 'unlabeled len', len(unlabeled_index))\n",
|
|
" return train_index, val_index, test_index, unlabeled_index"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 58,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def parse_architecture_string(arch_str):\n",
|
|
" print(arch_str)\n",
|
|
" steps = arch_str.split('+')\n",
|
|
" nodes = ['input'] # Start with input node\n",
|
|
" edges = []\n",
|
|
" for i, step in enumerate(steps):\n",
|
|
" step = step.strip('|').split('|')\n",
|
|
" for node in step:\n",
|
|
" op, idx = node.split('~')\n",
|
|
" edges.append((int(idx), i+1)) # i+1 because 0 is input node\n",
|
|
" nodes.append(op)\n",
|
|
" nodes.append('output') # Add output node\n",
|
|
" return nodes, edges"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 59,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def create_adj_matrix_and_ops(nodes, edges):\n",
|
|
" num_nodes = len(nodes)\n",
|
|
" adj_matrix = np.zeros((num_nodes, num_nodes), dtype=int)\n",
|
|
" for (src, dst) in edges:\n",
|
|
" adj_matrix[src][dst] = 1\n",
|
|
" return adj_matrix, nodes"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"graphs = []\n",
|
|
"length = 15625\n",
|
|
"ops_type = {}\n",
|
|
"len_ops = set()\n",
|
|
"for i in range(length):\n",
|
|
" arch_info = api.query_meta_info_by_index(i)\n",
|
|
" nodes, edges = parse_architecture_string(arch_info.arch_str)\n",
|
|
" adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges) \n",
|
|
" if i < 5:\n",
|
|
" print(\"Adjacency Matrix:\")\n",
|
|
" print(adj_matrix)\n",
|
|
" print(\"Operations List:\")\n",
|
|
" print(ops)\n",
|
|
" for op in ops:\n",
|
|
" if op not in ops_type:\n",
|
|
" ops_type[op] = len(ops_type)\n",
|
|
" len_ops.add(len(ops))\n",
|
|
" graphs.append((adj_matrix, ops))\n",
|
|
"print(graphs[0])\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(len(ops_type))\n",
|
|
"print(len(len_ops))\n",
|
|
"print(ops_type)\n",
|
|
"print(len_ops)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 60,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"op_to_atom = {\n",
|
|
" 'input': 'Si', # Hydrogen for input\n",
|
|
" 'nor_conv_1x1': 'C', # Carbon for 1x1 convolution\n",
|
|
" 'nor_conv_3x3': 'N', # Nitrogen for 3x3 convolution\n",
|
|
" 'avg_pool_3x3': 'O', # Oxygen for 3x3 average pooling\n",
|
|
" 'skip_connect': 'P', # Phosphorus for skip connection\n",
|
|
" 'none': 'S', # Sulfur for no operation\n",
|
|
" 'output': 'He' # Helium for output\n",
|
|
"}\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 61,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def graphs_to_json(graphs, filename):\n",
|
|
" bonds = {\n",
|
|
" 'nor_conv_1x1': 1,\n",
|
|
" 'nor_conv_3x3': 2,\n",
|
|
" 'avg_pool_3x3': 3,\n",
|
|
" 'skip_connect': 4,\n",
|
|
" 'input': 7,\n",
|
|
" 'output': 5,\n",
|
|
" 'none': 6\n",
|
|
" }\n",
|
|
"\n",
|
|
" source_name = \"nas-bench-201\"\n",
|
|
" num_graph = len(graphs)\n",
|
|
" pt = Chem.GetPeriodicTable()\n",
|
|
" atom_name_list = []\n",
|
|
" atom_count_list = []\n",
|
|
" for i in range(2, 119):\n",
|
|
" atom_name_list.append(pt.GetElementSymbol(i))\n",
|
|
" atom_count_list.append(0)\n",
|
|
" atom_name_list.append('*')\n",
|
|
" atom_count_list.append(0)\n",
|
|
" n_atoms_per_mol = [0] * 500\n",
|
|
" bond_count_list = [0, 0, 0, 0, 0, 0, 0, 0]\n",
|
|
" bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}\n",
|
|
" valencies = [0] * 500\n",
|
|
" transition_E = np.zeros((118, 118, 5))\n",
|
|
"\n",
|
|
" n_atom_list = []\n",
|
|
" n_bond_list = []\n",
|
|
" # graphs = [(adj_matrix, ops), ...]\n",
|
|
" for graph in graphs:\n",
|
|
" ops = graph[1]\n",
|
|
" adj = graph[0]\n",
|
|
" n_atom = len(ops)\n",
|
|
" n_bond = len(ops)\n",
|
|
" n_atom_list.append(n_atom)\n",
|
|
" n_bond_list.append(n_bond)\n",
|
|
"\n",
|
|
" n_atoms_per_mol[n_atom] += 1\n",
|
|
" cur_atom_count_arr = np.zeros(118)\n",
|
|
"\n",
|
|
" for op in ops:\n",
|
|
" symbol = op_to_atom[op]\n",
|
|
" if symbol == 'H':\n",
|
|
" continue\n",
|
|
" elif symbol == '*':\n",
|
|
" atom_count_list[-1] += 1\n",
|
|
" cur_atom_count_arr[-1] += 1\n",
|
|
" else:\n",
|
|
" atom_count_list[pt.GetAtomicNumber(symbol)-2] += 1\n",
|
|
" cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2] += 1\n",
|
|
" # print('symbol', symbol)\n",
|
|
" # print('pt.GetDefaultValence(symbol)', pt.GetDefaultValence(symbol))\n",
|
|
" # print(f'cur_atom_count_arr[{pt.GetAtomicNumber(symbol)-2}], {cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2]}')\n",
|
|
" try:\n",
|
|
" valencies[int(pt.GetDefaultValence(symbol))] += 1\n",
|
|
" except:\n",
|
|
" print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol)))\n",
|
|
" transition_E_temp = np.zeros((118, 118, 5))\n",
|
|
" # print(n_atom)\n",
|
|
" for i in range(n_atom):\n",
|
|
" for j in range(n_atom):\n",
|
|
" if i == j or adj[i][j] == 0:\n",
|
|
" continue\n",
|
|
" start_atom, end_atom = i, j\n",
|
|
" # if ops[start_atom] == 'input' or ops[end_atom] == 'input':\n",
|
|
" # continue\n",
|
|
" # if ops[start_atom] == 'output' or ops[end_atom] == 'output':\n",
|
|
" # continue\n",
|
|
" # if ops[start_atom] == 'none' or ops[end_atom] == 'none':\n",
|
|
" # continue\n",
|
|
" \n",
|
|
" start_index = pt.GetAtomicNumber(op_to_atom[ops[start_atom]]) - 2\n",
|
|
" end_index = pt.GetAtomicNumber(op_to_atom[ops[end_atom]]) - 2\n",
|
|
" bond_index = bonds[ops[end_atom]]\n",
|
|
" bond_count_list[bond_index] += 2\n",
|
|
"\n",
|
|
" # print(start_index, end_index, bond_index)\n",
|
|
" \n",
|
|
" transition_E[start_index, end_index, bond_index] += 2\n",
|
|
" transition_E[end_index, start_index, bond_index] += 2\n",
|
|
" transition_E_temp[start_index, end_index, bond_index] += 2\n",
|
|
" transition_E_temp[end_index, start_index, bond_index] += 2\n",
|
|
"\n",
|
|
" bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2\n",
|
|
" print(bond_count_list)\n",
|
|
" cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2\n",
|
|
" # print(f'cur_tot_bond={cur_tot_bond}') \n",
|
|
" # find non-zero element in cur_tot_bond\n",
|
|
" # for i in range(118):\n",
|
|
" # for j in range(118):\n",
|
|
" # if cur_tot_bond[i][j] != 0:\n",
|
|
" # print(f'i={i}, j={j}, cur_tot_bond[i][j]={cur_tot_bond[i][j]}')\n",
|
|
" # n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)\n",
|
|
" cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2\n",
|
|
" # print(f\"transition_E[:,:,0]={cur_tot_bond - transition_E_temp.sum(axis=-1)}\")\n",
|
|
" transition_E[:, :, 0] += cur_tot_bond - transition_E_temp.sum(axis=-1)\n",
|
|
" # find non-zero element in transition_E\n",
|
|
" # for i in range(118):\n",
|
|
" # for j in range(118):\n",
|
|
" # if transition_E[i][j][0] != 0:\n",
|
|
" # print(f'i={i}, j={j}, transition_E[i][j][0]={transition_E[i][j][0]}')\n",
|
|
" assert (cur_tot_bond > transition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}'\n",
|
|
" \n",
|
|
" n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)\n",
|
|
" n_atoms_per_mol = n_atoms_per_mol.tolist()[:51]\n",
|
|
"\n",
|
|
" atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list)\n",
|
|
" print('processed meta info: ------', filename, '------')\n",
|
|
" print('len atom_count_list', len(atom_count_list))\n",
|
|
" print('len atom_name_list', len(atom_name_list))\n",
|
|
" active_atoms = np.array(atom_name_list)[atom_count_list > 0]\n",
|
|
" active_atoms = active_atoms.tolist()\n",
|
|
" atom_count_list = atom_count_list.tolist()\n",
|
|
"\n",
|
|
" bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list)\n",
|
|
" bond_count_list = bond_count_list.tolist()\n",
|
|
" valencies = np.array(valencies) / np.sum(valencies)\n",
|
|
" valencies = valencies.tolist()\n",
|
|
"\n",
|
|
" no_edge = np.sum(transition_E, axis=-1) == 0\n",
|
|
" for i in range(118):\n",
|
|
" for j in range(118):\n",
|
|
" if no_edge[i][j] == False:\n",
|
|
" print(f'have an edge at i={i} , j={j}, transition_E[i][j]={transition_E[i][j]}')\n",
|
|
" # print(f'no_edge: {no_edge}')\n",
|
|
" first_elt = transition_E[:, :, 0]\n",
|
|
" first_elt[no_edge] = 1\n",
|
|
" transition_E[:, :, 0] = first_elt\n",
|
|
"\n",
|
|
" transition_E = transition_E / np.sum(transition_E, axis=-1, keepdims=True)\n",
|
|
"\n",
|
|
" # find non-zero element in transition_E again\n",
|
|
" for i in range(118):\n",
|
|
" for j in range(118):\n",
|
|
" if transition_E[i][j][0] != 0 and transition_E[i][j][0] != 1 and transition_E[i][j][0] != -1:\n",
|
|
" print(f'i={i}, j={j}, 2_transition_E[i][j][0]={transition_E[i][j][0]}')\n",
|
|
"\n",
|
|
" meta_dict = {\n",
|
|
" 'source': 'nasbench-201',\n",
|
|
" 'num_graph': num_graph,\n",
|
|
" 'n_atoms_per_mol_dist': n_atoms_per_mol[:51],\n",
|
|
" 'max_node': max(n_atom_list),\n",
|
|
" 'max_bond': max(n_bond_list),\n",
|
|
" 'atom_type_dist': atom_count_list,\n",
|
|
" 'bond_type_dist': bond_count_list,\n",
|
|
" 'valencies': valencies,\n",
|
|
" 'active_atoms': [atom_name_list[i] for i in range(118) if atom_count_list[i] > 0],\n",
|
|
" 'num_atom_type': len([atom_name_list[i] for i in range(118) if atom_count_list[i] > 0]),\n",
|
|
" 'transition_E': transition_E.tolist(),\n",
|
|
" }\n",
|
|
"\n",
|
|
" with open(f'{filename}.meta.json', 'w') as f:\n",
|
|
" json.dump(meta_dict, f)\n",
|
|
" return meta_dict\n",
|
|
"\n",
|
|
" "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"graphs_to_json(graphs, 'nasbench-201')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 62,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def gen_adj_matrix_and_ops(nasbench):\n",
|
|
" i = 0\n",
|
|
" epoch = 108\n",
|
|
"\n",
|
|
" for unique_hash in nasbench.hash_iterator():\n",
|
|
" fixed_metrics, computed_metrics = nasbench.get_metrics_from_hash(unique_hash)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 63,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"from torch_geometric.data import InMemoryDataset, Data\n",
|
|
"import os.path as osp\n",
|
|
"import pandas as pd\n",
|
|
"from tqdm import tqdm\n",
|
|
"import networkx as nx\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"class Dataset(InMemoryDataset):\n",
|
|
" def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):\n",
|
|
" self.target_prop = target_prop\n",
|
|
" source = './NAS-Bench-201-v1_1-096897.pth'\n",
|
|
" self.source = source\n",
|
|
" self.api = API(source) # Initialize NAS-Bench-201 API\n",
|
|
" super().__init__(root, transform, pre_transform, pre_filter)\n",
|
|
" self.data, self.slices = torch.load(self.processed_paths[0])\n",
|
|
"\n",
|
|
" @property\n",
|
|
" def raw_file_names(self):\n",
|
|
" return [] # NAS-Bench-201 data is loaded via the API, no raw files needed\n",
|
|
" \n",
|
|
" @property\n",
|
|
" def processed_file_names(self):\n",
|
|
" return [f'{self.source}.pt']\n",
|
|
"\n",
|
|
" def process(self):\n",
|
|
" def parse_architecture_string(arch_str):\n",
|
|
" stages = arch_str.split('+')\n",
|
|
" nodes = ['input']\n",
|
|
" edges = []\n",
|
|
" \n",
|
|
" for stage in stages:\n",
|
|
" operations = stage.strip('|').split('|')\n",
|
|
" for op in operations:\n",
|
|
" operation, idx = op.split('~')\n",
|
|
" idx = int(idx)\n",
|
|
" edges.append((idx, len(nodes))) # Add edge from idx to the new node\n",
|
|
" nodes.append(operation)\n",
|
|
" nodes.append('output') # Add the output node\n",
|
|
" return nodes, edges\n",
|
|
"\n",
|
|
" def create_graph(nodes, edges):\n",
|
|
" G = nx.DiGraph()\n",
|
|
" for i, node in enumerate(nodes):\n",
|
|
" G.add_node(i, label=node)\n",
|
|
" G.add_edges_from(edges)\n",
|
|
" return G\n",
|
|
"\n",
|
|
" def arch_to_graph(arch_str, sa, sc, target, target2=None, target3=None):\n",
|
|
" nodes, edges = parse_architecture_string(arch_str)\n",
|
|
"\n",
|
|
" node_labels = [bonds[node] for node in nodes] # Replace with appropriate encoding if necessary\n",
|
|
" assert 0 not in node_labels, f'Invalid node label: {node_labels}'\n",
|
|
" x = torch.LongTensor(node_labels)\n",
|
|
"\n",
|
|
" edges_list = [(start, end) for start, end in edges]\n",
|
|
" edge_type = [bonds[nodes[end]] for start, end in edges] # Example: using end node type as edge type\n",
|
|
" edge_index = torch.tensor(edges_list, dtype=torch.long).t().contiguous()\n",
|
|
" edge_type = torch.tensor(edge_type, dtype=torch.long)\n",
|
|
" edge_attr = edge_type.view(-1, 1)\n",
|
|
"\n",
|
|
" if target3 is not None:\n",
|
|
" y = torch.tensor([sa, sc, target, target2, target3], dtype=torch.float).view(1, -1)\n",
|
|
" elif target2 is not None:\n",
|
|
" y = torch.tensor([sa, sc, target, target2], dtype=torch.float).view(1, -1)\n",
|
|
" else:\n",
|
|
" y = torch.tensor([sa, sc, target], dtype=torch.float).view(1, -1)\n",
|
|
"\n",
|
|
" data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)\n",
|
|
" return data, nodes\n",
|
|
"\n",
|
|
" bonds = {\n",
|
|
" 'nor_conv_1x1': 1,\n",
|
|
" 'nor_conv_3x3': 2,\n",
|
|
" 'avg_pool_3x3': 3,\n",
|
|
" 'skip_connect': 4,\n",
|
|
" 'output': 5,\n",
|
|
" 'none': 6,\n",
|
|
" 'input': 7\n",
|
|
" }\n",
|
|
"\n",
|
|
" # Prepare to process NAS-Bench-201 data\n",
|
|
" data_list = []\n",
|
|
" len_data = len(self.api) # Number of architectures\n",
|
|
" with tqdm(total=len_data) as pbar:\n",
|
|
" for arch_index in range(len_data):\n",
|
|
" arch_info = self.api.query_meta_info_by_index(arch_index)\n",
|
|
" arch_str = arch_info.arch_str\n",
|
|
" sa = np.random.rand() # Placeholder for synthetic accessibility\n",
|
|
" sc = np.random.rand() # Placeholder for substructure count\n",
|
|
" target = np.random.rand() # Placeholder for target value\n",
|
|
" target2 = np.random.rand() # Placeholder for second target value\n",
|
|
" target3 = np.random.rand() # Placeholder for third target value\n",
|
|
"\n",
|
|
" data, active_nodes = arch_to_graph(arch_str, sa, sc, target, target2, target3)\n",
|
|
" data_list.append(data)\n",
|
|
" pbar.update(1)\n",
|
|
"\n",
|
|
" torch.save(self.collate(data_list), self.processed_paths[0])\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# dataset = Dataset(source='./NAS-Bench-201-v1_1-096897.pth', root='./data')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 84,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import pathlib\n",
|
|
"import torch\n",
|
|
"from torch_geometric.data import DataLoader\n",
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"import pandas as pd\n",
|
|
"from tqdm import tqdm\n",
|
|
"# import nas_bench_201 as nb201\n",
|
|
"\n",
|
|
"import utils as utils\n",
|
|
"from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule\n",
|
|
"from diffusion.distributions import DistributionNodes\n",
|
|
"\n",
|
|
"from rdkit.Chem import rdchem\n",
|
|
"\n",
|
|
"class DataModule(AbstractDataModule):\n",
|
|
" def __init__(self, cfg):\n",
|
|
" self.datadir = cfg.dataset.datadir\n",
|
|
" self.task = cfg.dataset.task_name\n",
|
|
" print(\"DataModule\")\n",
|
|
" print(\"task\", self.task)\n",
|
|
" print(\"datadir\", self.datadir)\n",
|
|
" super().__init__(cfg)\n",
|
|
"\n",
|
|
" def prepare_data(self) -> None:\n",
|
|
" target = getattr(self.cfg.dataset, 'guidance_target', None)\n",
|
|
" print(\"target\", target)\n",
|
|
" # try:\n",
|
|
" # base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]\n",
|
|
" # except NameError:\n",
|
|
" # base_path = pathlib.Path(os.getcwd()).parent[2]\n",
|
|
" base_path = '/home/stud/hanzhang/Graph-Dit'\n",
|
|
" root_path = os.path.join(base_path, self.datadir)\n",
|
|
" self.root_path = root_path\n",
|
|
"\n",
|
|
" batch_size = self.cfg.train.batch_size\n",
|
|
" \n",
|
|
" num_workers = self.cfg.train.num_workers\n",
|
|
" pin_memory = self.cfg.dataset.pin_memory\n",
|
|
"\n",
|
|
" # Load the dataset to the memory\n",
|
|
" # Dataset has target property, root path, and transform\n",
|
|
" source = './NAS-Bench-201-v1_1-096897.pth'\n",
|
|
" dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None)\n",
|
|
"\n",
|
|
" # if len(self.task.split('-')) == 2:\n",
|
|
" # train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)\n",
|
|
" # else:\n",
|
|
" train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset)\n",
|
|
"\n",
|
|
" self.train_index, self.val_index, self.test_index, self.unlabeled_index = train_index, val_index, test_index, unlabeled_index\n",
|
|
" train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index)\n",
|
|
" if len(unlabeled_index) > 0:\n",
|
|
" train_index = torch.cat([train_index, unlabeled_index], dim=0)\n",
|
|
" \n",
|
|
" train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index]\n",
|
|
" self.train_dataset = train_dataset \n",
|
|
" print('train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))\n",
|
|
" print('train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))\n",
|
|
" print('dataset len', len(dataset), 'train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))\n",
|
|
" self.train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=pin_memory)\n",
|
|
"\n",
|
|
" self.val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)\n",
|
|
" self.test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)\n",
|
|
"\n",
|
|
" training_iterations = len(train_dataset) // batch_size\n",
|
|
" self.training_iterations = training_iterations\n",
|
|
" \n",
|
|
" def random_data_split(self, dataset):\n",
|
|
" nan_count = torch.isnan(dataset.data.y[:, 0]).sum().item()\n",
|
|
" labeled_len = len(dataset) - nan_count\n",
|
|
" full_idx = list(range(labeled_len))\n",
|
|
" train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2\n",
|
|
" train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42)\n",
|
|
" train_index, val_index, _, _ = train_test_split(train_index, train_index, test_size=valid_ratio/(valid_ratio+train_ratio), random_state=42)\n",
|
|
" unlabeled_index = list(range(labeled_len, len(dataset)))\n",
|
|
" print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index), 'unlabeled len', len(unlabeled_index))\n",
|
|
" return train_index, val_index, test_index, unlabeled_index\n",
|
|
" \n",
|
|
" def fixed_split(self, dataset):\n",
|
|
" if self.task == 'O2-N2':\n",
|
|
" test_index = [42,43,92,122,197,198,251,254,257,355,511,512,549,602,603,604]\n",
|
|
" else:\n",
|
|
" raise ValueError('Invalid task name: {}'.format(self.task))\n",
|
|
" full_idx = list(range(len(dataset)))\n",
|
|
" full_idx = list(set(full_idx) - set(test_index))\n",
|
|
" train_ratio = 0.8\n",
|
|
" train_index, val_index, _, _ = train_test_split(full_idx, full_idx, test_size=1-train_ratio, random_state=42)\n",
|
|
" print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))\n",
|
|
" return train_index, val_index, test_index, []\n",
|
|
"\n",
|
|
" def parse_architecture_string(arch_str):\n",
|
|
" stages = arch_str.split('+')\n",
|
|
" nodes = ['input']\n",
|
|
" edges = []\n",
|
|
" \n",
|
|
" for stage in stages:\n",
|
|
" operations = stage.strip('|').split('|')\n",
|
|
" for op in operations:\n",
|
|
" operation, idx = op.split('~')\n",
|
|
" idx = int(idx)\n",
|
|
" edges.append((idx, len(nodes))) # Add edge from idx to the new node\n",
|
|
" nodes.append(operation)\n",
|
|
" nodes.append('output') # Add the output node\n",
|
|
" return nodes, edges\n",
|
|
"\n",
|
|
" def create_molecule_from_graph(nodes, edges):\n",
|
|
" mol = Chem.RWMol() # RWMol allows for building the molecule step by step\n",
|
|
" atom_indices = {}\n",
|
|
" num_to_op = {\n",
|
|
" 1 :'nor_conv_1x1',\n",
|
|
" 2 :'nor_conv_3x3',\n",
|
|
" 3 :'avg_pool_3x3',\n",
|
|
" 4 :'skip_connect',\n",
|
|
" 5 :'output',\n",
|
|
" 6 :'none',\n",
|
|
" 7 :'input'\n",
|
|
" } \n",
|
|
"\n",
|
|
" # Extract node operations from the data object\n",
|
|
"\n",
|
|
" # Add atoms to the molecule\n",
|
|
" for i, op_tensor in enumerate(nodes):\n",
|
|
" op = op_tensor.item()\n",
|
|
" if op == 0: continue\n",
|
|
" op = num_to_op[op]\n",
|
|
" atom_symbol = op_to_atom[op]\n",
|
|
" atom = Chem.Atom(atom_symbol)\n",
|
|
" atom_idx = mol.AddAtom(atom)\n",
|
|
" atom_indices[i] = atom_idx\n",
|
|
" \n",
|
|
" # Add bonds to the molecule\n",
|
|
" edge_number = edges.shape[1]\n",
|
|
" for i in range(edge_number):\n",
|
|
" start = edges[0, i].item()\n",
|
|
" end = edges[1, i].item()\n",
|
|
" mol.AddBond(atom_indices[start], atom_indices[end], rdchem.BondType.SINGLE)\n",
|
|
" \n",
|
|
" return mol\n",
|
|
"\n",
|
|
" def arch_str_to_smiles(self, arch_str):\n",
|
|
" nodes, edges = self.parse_architecture_string(arch_str)\n",
|
|
" mol = self.create_molecule_from_graph(nodes, edges)\n",
|
|
" smiles = Chem.MolToSmiles(mol)\n",
|
|
" return smiles\n",
|
|
"\n",
|
|
" def get_train_smiles(self):\n",
|
|
" # raise NotImplementedError(\"This method is not applicable for NAS-Bench-201 data.\")\n",
|
|
" # train_arch_strs = []\n",
|
|
" # test_arch_strs = []\n",
|
|
"\n",
|
|
" # for idx in self.train_index:\n",
|
|
" # arch_info = self.train_dataset[idx]\n",
|
|
" # arch_str = arch_info.arch_str\n",
|
|
" # train_arch_strs.append(arch_str)\n",
|
|
" # for idx in self.test_index:\n",
|
|
" # arch_info = self.train_dataset[idx]\n",
|
|
" # arch_str = arch_info.arch_str\n",
|
|
" # test_arch_strs.append(arch_str)\n",
|
|
"\n",
|
|
" train_smiles = [] \n",
|
|
" test_smiles = []\n",
|
|
"\n",
|
|
" for idx in self.train_index:\n",
|
|
" graph = self.train_dataset[idx]\n",
|
|
" print(graph.x, graph.edge_index)\n",
|
|
" print(f'class of graph.x: {graph.x.__class__}, class of graph.edge_index: {graph.edge_index.__class__}')\n",
|
|
" mol = self.create_molecule_from_graph(graph.x, graph.edge_index)\n",
|
|
" train_smiles.append(Chem.MolToSmiles(mol))\n",
|
|
" \n",
|
|
" for idx in self.test_index:\n",
|
|
" graph = self.train_dataset[idx]\n",
|
|
" mol = self.create_molecule_from_graph(graph.x, graph.edge_index)\n",
|
|
" test_smiles.append(Chem.MolToSmiles(mol))\n",
|
|
" \n",
|
|
" # train_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in train_arch_strs]\n",
|
|
" # test_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in test_arch_strs]\n",
|
|
" return train_smiles, test_smiles\n",
|
|
"\n",
|
|
" def get_data_split(self):\n",
|
|
" raise NotImplementedError(\"This method is not applicable for NAS-Bench-201 data.\")\n",
|
|
"\n",
|
|
" def example_batch(self):\n",
|
|
" return next(iter(self.val_loader))\n",
|
|
" \n",
|
|
" def train_dataloader(self):\n",
|
|
" return self.train_loader\n",
|
|
"\n",
|
|
" def val_dataloader(self):\n",
|
|
" return self.val_loader\n",
|
|
" \n",
|
|
" def test_dataloader(self):\n",
|
|
" return self.test_loader\n",
|
|
"\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 85,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from omegaconf import DictConfig, OmegaConf\n",
|
|
"import argparse\n",
|
|
"import hydra\n",
|
|
"\n",
|
|
"def parse_arg():\n",
|
|
" parser = argparse.ArgumentParser(description='Diffusion')\n",
|
|
" parser.add_argument('--config', type=str, default='config.yaml', help='config file')\n",
|
|
" return parser.parse_args()\n",
|
|
"\n",
|
|
"def task1(cfg: DictConfig):\n",
|
|
" datamodule = DataModule(cfg=cfg)\n",
|
|
" datamodule.prepare_data()\n",
|
|
" return datamodule\n",
|
|
"\n",
|
|
"cfg = {\n",
|
|
" 'general':{\n",
|
|
" 'name': 'graph_dit',\n",
|
|
" 'wandb': 'disabled' ,\n",
|
|
" 'gpus': 1,\n",
|
|
" 'resume': 'null',\n",
|
|
" 'test_only': 'null',\n",
|
|
" 'sample_every_val': 2500,\n",
|
|
" 'samples_to_generate': 512,\n",
|
|
" 'samples_to_save': 3,\n",
|
|
" 'chains_to_save': 1,\n",
|
|
" 'log_every_steps': 50,\n",
|
|
" 'number_chain_steps': 8,\n",
|
|
" 'final_model_samples_to_generate': 10000,\n",
|
|
" 'final_model_samples_to_save': 20,\n",
|
|
" 'final_model_chains_to_save': 1,\n",
|
|
" 'enable_progress_bar': False,\n",
|
|
" 'save_model': True,\n",
|
|
" },\n",
|
|
" 'model':{\n",
|
|
" 'type': 'discrete',\n",
|
|
" 'transition': 'marginal',\n",
|
|
" 'model': 'graph_dit',\n",
|
|
" 'diffusion_steps': 500,\n",
|
|
" 'diffusion_noise_schedule': 'cosine',\n",
|
|
" 'guide_scale': 2,\n",
|
|
" 'hidden_size': 1152,\n",
|
|
" 'depth': 6,\n",
|
|
" 'num_heads': 16,\n",
|
|
" 'mlp_ratio': 4,\n",
|
|
" 'drop_condition': 0.01,\n",
|
|
" 'lambda_train': [1, 10], # node and edge training weight \n",
|
|
" 'ensure_connected': True,\n",
|
|
" },\n",
|
|
" 'train':{\n",
|
|
" 'n_epochs': 10000,\n",
|
|
" 'batch_size': 1200,\n",
|
|
" 'lr': 0.0002,\n",
|
|
" 'clip_grad': 'null',\n",
|
|
" 'num_workers': 0,\n",
|
|
" 'weight_decay': 0,\n",
|
|
" 'seed': 0,\n",
|
|
" 'val_check_interval': 'null',\n",
|
|
" 'check_val_every_n_epoch': 1,\n",
|
|
" },\n",
|
|
" 'dataset':{\n",
|
|
" 'datadir': 'data',\n",
|
|
" 'task_name': 'nasbench-201',\n",
|
|
" 'guidance_target': 'nasbench-201',\n",
|
|
" 'pin_memory': False,\n",
|
|
" },\n",
|
|
"}\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 86,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"DataModule\n",
|
|
"task nasbench-201\n",
|
|
"datadir data\n",
|
|
"target nasbench-201\n",
|
|
"try to create the NAS-Bench-201 api from ./NAS-Bench-201-v1_1-096897.pth\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"cfg = OmegaConf.create(cfg)\n",
|
|
"dm = task1(cfg)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 73,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"def create_molecule_from_graph(nodes, edges):\n",
|
|
" mol = Chem.RWMol() # RWMol allows for building the molecule step by step\n",
|
|
" atom_indices = {}\n",
|
|
" num_to_op = {\n",
|
|
" 1 :'nor_conv_1x1',\n",
|
|
" 2 :'nor_conv_3x3',\n",
|
|
" 3 :'avg_pool_3x3',\n",
|
|
" 4 :'skip_connect',\n",
|
|
" 5 :'output',\n",
|
|
" 6 :'none',\n",
|
|
" 7 :'input'\n",
|
|
" } \n",
|
|
"\n",
|
|
" # Extract node operations from the data object\n",
|
|
"\n",
|
|
" # Add atoms to the molecule\n",
|
|
" for i, op_tensor in enumerate(nodes):\n",
|
|
" op = op_tensor.item()\n",
|
|
" if op == 0: continue\n",
|
|
" op = num_to_op[op]\n",
|
|
" atom_symbol = op_to_atom[op]\n",
|
|
" atom = Chem.Atom(atom_symbol)\n",
|
|
" atom_idx = mol.AddAtom(atom)\n",
|
|
" atom_indices[i] = atom_idx\n",
|
|
" \n",
|
|
" # Add bonds to the molecule\n",
|
|
" edge_number = edges.shape[1]\n",
|
|
" for i in range(edge_number):\n",
|
|
" start = edges[0, i].item()\n",
|
|
" end = edges[1, i].item()\n",
|
|
" mol.AddBond(atom_indices[start], atom_indices[end], rdchem.BondType.SINGLE)\n",
|
|
" \n",
|
|
" return mol"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 77,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcIAAACWCAIAAADCEh9HAAAABmJLR0QA/wD/AP+gvaeTAAAYNUlEQVR4nO3deVxU9d4H8O9szDAzMgz7KoK4ZhRhliKSlkKIqBTmdi1bbTPvvaXeevVkPnVLe26pZa/UMjVTCa+JkIKmV0Ul0SR3BcQFF0AEZphhBmbmnOeP8SIBLjBwzhz4vP+K3zln5tMr+nDO+Z1FxLIsAQBAW4n5DgAAIGyoUQAAh6BGAQAcghoFAHAIahQAwCGoUQAAh0j5DgDQFbFsncGwz2wuZNl6F5dgpfIhF5cQ+yKj8TedLkulitJoxjTfsLJyndlcoNWmuLrex21kuC3UKADXqqs3X7z4itVa3mhM1K3biN69fyUio/G3a9c+9PJ66XY1qtP94uraHzXqPFCjAJwyGvOKi1NY1urunuzuPlYq9ayvL6mp2SkWq/iOBm2EGgXgVHn5Fyxr9fJ6OSRkWcOgt/cMHiOBgzDFBMAps/kMEbm5jeI7CLQb1CgAp8TibkRkNObyHQTaDQ7qATjl5hZnMOSUlf2LYQy+vu/I5T1bXI1hai2Wq83HWbaugwNCq4nwhCcALjGMubg4WafbRkREIrU62svrRa02RSxW2lcoL19UUvLXO39IWFiqVjuhg5PCvcLeKACnxGJFePgvOl1WZeUP1dWbDYZ9BsO+K1fm9uy5SaUa3LCaQtFPrY5pvrlen1Vff4nDvHB3qFEA7ok0mic1midttprq6s2lpZ+YzaeLisbdf39xw2VPavXQxlP5DYqKElGjzgZTTAC8kUi6eXr+pW/fXJks0Got1+ky+U4EbYEaBeCZRKLp1i2GiOrqLvKdBdoCNQrAP7O5gIikUi3fQaAtUKMAnLp2bX5l5VqbTW//kWXrr137sLb2iEjk4uYWx282aBtMMQFwh2FqS0s/ZRiTSCRzcQkRi1X19RdsNh2RODj4CxeX7nwHhLZAjQJwRyxW3nff6crK9Xp9ttlcYLXekEq9NJoxPj6vq1SP2tdRKqN8ff/e8GMT7u7JCkVfhaIPh6nhLnD5PQCAQ7A3CsC1mpo9DGNQq4dKJBq+s0A7wBQTANfKyhYWFSXq9Vl8B4H2gRoF4BhrNP5GRCrVEL6TQPtAjQJwymw+Y7VWymQBLi7BfGeB9oEaBeCUwZBLRGp1NN9BoN2gRgE4ZX9gc+OHOYHQoUYBOGWvUbUaNdp5oEYBuGOzVZtMp0UiuVIZyXcWaDeoUQDuGI2/ETEq1UCRSM53Fmg3qFEA7tjnl3BitJNBjQJwBydGOyXUKABnGKMxj4hUqkf4TgLtCTUKwBGT6aTNpnNx6SGTBfKdBdoTahSAIwbDASJSq3EPaGeDGgXgCC6876xQowAcwfxSZ4UaBeCCzXrDbC4Ui5WurhF8Z4F2hsc2A3BBUpz7wD4P84OjRSIZ31mgnWFvFIATJQek1TfUVszRd0KoUQBOXM4lIgrGidFOCDUK0PEYK109TCSioJbf9wmChhqFdnPy5Mlr167xncIplR2jegN5hJPSm+8o0P5Qo+Co8+fPz5gxw93dfcCAAQEBAaNGjeI7kfMpOUBEFIwL7zsnzNRDW7Ase/jw4Z9//jk9Pf3UqVONF+3YsePZZ59dtWqVSCTiK57TwYnRTg01Cq1gs9lyc3PT0tI2bdp0+fJl+6BWqx0+fHhQUNB77723cOHCJUuWrFmzhmXZlStXSqX4BSMiopJcIqIg1GjnJGJZlu8M4Oxqa2t37tyZlpaWkZFRXV1tH+zevXt8fHxiYmJ8fLxMdutayF27do0bN66mpmbMmDGpqamurq48pXYaxjL6Pz9y6UZzq0gk4TsNtD/UKNxWRUXF1q1bMzMzt27dajQa7YP9+/cfM2ZMYmJidHT07Q7bDx06lJCQUFFRERsbm56ertFoOEztfE5vop+eorCR9JftfEeBDoFjLmjq/PnzW7ZsyczM3L17t9VqJSKxWBwVFZWYmDhp0qQ+ffrc9RMefvjhvXv3xsXF7dmzZ8SIEdu2bfPx8en44M4KJ0Y7O9Qo3HTy5Mm0tLTMzMzff//dPqJQKB577LHExMRnnnnGz8+vVZ/Wr1+/nJycuLi4I0eODB48ePv27T179uyA1EKAE6OdHQ7qu7TbTRk98cQTiYmJ48eP79atmyOfX1lZmZCQcPDgQX9//6ysrIiIrvdUDsZCn2jIaqbZFeTqwXca6BCo0a6oVVNGDjIYDMnJyTt27NBqtZmZmUOGdLFrJy//Rt8NJu/+9NpJvqNAR8FBfRfS5ikjR6jV6oyMjKlTp27cuHHUqFEbN26Mj49v929xXjdPjHaxPx5dDPZGO78Wp4wiIyPvfcrIcTab7dVXX12xYoWLi8vq1asnTpzIwZc6hbQJdCqNkr6jyOf5jgIdBXujndYdpowmTJjg7+/PZRiJRLJs2TIPD48FCxZMnjy5tLR01qxZXAbgTXgcMVbqPpTvHNCRWOhErFZrTk7OzJkzg4KCGv4Ta7XalJSU1atX6/V6vgOyixYtsp86mDNnTts+4csvv4yNjf3xxx9bXPrcc8/FxsaeOXPGgYwArYO90c7gzlNGcXFxLi4u/CZs8NZbb2m12hdeeGHBggUGg2HJkiViceuej1NYWLhnz564uLgWl+bl5Z06dUqv17dH2Nb7z/9QxZlbP7p6kGcv6juOtF31Yq+uATUqYLxMGTlu2rRp7u7uEydOXLp0aVVV1apVq9rxwgCeXfgPXdrXdHDHHBrxEQ2dy0cg4AJqVJDWrl37zTff5ObmMgxDRBKJZNiwYWPHjh03blxYWBjf6e4uKSlp27ZtSUlJ69atKy8v//nnn9VqNd+h2s+Ef1PocCIiQxkd/5Fy/kk7/0G+EdQrge9k0CFQo4JUWFi4f/9+hUIxdOhQXqaMHBcbG7tr166EhIRff/318ccf37p1q6enJ9+h2omLmhRaIiKFlob/L9VW0OFvKH8larSzQo0K0rRp0x544IG4uDiVSsV3lraLiorKzc0dNWpUXl7esGHDtm/fHhjYGd/41uMxOvwNVZ3jOwd0FNSoIPXs2bNz3KIeFhaWk5MTHx9/7NixoUOHbt++vVevXvey4dGjR9etW9d8XKfTtXdGh5mqiIhkAv6DB3eGGm271atXX716derUqcHBwU0W1dXVff755xKJZPbs2bxkExB/f//du3cnJiYeOHAgJiYmKyvrwQcfvOtWqampqampHMRzVH0N5X1FRNQjlu8o0FFQo2331VdfHT58eMiQIc1r1Gw2v/vuu1KpFDV6L7Ra7fbt259++umsrKzhw4dv2bIlJibmzptMnDgxKSmp+fjs2bMbnrHCm5M/UWk+MVbSXaKzW8hQSpoQevSvPKeCDoMaBaegUqnS09OnTZuWmpo6cuTIdevWJScn32H9iIiISZMmNR//6KOPOizjPcv/7tY/S+Q0YBKNXEhKL/4CQcdCjYKzcHFxWb9+vb+//6JFiyZMmLB8+fLnn3f6+9Cvn6SCTDqbQYPeoAH/fVBAwlc330cvU5E2jCTOcu8DdBDUKKeqqqoyMzPPnTsnEon69euXmJioVCr5DuVERCLRF1984efnN3fu3BdffLGysvLtt9/mO1QzjJUu5dCZzXQmnXQXbw5qw27VqEcv8o/iKx1wDzXKnY0bN7700ksNN2sSUUhIyI4dO+5xbrrrmDNnjlqtnjlz5jvvvFNRUfHJJ584xe1YVhMV/0oFmXQmnYxlNweV3hQeT/elUM+Wb06FrgA16ii9Xn/jxo0mgzU1NU1G8vLyJk+eLJFIFi9ePH78eKvVOm/evDVr1kydOvXgwYNchRWM119/XavVPvfccwsWLCgrK1uxYgVv72o2VVJBJhVkUtE2qjfcHNSGUe9E6p9CwdHkDBUPvEKNOqrF+eLm5s+fb7FY5s+fP3PmTPvI8uXLd+7cmZeXl5+fHxkZ2ZEZBWny5MkajWbChAmrVq2qrq5ev369QqHg7ut1F6komwoyqCibGAsRkUhM/lHUO5Hum0De/dv4sdXn6fA39Pg/8ablzgQ16qj4+HhfX98mgxaLpfHF4RaLZdeuXUQUExNTXFzcMP7AAw9cuXLl4MGDqNEWjR49eteuXaNHj968eXNCQsLmzZvd3NxefvnlkSNH9uvXr8VNvv7665qamt69e7ftG/Pz8wMrdvqcW0tlR28OSeQU/iT1HUd9kkjduvf6NcUytGE8lR2lyiJ6ah1J5A59GjgPvp/UJ2ADBw4kot27dzdfZD8BKpVK7T/e+UrGDz74gNPcQnP06FH7EwNWrlzZEZ9vf0jrnDlz7OX779mx7DxiP1ay6xLZP1az5upWfNaZdPbwMlZXctsVruSxC73YecSueow16xwPD84Ae6NcsD+HSaFQLF68uPnSe7lppyuLiIjYt29fWlra9OnT7SMsy27bti0zM7O4uFgsFgcGBg4ZMiQpKcnT09NsNkdHR3t4eOzYsePOH1tbW5udnZ2enp6ZmdlwdjsgIOCCahBNeZd6PNaWC5X63O0MT8DD9NweWhtHF3bT6hE0ZSupfFr9LeBs+O5xAbv3vVGTyWR/pGZlZSW3GTuhurq6sWPHNvwCN0zix8bGsixrMBiIyNvb+3ab37hxY/Xq1SkpKY3fHR0WFjZz5sycnBybzcbFv0P1BfbL3uw8YheHsZVFXHwjdCTsjXJBoVA88sgj+/bt++mnn1555RW+4wjbwoUL09PTfX19ly5dGhcXp1AoSkpKsrOz7c/ZU6lUer2++RP1L168mJ2dnZGRkZ2dbbFYiEgsFkdFRdkfM9i/f1unjNpGE0LT99LaJ6k0n1bG0F+yyed+TgNA++K7xwXs3vdGWZZNT08nIi8vr4b1rVZrVlbWt99+y1HczsI+ubRhw4Z7WfnEiROffvpp4xcBSKXS6OjoRYsWXb58uaOj3kVdDbvmCXYesZ9q2Uv7eQ4DDkCNtl2rapRl2fnz59v/Z3Z3dw8NDbW/H6lHjx4Wi4WryJ2Bh4cHEeXm5ra41Gw2JycnT58+nWXZjz/+uGF3QaPRTJw4MTU1Vadzpokdq5lNfYqdR+zHKrZwG99poI3wnvq2++yzzy5evDhr1qzw8PAmi8xm89tvv22/2L7x+B9//LF27dqzZ8+KxeKQkJCYmJixY8c6z/vmBCEiIuL48eNTpkz5/vvvm7/EyWg0qtVqb2/v8vLyvLy8xMTE+Pj4lJSUUaNGyeVOeYERa6PMGXTkW5K40Pg1dN8zfAeCVkONgsAsWbLkrbfeIqLIyMjXXnstJSVFo9E0LG1co/bfbae4kfTOWJZ2zqX9C0kkoYSvaOAMvgNB67Tu3bYAvHvzzTfff/99hUKRn5//0ksvBQQEPPvss0eOHGm+pkgkEkCHEpFIRE8soPhFxDL0y6v0K94hKjDYGwVBqq6uTktLW7t27b59+xiGkUqlP/zww8SJExvvjfKdsfWOrqEtLxBjpUFvUPxiEmEvRxhQoyBsJSUls2bN2rRpk0ajKSkpEYvFAq5RIjq9if492RQYfP2x4cE9vhaJcEmiAODPHQhbcHDwhg0bAgMDdTrd/v37+Y7jsH7J7NRtxQPpeuWKc+fGM4yJ70Bwd6hREDyZTNa9e3ciqqys5DtLOxD1GB7ac4NU6q3TZRYUDLdamz6GEZwNahQExmazNRkpKys7fvw4EYWGhvKRqP0plVF9++bK5WFG48GCgliL5QrfieBOUKMgJAzDDBgw4Pnnn8/MzCwoKCgpKcnIyIiLizMYDBEREYMGDeI7YLuRy3v26bPP1fV+k+nk2bMxdXWFfCeC28IUEwhJSUnJiBEjioqKmowPGDAgPT09LCxM2DP1zdhsVYWFo43GXJnMLzx8m1KJh4E5I9QoCM+hQ4d27txZXFxsMpl8fX2HDRuWkJBgf8sIwzB79+51cXEZMmQI3zHbB8MYz517Sq/Plkjcw8Mz1OqhfCeCplCjAM6OZevPn59WVZUqEsnDwta7u4/nOxH8Cc6NglClpKS8//77JlPnvyRIJHIJDf3R23sGy9YVF6dUVKzkOxH8CfZGQZAuXLgQGhrq6el5/fp1Ydzx2R5KSxdcuTKXSBQU9Jmv79/5jgM3YW8UBCk3N5eIBg8e3HU6lIj8/OYEB39JJLp8+e0rV3DrvbPArWYgSA01yncQrvn4vCGVai9cmF5ausBm03XvvtS+M8SyVr0+y2g8aLGUS6WecnmYRvOkTBZo3+r8+aksa+3R43ux2LXJB5rNp69e/VAu7xkY+HHTL4N7gxoFQTpw4AARdZrp+Fbx8JgikWiKiyfU1uYzjFksVtbVFRUVJZnNp/+8orhnz5/d3ZOIqKoqlWWtISHLiJrWqMVSXlWVqlQORI22GWoUhMdkMh07dkwikdhfQNAFaTSJvXvvkst7icVKlrWdOzfebD6tUj3i5/euQtGHYWpNpmPV1ZvU6kf5TtoloEZBePLy8iwWy0MPPaRWq/nOwhuV6mZFGo25JtMJqdSrV68dEsnN150qlZGens/yl65rwRQTCE+XPTHaovr6C0Qkl4c3dChwDDUKwoMabUwq9SIik+mY2VzAd5YuCgf1IDAsy9prtGvOLzWnVsfKZL4WS9mZM4P8/N719JzSMEHfhMGwVyxWNRk0mY51fMZODpffg8AUFhb27t3bx8enrKyM7yzOwmg8WFz8dH39ZSISiSRubnFeXi+6u49tONw8ckTGstY7fIJSObBfv0NcZO2MsDcKAmPfFY2OjuY7iBNRqR4ZMKCoujrjxo01en22TrdVp9vq5jYyPDxDJLr1WunAwH+KRIom29bVnbt+fSm3eTsb1CgIDE6Mtkgkkmu1T2u1T1utlTdurL569T29fkdp6af+/h80rOPt/ZpEommyYU3NHtSogzDFBAKDGr0zqdTD1/evQUGfE1FVVRrfcboE1CgISU1NzYkTJ2QyWVRUFN9ZnJpK9TARWSylfAfpElCjICQHDx602WyRkZGurk1vaoTGjMZDRCSTBfAdpEtAjYKQmM2nHnooAEf0jen1v547N66qKs1mqyIihjFXVf105co/iMjDYxLf6boETDGBkPTtm718+dWQEFwxeovBkFNdnV5dnU5EYrGSYUxELBG5uyfjmaTcQI2CgLBG429E5OaGvdFbAgI+9PCYXFW1sbb2kMVSLha7yOW93N2TNZonG9bp3n0ZESMWK5tvrlD0CQlZIZV6cxi5s8Hl9yAYZvPpkyf7y2QBERF4bzs4EZwbBcEwGA4QEV6NCc4GNQqCYTTmEpFajSN6cC6oURAMgyGXiFQq1Cg4F9QoCIPNVm02nxGJ5Erlg3xnAfgT1CgIg8GQS8SoVAMbP2sDwBmgRkEY7CdGVSpcMQpOBzUKwoD5JXBaqFEQBMZozKNG73EDcB6oURAAk+mEzaaXy0NlMn++swA0hRoFAbBfeI8To+CcUKMgADgxCs4MNQoCgAvvwZmhRsHZWa0VdXVFYrHK1TWC7ywALcCD8sDZSSTdevXaZrFcFYnw6wrOCA/KAwBwCA7qwRmxrK2iYmVR0egTJ8KPHw85ezb60qU3amr+Y19aUbEiP1996dKMFrc9f35Kfr66qiqVw7zQpeEoCZwOy1qKipL0+iwikkp9pFKP2tpjBsMBvX77gAEF9hUYxsgw5hY3ZxgTwxhZ1sJpaOjCUKPgdK5fX6bXZ8lkvmFhaWp1DBERsUbjIau1nOdkAC1BjYLT0em2EJGf33v/7VAiEqlUg3iMBHAHODcKTsdm0xORRKLmOwjAPUGNgtORy8OIqKzsX/X1F/nOAnB3OKgHp+PjM7OqaqPJdPLUqQgPj0keHtPU6hbupjcaD12+/Lfm4ybT8Y7PCHALahScjkr1aHj4L5cv/81kOnH9+rLr15cpFP19fN7w9n6l8fGT2XzKbD7FY04AO9QoOCM3t5H9+x+vrc2vrFxbWZlqNp+6dOm12trfQ0K+bVjH3T05KOjT5tteuvSGXr+dw7DQ1aFGwXkplZFKZWRQ0Gfl5UtLSt6qqPjO2/s1pfIh+1KJpJtc3qv5VmKxituY0NVhigmcn9jH5003t1FEZDDs5TsMQFOoURAGqVRLRDabke8gAE2hRsHpWK0VzUfsN9S7uvbjIxHAneDcKDidwsKRLMtotU+5uj4okajM5oLy8kUWS5lC0VejGc13OoCmUKPgXBjGLJf30ekyrl79oPG4Wj0kNPRHkUjOVzCA28HzRsEZMYyhpmav2XyWZeslErVK9ahSGdVoaa3NphOLlRKJpvm2Nls1w5gkEq1YrOAwMnRdqFEAAIdgigkAwCGoUQAAh6BGAQAcghoFAHAIahQAwCGoUQAAh/w/5kM91lORMsgAAACyelRYdHJka2l0UEtMIHJka2l0IDIwMjMuMDkuNAAAeJx7v2/tPQYg4GdAAA4gZgPiBkY+hgyQADMjP4MGmMEGEWBiEoAwGBkxGVA1jMxMYJMYuBkYgTqAfAYGFgZGVgYmNgYRkLh4FkgV3NKrdkkOcie794E4N97Pc+BebWQHYt9eIO2wzcbOHsS+anfIoaksfD+I/dXa12Gy1z0w26NwpcPatRftIUY12J+Mmm7HgAbEANcxIdFj5zaeAAABGXpUWHRNT0wgcmRraXQgMjAyMy4wOS40AAB4nH2SSW7DMAxF9z7Fv0AEUgMlLbrwELRFG7uondygi+x7f5RKodhBjVJacHgEqQ81KPY5vF2/cTc7NA1A/9ycMy6OiJoTioPu+Pw6ol/armb66TwuMxJEO/Q8ku0ynWqGMV/hTHCBUsCBjI2OOIMM3Uy7L+37k6+8xQe8ySIxJoU4ep94Q1fOoYc13rFkp+WoHvsdzmOGGIrEkcr45FmC3QGDgs5Ykpg8DmyiEwl7k0XBYGwOpBuyEZeY3Q4XdUO+Z0uD7G6Y8PK1Sf91bgIduPLHcXhQ+FfzbhqHVfNy7CppCd2qnNbgV4FKNawylKqsj9VOxO3s7aQS13+hfvMDBK93Gkl4MxsAAACQelRYdFNNSUxFUyByZGtpdCAyMDIzLjA5LjQAAHicFY07CsQwEEOvsmUCk2Hk+dm4TJMykDKk3GLvkMOvXQjE4wnt+339jme51nPk4vv4Pp93AQmX5gGjXtgU0QbJUSZRdnVptA0pVdAmKhJZaQOnRjh14xaTCCPNKqgHSwpyzqohvFD3+SKVwKEVwxESWt8/z+gf06IW8IIAAAAASUVORK5CYII=",
|
|
"text/plain": [
|
|
"<rdkit.Chem.rdchem.RWMol at 0x7eff28309400>"
|
|
]
|
|
},
|
|
"execution_count": 77,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"nodes = torch.tensor([7,4,1,6,6,6,1,5])\n",
|
|
"edges = torch.tensor([[0,0,1,0,1,2],[1,2,3,4,5,6]])\n",
|
|
"create_molecule_from_graph(nodes, edges)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 83,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"ename": "TypeError",
|
|
"evalue": "create_molecule_from_graph() takes 2 positional arguments but 3 were given",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
|
|
"Cell \u001b[0;32mIn[83], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m smiles \u001b[38;5;241m=\u001b[39m \u001b[43mdm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_train_smiles\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
|
"Cell \u001b[0;32mIn[75], line 167\u001b[0m, in \u001b[0;36mDataModule.get_train_smiles\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 165\u001b[0m graph \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_dataset[idx]\n\u001b[1;32m 166\u001b[0m \u001b[38;5;66;03m# print(graph.x, graph.edge_index)\u001b[39;00m\n\u001b[0;32m--> 167\u001b[0m mol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcreate_molecule_from_graph\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgraph\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgraph\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43medge_index\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m train_smiles\u001b[38;5;241m.\u001b[39mappend(Chem\u001b[38;5;241m.\u001b[39mMolToSmiles(mol))\n\u001b[1;32m 170\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtest_index:\n",
|
|
"\u001b[0;31mTypeError\u001b[0m: create_molecule_from_graph() takes 2 positional arguments but 3 were given"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"smiles = dm.get_train_smiles()"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "graphdit",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.19"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|