{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from nas_201_api import NASBench201API as API" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "api = API('./NAS-Bench-201-v1_1-096897.pth', verbose=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "results = api.query_by_index(1, 'cifar100')\n", "print('There are {:} trials for this architecture [{:}] on CIFAR-100'.format(len(results), api[1]))\n", "# \n", "for seed, result in results.items():\n", " print('Latency : {:}'.format(result.get_latency()))\n", " print('Train Info : {:}'.format(result.get_train()))\n", " print('Valid Info : {:}'.format(result.get_eval('x-valid')))\n", " print('Test Info : {:}'.format(result.get_eval('x-test')))\n", " print('')\n", " print('Train Info [10-th epoch]: {:}'.format(result.get_train(10)))" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append('../') \n", "\n", "import os\n", "import os.path as osp\n", "import pathlib\n", "import json\n", "\n", "import torch\n", "import torch.nn.functional as F\n", "from rdkit import Chem, RDLogger\n", "from rdkit.Chem.rdchem import BondType as BT\n", "from tqdm import tqdm\n", "import numpy as np\n", "import pandas as pd\n", "from torch_geometric.data import Data, InMemoryDataset\n", "from torch_geometric.loader import DataLoader\n", "from sklearn.model_selection import train_test_split\n" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [], "source": [ "def random_data_split(task, dataset):\n", " nan_count = torch.isnan(dataset.y[:, 0]).sum().item()\n", " labeled_len = len(dataset) - nan_count\n", " full_idx = list(range(labeled_len))\n", " train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2\n", " train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42)\n", " train_index, val_index, _, _ = train_test_split(train_index, train_index, test_size=valid_ratio/(valid_ratio+train_ratio), random_state=42)\n", " unlabeled_index = list(range(labeled_len, len(dataset)))\n", " print(task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index), 'unlabeled len', len(unlabeled_index))\n", " return train_index, val_index, test_index, unlabeled_index" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [], "source": [ "def parse_architecture_string(arch_str):\n", " print(arch_str)\n", " steps = arch_str.split('+')\n", " nodes = ['input'] # Start with input node\n", " edges = []\n", " for i, step in enumerate(steps):\n", " step = step.strip('|').split('|')\n", " for node in step:\n", " op, idx = node.split('~')\n", " edges.append((int(idx), i+1)) # i+1 because 0 is input node\n", " nodes.append(op)\n", " nodes.append('output') # Add output node\n", " return nodes, edges" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "def create_adj_matrix_and_ops(nodes, edges):\n", " num_nodes = len(nodes)\n", " adj_matrix = np.zeros((num_nodes, num_nodes), dtype=int)\n", " for (src, dst) in edges:\n", " adj_matrix[src][dst] = 1\n", " return adj_matrix, nodes" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "graphs = []\n", "length = 15625\n", "ops_type = {}\n", "len_ops = set()\n", "for i in range(length):\n", " arch_info = api.query_meta_info_by_index(i)\n", " nodes, edges = parse_architecture_string(arch_info.arch_str)\n", " adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges) \n", " if i < 5:\n", " print(\"Adjacency Matrix:\")\n", " print(adj_matrix)\n", " print(\"Operations List:\")\n", " print(ops)\n", " for op in ops:\n", " if op not in ops_type:\n", " ops_type[op] = len(ops_type)\n", " len_ops.add(len(ops))\n", " graphs.append((adj_matrix, ops))\n", "print(graphs[0])\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(len(ops_type))\n", "print(len(len_ops))\n", "print(ops_type)\n", "print(len_ops)" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [], "source": [ "op_to_atom = {\n", " 'input': 'Si', # Hydrogen for input\n", " 'nor_conv_1x1': 'C', # Carbon for 1x1 convolution\n", " 'nor_conv_3x3': 'N', # Nitrogen for 3x3 convolution\n", " 'avg_pool_3x3': 'O', # Oxygen for 3x3 average pooling\n", " 'skip_connect': 'P', # Phosphorus for skip connection\n", " 'none': 'S', # Sulfur for no operation\n", " 'output': 'He' # Helium for output\n", "}\n", "\n" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [], "source": [ "def graphs_to_json(graphs, filename):\n", " bonds = {\n", " 'nor_conv_1x1': 1,\n", " 'nor_conv_3x3': 2,\n", " 'avg_pool_3x3': 3,\n", " 'skip_connect': 4,\n", " 'input': 7,\n", " 'output': 5,\n", " 'none': 6\n", " }\n", "\n", " source_name = \"nas-bench-201\"\n", " num_graph = len(graphs)\n", " pt = Chem.GetPeriodicTable()\n", " atom_name_list = []\n", " atom_count_list = []\n", " for i in range(2, 119):\n", " atom_name_list.append(pt.GetElementSymbol(i))\n", " atom_count_list.append(0)\n", " atom_name_list.append('*')\n", " atom_count_list.append(0)\n", " n_atoms_per_mol = [0] * 500\n", " bond_count_list = [0, 0, 0, 0, 0, 0, 0, 0]\n", " bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}\n", " valencies = [0] * 500\n", " transition_E = np.zeros((118, 118, 5))\n", "\n", " n_atom_list = []\n", " n_bond_list = []\n", " # graphs = [(adj_matrix, ops), ...]\n", " for graph in graphs:\n", " ops = graph[1]\n", " adj = graph[0]\n", " n_atom = len(ops)\n", " n_bond = len(ops)\n", " n_atom_list.append(n_atom)\n", " n_bond_list.append(n_bond)\n", "\n", " n_atoms_per_mol[n_atom] += 1\n", " cur_atom_count_arr = np.zeros(118)\n", "\n", " for op in ops:\n", " symbol = op_to_atom[op]\n", " if symbol == 'H':\n", " continue\n", " elif symbol == '*':\n", " atom_count_list[-1] += 1\n", " cur_atom_count_arr[-1] += 1\n", " else:\n", " atom_count_list[pt.GetAtomicNumber(symbol)-2] += 1\n", " cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2] += 1\n", " # print('symbol', symbol)\n", " # print('pt.GetDefaultValence(symbol)', pt.GetDefaultValence(symbol))\n", " # print(f'cur_atom_count_arr[{pt.GetAtomicNumber(symbol)-2}], {cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2]}')\n", " try:\n", " valencies[int(pt.GetDefaultValence(symbol))] += 1\n", " except:\n", " print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol)))\n", " transition_E_temp = np.zeros((118, 118, 5))\n", " # print(n_atom)\n", " for i in range(n_atom):\n", " for j in range(n_atom):\n", " if i == j or adj[i][j] == 0:\n", " continue\n", " start_atom, end_atom = i, j\n", " # if ops[start_atom] == 'input' or ops[end_atom] == 'input':\n", " # continue\n", " # if ops[start_atom] == 'output' or ops[end_atom] == 'output':\n", " # continue\n", " # if ops[start_atom] == 'none' or ops[end_atom] == 'none':\n", " # continue\n", " \n", " start_index = pt.GetAtomicNumber(op_to_atom[ops[start_atom]]) - 2\n", " end_index = pt.GetAtomicNumber(op_to_atom[ops[end_atom]]) - 2\n", " bond_index = bonds[ops[end_atom]]\n", " bond_count_list[bond_index] += 2\n", "\n", " # print(start_index, end_index, bond_index)\n", " \n", " transition_E[start_index, end_index, bond_index] += 2\n", " transition_E[end_index, start_index, bond_index] += 2\n", " transition_E_temp[start_index, end_index, bond_index] += 2\n", " transition_E_temp[end_index, start_index, bond_index] += 2\n", "\n", " bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2\n", " print(bond_count_list)\n", " cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2\n", " # print(f'cur_tot_bond={cur_tot_bond}') \n", " # find non-zero element in cur_tot_bond\n", " # for i in range(118):\n", " # for j in range(118):\n", " # if cur_tot_bond[i][j] != 0:\n", " # print(f'i={i}, j={j}, cur_tot_bond[i][j]={cur_tot_bond[i][j]}')\n", " # n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)\n", " cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2\n", " # print(f\"transition_E[:,:,0]={cur_tot_bond - transition_E_temp.sum(axis=-1)}\")\n", " transition_E[:, :, 0] += cur_tot_bond - transition_E_temp.sum(axis=-1)\n", " # find non-zero element in transition_E\n", " # for i in range(118):\n", " # for j in range(118):\n", " # if transition_E[i][j][0] != 0:\n", " # print(f'i={i}, j={j}, transition_E[i][j][0]={transition_E[i][j][0]}')\n", " assert (cur_tot_bond > transition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}'\n", " \n", " n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)\n", " n_atoms_per_mol = n_atoms_per_mol.tolist()[:51]\n", "\n", " atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list)\n", " print('processed meta info: ------', filename, '------')\n", " print('len atom_count_list', len(atom_count_list))\n", " print('len atom_name_list', len(atom_name_list))\n", " active_atoms = np.array(atom_name_list)[atom_count_list > 0]\n", " active_atoms = active_atoms.tolist()\n", " atom_count_list = atom_count_list.tolist()\n", "\n", " bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list)\n", " bond_count_list = bond_count_list.tolist()\n", " valencies = np.array(valencies) / np.sum(valencies)\n", " valencies = valencies.tolist()\n", "\n", " no_edge = np.sum(transition_E, axis=-1) == 0\n", " for i in range(118):\n", " for j in range(118):\n", " if no_edge[i][j] == False:\n", " print(f'have an edge at i={i} , j={j}, transition_E[i][j]={transition_E[i][j]}')\n", " # print(f'no_edge: {no_edge}')\n", " first_elt = transition_E[:, :, 0]\n", " first_elt[no_edge] = 1\n", " transition_E[:, :, 0] = first_elt\n", "\n", " transition_E = transition_E / np.sum(transition_E, axis=-1, keepdims=True)\n", "\n", " # find non-zero element in transition_E again\n", " for i in range(118):\n", " for j in range(118):\n", " if transition_E[i][j][0] != 0 and transition_E[i][j][0] != 1 and transition_E[i][j][0] != -1:\n", " print(f'i={i}, j={j}, 2_transition_E[i][j][0]={transition_E[i][j][0]}')\n", "\n", " meta_dict = {\n", " 'source': 'nasbench-201',\n", " 'num_graph': num_graph,\n", " 'n_atoms_per_mol_dist': n_atoms_per_mol[:51],\n", " 'max_node': max(n_atom_list),\n", " 'max_bond': max(n_bond_list),\n", " 'atom_type_dist': atom_count_list,\n", " 'bond_type_dist': bond_count_list,\n", " 'valencies': valencies,\n", " 'active_atoms': [atom_name_list[i] for i in range(118) if atom_count_list[i] > 0],\n", " 'num_atom_type': len([atom_name_list[i] for i in range(118) if atom_count_list[i] > 0]),\n", " 'transition_E': transition_E.tolist(),\n", " }\n", "\n", " with open(f'{filename}.meta.json', 'w') as f:\n", " json.dump(meta_dict, f)\n", " return meta_dict\n", "\n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "graphs_to_json(graphs, 'nasbench-201')" ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [], "source": [ "def gen_adj_matrix_and_ops(nasbench):\n", " i = 0\n", " epoch = 108\n", "\n", " for unique_hash in nasbench.hash_iterator():\n", " fixed_metrics, computed_metrics = nasbench.get_metrics_from_hash(unique_hash)" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch_geometric.data import InMemoryDataset, Data\n", "import os.path as osp\n", "import pandas as pd\n", "from tqdm import tqdm\n", "import networkx as nx\n", "import numpy as np\n", "\n", "class Dataset(InMemoryDataset):\n", " def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):\n", " self.target_prop = target_prop\n", " source = './NAS-Bench-201-v1_1-096897.pth'\n", " self.source = source\n", " self.api = API(source) # Initialize NAS-Bench-201 API\n", " super().__init__(root, transform, pre_transform, pre_filter)\n", " self.data, self.slices = torch.load(self.processed_paths[0])\n", "\n", " @property\n", " def raw_file_names(self):\n", " return [] # NAS-Bench-201 data is loaded via the API, no raw files needed\n", " \n", " @property\n", " def processed_file_names(self):\n", " return [f'{self.source}.pt']\n", "\n", " def process(self):\n", " def parse_architecture_string(arch_str):\n", " stages = arch_str.split('+')\n", " nodes = ['input']\n", " edges = []\n", " \n", " for stage in stages:\n", " operations = stage.strip('|').split('|')\n", " for op in operations:\n", " operation, idx = op.split('~')\n", " idx = int(idx)\n", " edges.append((idx, len(nodes))) # Add edge from idx to the new node\n", " nodes.append(operation)\n", " nodes.append('output') # Add the output node\n", " return nodes, edges\n", "\n", " def create_graph(nodes, edges):\n", " G = nx.DiGraph()\n", " for i, node in enumerate(nodes):\n", " G.add_node(i, label=node)\n", " G.add_edges_from(edges)\n", " return G\n", "\n", " def arch_to_graph(arch_str, sa, sc, target, target2=None, target3=None):\n", " nodes, edges = parse_architecture_string(arch_str)\n", "\n", " node_labels = [bonds[node] for node in nodes] # Replace with appropriate encoding if necessary\n", " assert 0 not in node_labels, f'Invalid node label: {node_labels}'\n", " x = torch.LongTensor(node_labels)\n", "\n", " edges_list = [(start, end) for start, end in edges]\n", " edge_type = [bonds[nodes[end]] for start, end in edges] # Example: using end node type as edge type\n", " edge_index = torch.tensor(edges_list, dtype=torch.long).t().contiguous()\n", " edge_type = torch.tensor(edge_type, dtype=torch.long)\n", " edge_attr = edge_type.view(-1, 1)\n", "\n", " if target3 is not None:\n", " y = torch.tensor([sa, sc, target, target2, target3], dtype=torch.float).view(1, -1)\n", " elif target2 is not None:\n", " y = torch.tensor([sa, sc, target, target2], dtype=torch.float).view(1, -1)\n", " else:\n", " y = torch.tensor([sa, sc, target], dtype=torch.float).view(1, -1)\n", "\n", " data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)\n", " return data, nodes\n", "\n", " bonds = {\n", " 'nor_conv_1x1': 1,\n", " 'nor_conv_3x3': 2,\n", " 'avg_pool_3x3': 3,\n", " 'skip_connect': 4,\n", " 'output': 5,\n", " 'none': 6,\n", " 'input': 7\n", " }\n", "\n", " # Prepare to process NAS-Bench-201 data\n", " data_list = []\n", " len_data = len(self.api) # Number of architectures\n", " with tqdm(total=len_data) as pbar:\n", " for arch_index in range(len_data):\n", " arch_info = self.api.query_meta_info_by_index(arch_index)\n", " arch_str = arch_info.arch_str\n", " sa = np.random.rand() # Placeholder for synthetic accessibility\n", " sc = np.random.rand() # Placeholder for substructure count\n", " target = np.random.rand() # Placeholder for target value\n", " target2 = np.random.rand() # Placeholder for second target value\n", " target3 = np.random.rand() # Placeholder for third target value\n", "\n", " data, active_nodes = arch_to_graph(arch_str, sa, sc, target, target2, target3)\n", " data_list.append(data)\n", " pbar.update(1)\n", "\n", " torch.save(self.collate(data_list), self.processed_paths[0])\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# dataset = Dataset(source='./NAS-Bench-201-v1_1-096897.pth', root='./data')" ] }, { "cell_type": "code", "execution_count": 84, "metadata": {}, "outputs": [], "source": [ "import os\n", "import pathlib\n", "import torch\n", "from torch_geometric.data import DataLoader\n", "from sklearn.model_selection import train_test_split\n", "import pandas as pd\n", "from tqdm import tqdm\n", "# import nas_bench_201 as nb201\n", "\n", "import utils as utils\n", "from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule\n", "from diffusion.distributions import DistributionNodes\n", "\n", "from rdkit.Chem import rdchem\n", "\n", "class DataModule(AbstractDataModule):\n", " def __init__(self, cfg):\n", " self.datadir = cfg.dataset.datadir\n", " self.task = cfg.dataset.task_name\n", " print(\"DataModule\")\n", " print(\"task\", self.task)\n", " print(\"datadir\", self.datadir)\n", " super().__init__(cfg)\n", "\n", " def prepare_data(self) -> None:\n", " target = getattr(self.cfg.dataset, 'guidance_target', None)\n", " print(\"target\", target)\n", " # try:\n", " # base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]\n", " # except NameError:\n", " # base_path = pathlib.Path(os.getcwd()).parent[2]\n", " base_path = '/home/stud/hanzhang/Graph-Dit'\n", " root_path = os.path.join(base_path, self.datadir)\n", " self.root_path = root_path\n", "\n", " batch_size = self.cfg.train.batch_size\n", " \n", " num_workers = self.cfg.train.num_workers\n", " pin_memory = self.cfg.dataset.pin_memory\n", "\n", " # Load the dataset to the memory\n", " # Dataset has target property, root path, and transform\n", " source = './NAS-Bench-201-v1_1-096897.pth'\n", " dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None)\n", "\n", " # if len(self.task.split('-')) == 2:\n", " # train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)\n", " # else:\n", " train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset)\n", "\n", " self.train_index, self.val_index, self.test_index, self.unlabeled_index = train_index, val_index, test_index, unlabeled_index\n", " train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index)\n", " if len(unlabeled_index) > 0:\n", " train_index = torch.cat([train_index, unlabeled_index], dim=0)\n", " \n", " train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index]\n", " self.train_dataset = train_dataset \n", " print('train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))\n", " print('train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))\n", " print('dataset len', len(dataset), 'train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))\n", " self.train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=pin_memory)\n", "\n", " self.val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)\n", " self.test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)\n", "\n", " training_iterations = len(train_dataset) // batch_size\n", " self.training_iterations = training_iterations\n", " \n", " def random_data_split(self, dataset):\n", " nan_count = torch.isnan(dataset.data.y[:, 0]).sum().item()\n", " labeled_len = len(dataset) - nan_count\n", " full_idx = list(range(labeled_len))\n", " train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2\n", " train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42)\n", " train_index, val_index, _, _ = train_test_split(train_index, train_index, test_size=valid_ratio/(valid_ratio+train_ratio), random_state=42)\n", " unlabeled_index = list(range(labeled_len, len(dataset)))\n", " print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index), 'unlabeled len', len(unlabeled_index))\n", " return train_index, val_index, test_index, unlabeled_index\n", " \n", " def fixed_split(self, dataset):\n", " if self.task == 'O2-N2':\n", " test_index = [42,43,92,122,197,198,251,254,257,355,511,512,549,602,603,604]\n", " else:\n", " raise ValueError('Invalid task name: {}'.format(self.task))\n", " full_idx = list(range(len(dataset)))\n", " full_idx = list(set(full_idx) - set(test_index))\n", " train_ratio = 0.8\n", " train_index, val_index, _, _ = train_test_split(full_idx, full_idx, test_size=1-train_ratio, random_state=42)\n", " print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))\n", " return train_index, val_index, test_index, []\n", "\n", " def parse_architecture_string(arch_str):\n", " stages = arch_str.split('+')\n", " nodes = ['input']\n", " edges = []\n", " \n", " for stage in stages:\n", " operations = stage.strip('|').split('|')\n", " for op in operations:\n", " operation, idx = op.split('~')\n", " idx = int(idx)\n", " edges.append((idx, len(nodes))) # Add edge from idx to the new node\n", " nodes.append(operation)\n", " nodes.append('output') # Add the output node\n", " return nodes, edges\n", "\n", " def create_molecule_from_graph(nodes, edges):\n", " mol = Chem.RWMol() # RWMol allows for building the molecule step by step\n", " atom_indices = {}\n", " num_to_op = {\n", " 1 :'nor_conv_1x1',\n", " 2 :'nor_conv_3x3',\n", " 3 :'avg_pool_3x3',\n", " 4 :'skip_connect',\n", " 5 :'output',\n", " 6 :'none',\n", " 7 :'input'\n", " } \n", "\n", " # Extract node operations from the data object\n", "\n", " # Add atoms to the molecule\n", " for i, op_tensor in enumerate(nodes):\n", " op = op_tensor.item()\n", " if op == 0: continue\n", " op = num_to_op[op]\n", " atom_symbol = op_to_atom[op]\n", " atom = Chem.Atom(atom_symbol)\n", " atom_idx = mol.AddAtom(atom)\n", " atom_indices[i] = atom_idx\n", " \n", " # Add bonds to the molecule\n", " edge_number = edges.shape[1]\n", " for i in range(edge_number):\n", " start = edges[0, i].item()\n", " end = edges[1, i].item()\n", " mol.AddBond(atom_indices[start], atom_indices[end], rdchem.BondType.SINGLE)\n", " \n", " return mol\n", "\n", " def arch_str_to_smiles(self, arch_str):\n", " nodes, edges = self.parse_architecture_string(arch_str)\n", " mol = self.create_molecule_from_graph(nodes, edges)\n", " smiles = Chem.MolToSmiles(mol)\n", " return smiles\n", "\n", " def get_train_smiles(self):\n", " # raise NotImplementedError(\"This method is not applicable for NAS-Bench-201 data.\")\n", " # train_arch_strs = []\n", " # test_arch_strs = []\n", "\n", " # for idx in self.train_index:\n", " # arch_info = self.train_dataset[idx]\n", " # arch_str = arch_info.arch_str\n", " # train_arch_strs.append(arch_str)\n", " # for idx in self.test_index:\n", " # arch_info = self.train_dataset[idx]\n", " # arch_str = arch_info.arch_str\n", " # test_arch_strs.append(arch_str)\n", "\n", " train_smiles = [] \n", " test_smiles = []\n", "\n", " for idx in self.train_index:\n", " graph = self.train_dataset[idx]\n", " print(graph.x, graph.edge_index)\n", " print(f'class of graph.x: {graph.x.__class__}, class of graph.edge_index: {graph.edge_index.__class__}')\n", " mol = self.create_molecule_from_graph(graph.x, graph.edge_index)\n", " train_smiles.append(Chem.MolToSmiles(mol))\n", " \n", " for idx in self.test_index:\n", " graph = self.train_dataset[idx]\n", " mol = self.create_molecule_from_graph(graph.x, graph.edge_index)\n", " test_smiles.append(Chem.MolToSmiles(mol))\n", " \n", " # train_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in train_arch_strs]\n", " # test_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in test_arch_strs]\n", " return train_smiles, test_smiles\n", "\n", " def get_data_split(self):\n", " raise NotImplementedError(\"This method is not applicable for NAS-Bench-201 data.\")\n", "\n", " def example_batch(self):\n", " return next(iter(self.val_loader))\n", " \n", " def train_dataloader(self):\n", " return self.train_loader\n", "\n", " def val_dataloader(self):\n", " return self.val_loader\n", " \n", " def test_dataloader(self):\n", " return self.test_loader\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [], "source": [ "from omegaconf import DictConfig, OmegaConf\n", "import argparse\n", "import hydra\n", "\n", "def parse_arg():\n", " parser = argparse.ArgumentParser(description='Diffusion')\n", " parser.add_argument('--config', type=str, default='config.yaml', help='config file')\n", " return parser.parse_args()\n", "\n", "def task1(cfg: DictConfig):\n", " datamodule = DataModule(cfg=cfg)\n", " datamodule.prepare_data()\n", " return datamodule\n", "\n", "cfg = {\n", " 'general':{\n", " 'name': 'graph_dit',\n", " 'wandb': 'disabled' ,\n", " 'gpus': 1,\n", " 'resume': 'null',\n", " 'test_only': 'null',\n", " 'sample_every_val': 2500,\n", " 'samples_to_generate': 512,\n", " 'samples_to_save': 3,\n", " 'chains_to_save': 1,\n", " 'log_every_steps': 50,\n", " 'number_chain_steps': 8,\n", " 'final_model_samples_to_generate': 10000,\n", " 'final_model_samples_to_save': 20,\n", " 'final_model_chains_to_save': 1,\n", " 'enable_progress_bar': False,\n", " 'save_model': True,\n", " },\n", " 'model':{\n", " 'type': 'discrete',\n", " 'transition': 'marginal',\n", " 'model': 'graph_dit',\n", " 'diffusion_steps': 500,\n", " 'diffusion_noise_schedule': 'cosine',\n", " 'guide_scale': 2,\n", " 'hidden_size': 1152,\n", " 'depth': 6,\n", " 'num_heads': 16,\n", " 'mlp_ratio': 4,\n", " 'drop_condition': 0.01,\n", " 'lambda_train': [1, 10], # node and edge training weight \n", " 'ensure_connected': True,\n", " },\n", " 'train':{\n", " 'n_epochs': 10000,\n", " 'batch_size': 1200,\n", " 'lr': 0.0002,\n", " 'clip_grad': 'null',\n", " 'num_workers': 0,\n", " 'weight_decay': 0,\n", " 'seed': 0,\n", " 'val_check_interval': 'null',\n", " 'check_val_every_n_epoch': 1,\n", " },\n", " 'dataset':{\n", " 'datadir': 'data',\n", " 'task_name': 'nasbench-201',\n", " 'guidance_target': 'nasbench-201',\n", " 'pin_memory': False,\n", " },\n", "}\n" ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "DataModule\n", "task nasbench-201\n", "datadir data\n", "target nasbench-201\n", "try to create the NAS-Bench-201 api from ./NAS-Bench-201-v1_1-096897.pth\n" ] } ], "source": [ "cfg = OmegaConf.create(cfg)\n", "dm = task1(cfg)" ] }, { "cell_type": "code", "execution_count": 73, "metadata": {}, "outputs": [], "source": [ "\n", "def create_molecule_from_graph(nodes, edges):\n", " mol = Chem.RWMol() # RWMol allows for building the molecule step by step\n", " atom_indices = {}\n", " num_to_op = {\n", " 1 :'nor_conv_1x1',\n", " 2 :'nor_conv_3x3',\n", " 3 :'avg_pool_3x3',\n", " 4 :'skip_connect',\n", " 5 :'output',\n", " 6 :'none',\n", " 7 :'input'\n", " } \n", "\n", " # Extract node operations from the data object\n", "\n", " # Add atoms to the molecule\n", " for i, op_tensor in enumerate(nodes):\n", " op = op_tensor.item()\n", " if op == 0: continue\n", " op = num_to_op[op]\n", " atom_symbol = op_to_atom[op]\n", " atom = Chem.Atom(atom_symbol)\n", " atom_idx = mol.AddAtom(atom)\n", " atom_indices[i] = atom_idx\n", " \n", " # Add bonds to the molecule\n", " edge_number = edges.shape[1]\n", " for i in range(edge_number):\n", " start = edges[0, i].item()\n", " end = edges[1, i].item()\n", " mol.AddBond(atom_indices[start], atom_indices[end], rdchem.BondType.SINGLE)\n", " \n", " return mol" ] }, { "cell_type": "code", "execution_count": 77, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAcIAAACWCAIAAADCEh9HAAAABmJLR0QA/wD/AP+gvaeTAAAYNUlEQVR4nO3deVxU9d4H8O9szDAzMgz7KoK4ZhRhliKSlkKIqBTmdi1bbTPvvaXeevVkPnVLe26pZa/UMjVTCa+JkIKmV0Ul0SR3BcQFF0AEZphhBmbmnOeP8SIBLjBwzhz4vP+K3zln5tMr+nDO+Z1FxLIsAQBAW4n5DgAAIGyoUQAAh6BGAQAcghoFAHAIahQAwCGoUQAAh0j5DgDQFbFsncGwz2wuZNl6F5dgpfIhF5cQ+yKj8TedLkulitJoxjTfsLJyndlcoNWmuLrex21kuC3UKADXqqs3X7z4itVa3mhM1K3biN69fyUio/G3a9c+9PJ66XY1qtP94uraHzXqPFCjAJwyGvOKi1NY1urunuzuPlYq9ayvL6mp2SkWq/iOBm2EGgXgVHn5Fyxr9fJ6OSRkWcOgt/cMHiOBgzDFBMAps/kMEbm5jeI7CLQb1CgAp8TibkRkNObyHQTaDQ7qATjl5hZnMOSUlf2LYQy+vu/I5T1bXI1hai2Wq83HWbaugwNCq4nwhCcALjGMubg4WafbRkREIrU62svrRa02RSxW2lcoL19UUvLXO39IWFiqVjuhg5PCvcLeKACnxGJFePgvOl1WZeUP1dWbDYZ9BsO+K1fm9uy5SaUa3LCaQtFPrY5pvrlen1Vff4nDvHB3qFEA7ok0mic1midttprq6s2lpZ+YzaeLisbdf39xw2VPavXQxlP5DYqKElGjzgZTTAC8kUi6eXr+pW/fXJks0Got1+ky+U4EbYEaBeCZRKLp1i2GiOrqLvKdBdoCNQrAP7O5gIikUi3fQaAtUKMAnLp2bX5l5VqbTW//kWXrr137sLb2iEjk4uYWx282aBtMMQFwh2FqS0s/ZRiTSCRzcQkRi1X19RdsNh2RODj4CxeX7nwHhLZAjQJwRyxW3nff6crK9Xp9ttlcYLXekEq9NJoxPj6vq1SP2tdRKqN8ff/e8GMT7u7JCkVfhaIPh6nhLnD5PQCAQ7A3CsC1mpo9DGNQq4dKJBq+s0A7wBQTANfKyhYWFSXq9Vl8B4H2gRoF4BhrNP5GRCrVEL6TQPtAjQJwymw+Y7VWymQBLi7BfGeB9oEaBeCUwZBLRGp1NN9BoN2gRgE4ZX9gc+OHOYHQoUYBOGWvUbUaNdp5oEYBuGOzVZtMp0UiuVIZyXcWaDeoUQDuGI2/ETEq1UCRSM53Fmg3qFEA7tjnl3BitJNBjQJwBydGOyXUKABnGKMxj4hUqkf4TgLtCTUKwBGT6aTNpnNx6SGTBfKdBdoTahSAIwbDASJSq3EPaGeDGgXgCC6876xQowAcwfxSZ4UaBeCCzXrDbC4Ui5WurhF8Z4F2hsc2A3BBUpz7wD4P84OjRSIZ31mgnWFvFIATJQek1TfUVszRd0KoUQBOXM4lIgrGidFOCDUK0PEYK109TCSioJbf9wmChhqFdnPy5Mlr167xncIplR2jegN5hJPSm+8o0P5Qo+Co8+fPz5gxw93dfcCAAQEBAaNGjeI7kfMpOUBEFIwL7zsnzNRDW7Ase/jw4Z9//jk9Pf3UqVONF+3YsePZZ59dtWqVSCTiK57TwYnRTg01Cq1gs9lyc3PT0tI2bdp0+fJl+6BWqx0+fHhQUNB77723cOHCJUuWrFmzhmXZlStXSqX4BSMiopJcIqIg1GjnJGJZlu8M4Oxqa2t37tyZlpaWkZFRXV1tH+zevXt8fHxiYmJ8fLxMdutayF27do0bN66mpmbMmDGpqamurq48pXYaxjL6Pz9y6UZzq0gk4TsNtD/UKNxWRUXF1q1bMzMzt27dajQa7YP9+/cfM2ZMYmJidHT07Q7bDx06lJCQUFFRERsbm56ertFoOEztfE5vop+eorCR9JftfEeBDoFjLmjq/PnzW7ZsyczM3L17t9VqJSKxWBwVFZWYmDhp0qQ+ffrc9RMefvjhvXv3xsXF7dmzZ8SIEdu2bfPx8en44M4KJ0Y7O9Qo3HTy5Mm0tLTMzMzff//dPqJQKB577LHExMRnnnnGz8+vVZ/Wr1+/nJycuLi4I0eODB48ePv27T179uyA1EKAE6OdHQ7qu7TbTRk98cQTiYmJ48eP79atmyOfX1lZmZCQcPDgQX9//6ysrIiIrvdUDsZCn2jIaqbZFeTqwXca6BCo0a6oVVNGDjIYDMnJyTt27NBqtZmZmUOGdLFrJy//Rt8NJu/+9NpJvqNAR8FBfRfS5ikjR6jV6oyMjKlTp27cuHHUqFEbN26Mj49v929xXjdPjHaxPx5dDPZGO78Wp4wiIyPvfcrIcTab7dVXX12xYoWLi8vq1asnTpzIwZc6hbQJdCqNkr6jyOf5jgIdBXujndYdpowmTJjg7+/PZRiJRLJs2TIPD48FCxZMnjy5tLR01qxZXAbgTXgcMVbqPpTvHNCRWOhErFZrTk7OzJkzg4KCGv4Ta7XalJSU1atX6/V6vgOyixYtsp86mDNnTts+4csvv4yNjf3xxx9bXPrcc8/FxsaeOXPGgYwArYO90c7gzlNGcXFxLi4u/CZs8NZbb2m12hdeeGHBggUGg2HJkiViceuej1NYWLhnz564uLgWl+bl5Z06dUqv17dH2Nb7z/9QxZlbP7p6kGcv6juOtF31Yq+uATUqYLxMGTlu2rRp7u7uEydOXLp0aVVV1apVq9rxwgCeXfgPXdrXdHDHHBrxEQ2dy0cg4AJqVJDWrl37zTff5ObmMgxDRBKJZNiwYWPHjh03blxYWBjf6e4uKSlp27ZtSUlJ69atKy8v//nnn9VqNd+h2s+Ef1PocCIiQxkd/5Fy/kk7/0G+EdQrge9k0CFQo4JUWFi4f/9+hUIxdOhQXqaMHBcbG7tr166EhIRff/318ccf37p1q6enJ9+h2omLmhRaIiKFlob/L9VW0OFvKH8larSzQo0K0rRp0x544IG4uDiVSsV3lraLiorKzc0dNWpUXl7esGHDtm/fHhjYGd/41uMxOvwNVZ3jOwd0FNSoIPXs2bNz3KIeFhaWk5MTHx9/7NixoUOHbt++vVevXvey4dGjR9etW9d8XKfTtXdGh5mqiIhkAv6DB3eGGm271atXX716derUqcHBwU0W1dXVff755xKJZPbs2bxkExB/f//du3cnJiYeOHAgJiYmKyvrwQcfvOtWqampqampHMRzVH0N5X1FRNQjlu8o0FFQo2331VdfHT58eMiQIc1r1Gw2v/vuu1KpFDV6L7Ra7fbt259++umsrKzhw4dv2bIlJibmzptMnDgxKSmp+fjs2bMbnrHCm5M/UWk+MVbSXaKzW8hQSpoQevSvPKeCDoMaBaegUqnS09OnTZuWmpo6cuTIdevWJScn32H9iIiISZMmNR//6KOPOizjPcv/7tY/S+Q0YBKNXEhKL/4CQcdCjYKzcHFxWb9+vb+//6JFiyZMmLB8+fLnn3f6+9Cvn6SCTDqbQYPeoAH/fVBAwlc330cvU5E2jCTOcu8DdBDUKKeqqqoyMzPPnTsnEon69euXmJioVCr5DuVERCLRF1984efnN3fu3BdffLGysvLtt9/mO1QzjJUu5dCZzXQmnXQXbw5qw27VqEcv8o/iKx1wDzXKnY0bN7700ksNN2sSUUhIyI4dO+5xbrrrmDNnjlqtnjlz5jvvvFNRUfHJJ584xe1YVhMV/0oFmXQmnYxlNweV3hQeT/elUM+Wb06FrgA16ii9Xn/jxo0mgzU1NU1G8vLyJk+eLJFIFi9ePH78eKvVOm/evDVr1kydOvXgwYNchRWM119/XavVPvfccwsWLCgrK1uxYgVv72o2VVJBJhVkUtE2qjfcHNSGUe9E6p9CwdHkDBUPvEKNOqrF+eLm5s+fb7FY5s+fP3PmTPvI8uXLd+7cmZeXl5+fHxkZ2ZEZBWny5MkajWbChAmrVq2qrq5ev369QqHg7ut1F6komwoyqCibGAsRkUhM/lHUO5Hum0De/dv4sdXn6fA39Pg/8ablzgQ16qj4+HhfX98mgxaLpfHF4RaLZdeuXUQUExNTXFzcMP7AAw9cuXLl4MGDqNEWjR49eteuXaNHj968eXNCQsLmzZvd3NxefvnlkSNH9uvXr8VNvv7665qamt69e7ftG/Pz8wMrdvqcW0tlR28OSeQU/iT1HUd9kkjduvf6NcUytGE8lR2lyiJ6ah1J5A59GjgPvp/UJ2ADBw4kot27dzdfZD8BKpVK7T/e+UrGDz74gNPcQnP06FH7EwNWrlzZEZ9vf0jrnDlz7OX779mx7DxiP1ay6xLZP1az5upWfNaZdPbwMlZXctsVruSxC73YecSueow16xwPD84Ae6NcsD+HSaFQLF68uPnSe7lppyuLiIjYt29fWlra9OnT7SMsy27bti0zM7O4uFgsFgcGBg4ZMiQpKcnT09NsNkdHR3t4eOzYsePOH1tbW5udnZ2enp6ZmdlwdjsgIOCCahBNeZd6PNaWC5X63O0MT8DD9NweWhtHF3bT6hE0ZSupfFr9LeBs+O5xAbv3vVGTyWR/pGZlZSW3GTuhurq6sWPHNvwCN0zix8bGsixrMBiIyNvb+3ab37hxY/Xq1SkpKY3fHR0WFjZz5sycnBybzcbFv0P1BfbL3uw8YheHsZVFXHwjdCTsjXJBoVA88sgj+/bt++mnn1555RW+4wjbwoUL09PTfX19ly5dGhcXp1AoSkpKsrOz7c/ZU6lUer2++RP1L168mJ2dnZGRkZ2dbbFYiEgsFkdFRdkfM9i/f1unjNpGE0LT99LaJ6k0n1bG0F+yyed+TgNA++K7xwXs3vdGWZZNT08nIi8vr4b1rVZrVlbWt99+y1HczsI+ubRhw4Z7WfnEiROffvpp4xcBSKXS6OjoRYsWXb58uaOj3kVdDbvmCXYesZ9q2Uv7eQ4DDkCNtl2rapRl2fnz59v/Z3Z3dw8NDbW/H6lHjx4Wi4WryJ2Bh4cHEeXm5ra41Gw2JycnT58+nWXZjz/+uGF3QaPRTJw4MTU1Vadzpokdq5lNfYqdR+zHKrZwG99poI3wnvq2++yzzy5evDhr1qzw8PAmi8xm89tvv22/2L7x+B9//LF27dqzZ8+KxeKQkJCYmJixY8c6z/vmBCEiIuL48eNTpkz5/vvvm7/EyWg0qtVqb2/v8vLyvLy8xMTE+Pj4lJSUUaNGyeVOeYERa6PMGXTkW5K40Pg1dN8zfAeCVkONgsAsWbLkrbfeIqLIyMjXXnstJSVFo9E0LG1co/bfbae4kfTOWJZ2zqX9C0kkoYSvaOAMvgNB67Tu3bYAvHvzzTfff/99hUKRn5//0ksvBQQEPPvss0eOHGm+pkgkEkCHEpFIRE8soPhFxDL0y6v0K94hKjDYGwVBqq6uTktLW7t27b59+xiGkUqlP/zww8SJExvvjfKdsfWOrqEtLxBjpUFvUPxiEmEvRxhQoyBsJSUls2bN2rRpk0ajKSkpEYvFAq5RIjq9if492RQYfP2x4cE9vhaJcEmiAODPHQhbcHDwhg0bAgMDdTrd/v37+Y7jsH7J7NRtxQPpeuWKc+fGM4yJ70Bwd6hREDyZTNa9e3ciqqys5DtLOxD1GB7ac4NU6q3TZRYUDLdamz6GEZwNahQExmazNRkpKys7fvw4EYWGhvKRqP0plVF9++bK5WFG48GCgliL5QrfieBOUKMgJAzDDBgw4Pnnn8/MzCwoKCgpKcnIyIiLizMYDBEREYMGDeI7YLuRy3v26bPP1fV+k+nk2bMxdXWFfCeC28IUEwhJSUnJiBEjioqKmowPGDAgPT09LCxM2DP1zdhsVYWFo43GXJnMLzx8m1KJh4E5I9QoCM+hQ4d27txZXFxsMpl8fX2HDRuWkJBgf8sIwzB79+51cXEZMmQI3zHbB8MYz517Sq/Plkjcw8Mz1OqhfCeCplCjAM6OZevPn59WVZUqEsnDwta7u4/nOxH8Cc6NglClpKS8//77JlPnvyRIJHIJDf3R23sGy9YVF6dUVKzkOxH8CfZGQZAuXLgQGhrq6el5/fp1Ydzx2R5KSxdcuTKXSBQU9Jmv79/5jgM3YW8UBCk3N5eIBg8e3HU6lIj8/OYEB39JJLp8+e0rV3DrvbPArWYgSA01yncQrvn4vCGVai9cmF5ausBm03XvvtS+M8SyVr0+y2g8aLGUS6WecnmYRvOkTBZo3+r8+aksa+3R43ux2LXJB5rNp69e/VAu7xkY+HHTL4N7gxoFQTpw4AARdZrp+Fbx8JgikWiKiyfU1uYzjFksVtbVFRUVJZnNp/+8orhnz5/d3ZOIqKoqlWWtISHLiJrWqMVSXlWVqlQORI22GWoUhMdkMh07dkwikdhfQNAFaTSJvXvvkst7icVKlrWdOzfebD6tUj3i5/euQtGHYWpNpmPV1ZvU6kf5TtoloEZBePLy8iwWy0MPPaRWq/nOwhuV6mZFGo25JtMJqdSrV68dEsnN150qlZGens/yl65rwRQTCE+XPTHaovr6C0Qkl4c3dChwDDUKwoMabUwq9SIik+mY2VzAd5YuCgf1IDAsy9prtGvOLzWnVsfKZL4WS9mZM4P8/N719JzSMEHfhMGwVyxWNRk0mY51fMZODpffg8AUFhb27t3bx8enrKyM7yzOwmg8WFz8dH39ZSISiSRubnFeXi+6u49tONw8ckTGstY7fIJSObBfv0NcZO2MsDcKAmPfFY2OjuY7iBNRqR4ZMKCoujrjxo01en22TrdVp9vq5jYyPDxDJLr1WunAwH+KRIom29bVnbt+fSm3eTsb1CgIDE6Mtkgkkmu1T2u1T1utlTdurL569T29fkdp6af+/h80rOPt/ZpEommyYU3NHtSogzDFBAKDGr0zqdTD1/evQUGfE1FVVRrfcboE1CgISU1NzYkTJ2QyWVRUFN9ZnJpK9TARWSylfAfpElCjICQHDx602WyRkZGurk1vaoTGjMZDRCSTBfAdpEtAjYKQmM2nHnooAEf0jen1v547N66qKs1mqyIihjFXVf105co/iMjDYxLf6boETDGBkPTtm718+dWQEFwxeovBkFNdnV5dnU5EYrGSYUxELBG5uyfjmaTcQI2CgLBG429E5OaGvdFbAgI+9PCYXFW1sbb2kMVSLha7yOW93N2TNZonG9bp3n0ZESMWK5tvrlD0CQlZIZV6cxi5s8Hl9yAYZvPpkyf7y2QBERF4bzs4EZwbBcEwGA4QEV6NCc4GNQqCYTTmEpFajSN6cC6oURAMgyGXiFQq1Cg4F9QoCIPNVm02nxGJ5Erlg3xnAfgT1CgIg8GQS8SoVAMbP2sDwBmgRkEY7CdGVSpcMQpOBzUKwoD5JXBaqFEQBMZozKNG73EDcB6oURAAk+mEzaaXy0NlMn++swA0hRoFAbBfeI8To+CcUKMgADgxCs4MNQoCgAvvwZmhRsHZWa0VdXVFYrHK1TWC7ywALcCD8sDZSSTdevXaZrFcFYnw6wrOCA/KAwBwCA7qwRmxrK2iYmVR0egTJ8KPHw85ezb60qU3amr+Y19aUbEiP1996dKMFrc9f35Kfr66qiqVw7zQpeEoCZwOy1qKipL0+iwikkp9pFKP2tpjBsMBvX77gAEF9hUYxsgw5hY3ZxgTwxhZ1sJpaOjCUKPgdK5fX6bXZ8lkvmFhaWp1DBERsUbjIau1nOdkAC1BjYLT0em2EJGf33v/7VAiEqlUg3iMBHAHODcKTsdm0xORRKLmOwjAPUGNgtORy8OIqKzsX/X1F/nOAnB3OKgHp+PjM7OqaqPJdPLUqQgPj0keHtPU6hbupjcaD12+/Lfm4ybT8Y7PCHALahScjkr1aHj4L5cv/81kOnH9+rLr15cpFP19fN7w9n6l8fGT2XzKbD7FY04AO9QoOCM3t5H9+x+vrc2vrFxbWZlqNp+6dOm12trfQ0K+bVjH3T05KOjT5tteuvSGXr+dw7DQ1aFGwXkplZFKZWRQ0Gfl5UtLSt6qqPjO2/s1pfIh+1KJpJtc3qv5VmKxituY0NVhigmcn9jH5003t1FEZDDs5TsMQFOoURAGqVRLRDabke8gAE2hRsHpWK0VzUfsN9S7uvbjIxHAneDcKDidwsKRLMtotU+5uj4okajM5oLy8kUWS5lC0VejGc13OoCmUKPgXBjGLJf30ekyrl79oPG4Wj0kNPRHkUjOVzCA28HzRsEZMYyhpmav2XyWZeslErVK9ahSGdVoaa3NphOLlRKJpvm2Nls1w5gkEq1YrOAwMnRdqFEAAIdgigkAwCGoUQAAh6BGAQAcghoFAHAIahQAwCGoUQAAh/w/5kM91lORMsgAAACyelRYdHJka2l0UEtMIHJka2l0IDIwMjMuMDkuNAAAeJx7v2/tPQYg4GdAAA4gZgPiBkY+hgyQADMjP4MGmMEGEWBiEoAwGBkxGVA1jMxMYJMYuBkYgTqAfAYGFgZGVgYmNgYRkLh4FkgV3NKrdkkOcie794E4N97Pc+BebWQHYt9eIO2wzcbOHsS+anfIoaksfD+I/dXa12Gy1z0w26NwpcPatRftIUY12J+Mmm7HgAbEANcxIdFj5zaeAAABGXpUWHRNT0wgcmRraXQgMjAyMy4wOS40AAB4nH2SSW7DMAxF9z7Fv0AEUgMlLbrwELRFG7uondygi+x7f5RKodhBjVJacHgEqQ81KPY5vF2/cTc7NA1A/9ycMy6OiJoTioPu+Pw6ol/armb66TwuMxJEO/Q8ku0ynWqGMV/hTHCBUsCBjI2OOIMM3Uy7L+37k6+8xQe8ySIxJoU4ep94Q1fOoYc13rFkp+WoHvsdzmOGGIrEkcr45FmC3QGDgs5Ykpg8DmyiEwl7k0XBYGwOpBuyEZeY3Q4XdUO+Z0uD7G6Y8PK1Sf91bgIduPLHcXhQ+FfzbhqHVfNy7CppCd2qnNbgV4FKNawylKqsj9VOxO3s7aQS13+hfvMDBK93Gkl4MxsAAACQelRYdFNNSUxFUyByZGtpdCAyMDIzLjA5LjQAAHicFY07CsQwEEOvsmUCk2Hk+dm4TJMykDKk3GLvkMOvXQjE4wnt+339jme51nPk4vv4Pp93AQmX5gGjXtgU0QbJUSZRdnVptA0pVdAmKhJZaQOnRjh14xaTCCPNKqgHSwpyzqohvFD3+SKVwKEVwxESWt8/z+gf06IW8IIAAAAASUVORK5CYII=", "text/plain": [ "" ] }, "execution_count": 77, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nodes = torch.tensor([7,4,1,6,6,6,1,5])\n", "edges = torch.tensor([[0,0,1,0,1,2],[1,2,3,4,5,6]])\n", "create_molecule_from_graph(nodes, edges)" ] }, { "cell_type": "code", "execution_count": 83, "metadata": {}, "outputs": [ { "ename": "TypeError", "evalue": "create_molecule_from_graph() takes 2 positional arguments but 3 were given", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[83], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m smiles \u001b[38;5;241m=\u001b[39m \u001b[43mdm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_train_smiles\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", "Cell \u001b[0;32mIn[75], line 167\u001b[0m, in \u001b[0;36mDataModule.get_train_smiles\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 165\u001b[0m graph \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_dataset[idx]\n\u001b[1;32m 166\u001b[0m \u001b[38;5;66;03m# print(graph.x, graph.edge_index)\u001b[39;00m\n\u001b[0;32m--> 167\u001b[0m mol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcreate_molecule_from_graph\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgraph\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgraph\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43medge_index\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m train_smiles\u001b[38;5;241m.\u001b[39mappend(Chem\u001b[38;5;241m.\u001b[39mMolToSmiles(mol))\n\u001b[1;32m 170\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtest_index:\n", "\u001b[0;31mTypeError\u001b[0m: create_molecule_from_graph() takes 2 positional arguments but 3 were given" ] } ], "source": [ "smiles = dm.get_train_smiles()" ] } ], "metadata": { "kernelspec": { "display_name": "graphdit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 2 }