888 lines
46 KiB
Plaintext
888 lines
46 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from nas_201_api import NASBench201API as API"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"api = API('./NAS-Bench-201-v1_1-096897.pth', verbose=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"results = api.query_by_index(1, 'cifar100')\n",
|
|
"print('There are {:} trials for this architecture [{:}] on CIFAR-100'.format(len(results), api[1]))\n",
|
|
"# \n",
|
|
"for seed, result in results.items():\n",
|
|
" print('Latency : {:}'.format(result.get_latency()))\n",
|
|
" print('Train Info : {:}'.format(result.get_train()))\n",
|
|
" print('Valid Info : {:}'.format(result.get_eval('x-valid')))\n",
|
|
" print('Test Info : {:}'.format(result.get_eval('x-test')))\n",
|
|
" print('')\n",
|
|
" print('Train Info [10-th epoch]: {:}'.format(result.get_train(10)))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 56,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import sys\n",
|
|
"sys.path.append('../') \n",
|
|
"\n",
|
|
"import os\n",
|
|
"import os.path as osp\n",
|
|
"import pathlib\n",
|
|
"import json\n",
|
|
"\n",
|
|
"import torch\n",
|
|
"import torch.nn.functional as F\n",
|
|
"from rdkit import Chem, RDLogger\n",
|
|
"from rdkit.Chem.rdchem import BondType as BT\n",
|
|
"from tqdm import tqdm\n",
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"from torch_geometric.data import Data, InMemoryDataset\n",
|
|
"from torch_geometric.loader import DataLoader\n",
|
|
"from sklearn.model_selection import train_test_split\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 57,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def random_data_split(task, dataset):\n",
|
|
" nan_count = torch.isnan(dataset.y[:, 0]).sum().item()\n",
|
|
" labeled_len = len(dataset) - nan_count\n",
|
|
" full_idx = list(range(labeled_len))\n",
|
|
" train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2\n",
|
|
" train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42)\n",
|
|
" train_index, val_index, _, _ = train_test_split(train_index, train_index, test_size=valid_ratio/(valid_ratio+train_ratio), random_state=42)\n",
|
|
" unlabeled_index = list(range(labeled_len, len(dataset)))\n",
|
|
" print(task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index), 'unlabeled len', len(unlabeled_index))\n",
|
|
" return train_index, val_index, test_index, unlabeled_index"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 58,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def parse_architecture_string(arch_str):\n",
|
|
" print(arch_str)\n",
|
|
" steps = arch_str.split('+')\n",
|
|
" nodes = ['input'] # Start with input node\n",
|
|
" edges = []\n",
|
|
" for i, step in enumerate(steps):\n",
|
|
" step = step.strip('|').split('|')\n",
|
|
" for node in step:\n",
|
|
" op, idx = node.split('~')\n",
|
|
" edges.append((int(idx), i+1)) # i+1 because 0 is input node\n",
|
|
" nodes.append(op)\n",
|
|
" nodes.append('output') # Add output node\n",
|
|
" return nodes, edges"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 59,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def create_adj_matrix_and_ops(nodes, edges):\n",
|
|
" num_nodes = len(nodes)\n",
|
|
" adj_matrix = np.zeros((num_nodes, num_nodes), dtype=int)\n",
|
|
" for (src, dst) in edges:\n",
|
|
" adj_matrix[src][dst] = 1\n",
|
|
" return adj_matrix, nodes"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"graphs = []\n",
|
|
"length = 15625\n",
|
|
"ops_type = {}\n",
|
|
"len_ops = set()\n",
|
|
"for i in range(length):\n",
|
|
" arch_info = api.query_meta_info_by_index(i)\n",
|
|
" nodes, edges = parse_architecture_string(arch_info.arch_str)\n",
|
|
" adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges) \n",
|
|
" if i < 5:\n",
|
|
" print(\"Adjacency Matrix:\")\n",
|
|
" print(adj_matrix)\n",
|
|
" print(\"Operations List:\")\n",
|
|
" print(ops)\n",
|
|
" for op in ops:\n",
|
|
" if op not in ops_type:\n",
|
|
" ops_type[op] = len(ops_type)\n",
|
|
" len_ops.add(len(ops))\n",
|
|
" graphs.append((adj_matrix, ops))\n",
|
|
"print(graphs[0])\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(len(ops_type))\n",
|
|
"print(len(len_ops))\n",
|
|
"print(ops_type)\n",
|
|
"print(len_ops)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 60,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"op_to_atom = {\n",
|
|
" 'input': 'Si', # Hydrogen for input\n",
|
|
" 'nor_conv_1x1': 'C', # Carbon for 1x1 convolution\n",
|
|
" 'nor_conv_3x3': 'N', # Nitrogen for 3x3 convolution\n",
|
|
" 'avg_pool_3x3': 'O', # Oxygen for 3x3 average pooling\n",
|
|
" 'skip_connect': 'P', # Phosphorus for skip connection\n",
|
|
" 'none': 'S', # Sulfur for no operation\n",
|
|
" 'output': 'He' # Helium for output\n",
|
|
"}\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 61,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def graphs_to_json(graphs, filename):\n",
|
|
" bonds = {\n",
|
|
" 'nor_conv_1x1': 1,\n",
|
|
" 'nor_conv_3x3': 2,\n",
|
|
" 'avg_pool_3x3': 3,\n",
|
|
" 'skip_connect': 4,\n",
|
|
" 'input': 7,\n",
|
|
" 'output': 5,\n",
|
|
" 'none': 6\n",
|
|
" }\n",
|
|
"\n",
|
|
" source_name = \"nas-bench-201\"\n",
|
|
" num_graph = len(graphs)\n",
|
|
" pt = Chem.GetPeriodicTable()\n",
|
|
" atom_name_list = []\n",
|
|
" atom_count_list = []\n",
|
|
" for i in range(2, 119):\n",
|
|
" atom_name_list.append(pt.GetElementSymbol(i))\n",
|
|
" atom_count_list.append(0)\n",
|
|
" atom_name_list.append('*')\n",
|
|
" atom_count_list.append(0)\n",
|
|
" n_atoms_per_mol = [0] * 500\n",
|
|
" bond_count_list = [0, 0, 0, 0, 0, 0, 0, 0]\n",
|
|
" bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}\n",
|
|
" valencies = [0] * 500\n",
|
|
" transition_E = np.zeros((118, 118, 5))\n",
|
|
"\n",
|
|
" n_atom_list = []\n",
|
|
" n_bond_list = []\n",
|
|
" # graphs = [(adj_matrix, ops), ...]\n",
|
|
" for graph in graphs:\n",
|
|
" ops = graph[1]\n",
|
|
" adj = graph[0]\n",
|
|
" n_atom = len(ops)\n",
|
|
" n_bond = len(ops)\n",
|
|
" n_atom_list.append(n_atom)\n",
|
|
" n_bond_list.append(n_bond)\n",
|
|
"\n",
|
|
" n_atoms_per_mol[n_atom] += 1\n",
|
|
" cur_atom_count_arr = np.zeros(118)\n",
|
|
"\n",
|
|
" for op in ops:\n",
|
|
" symbol = op_to_atom[op]\n",
|
|
" if symbol == 'H':\n",
|
|
" continue\n",
|
|
" elif symbol == '*':\n",
|
|
" atom_count_list[-1] += 1\n",
|
|
" cur_atom_count_arr[-1] += 1\n",
|
|
" else:\n",
|
|
" atom_count_list[pt.GetAtomicNumber(symbol)-2] += 1\n",
|
|
" cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2] += 1\n",
|
|
" # print('symbol', symbol)\n",
|
|
" # print('pt.GetDefaultValence(symbol)', pt.GetDefaultValence(symbol))\n",
|
|
" # print(f'cur_atom_count_arr[{pt.GetAtomicNumber(symbol)-2}], {cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2]}')\n",
|
|
" try:\n",
|
|
" valencies[int(pt.GetDefaultValence(symbol))] += 1\n",
|
|
" except:\n",
|
|
" print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol)))\n",
|
|
" transition_E_temp = np.zeros((118, 118, 5))\n",
|
|
" # print(n_atom)\n",
|
|
" for i in range(n_atom):\n",
|
|
" for j in range(n_atom):\n",
|
|
" if i == j or adj[i][j] == 0:\n",
|
|
" continue\n",
|
|
" start_atom, end_atom = i, j\n",
|
|
" # if ops[start_atom] == 'input' or ops[end_atom] == 'input':\n",
|
|
" # continue\n",
|
|
" # if ops[start_atom] == 'output' or ops[end_atom] == 'output':\n",
|
|
" # continue\n",
|
|
" # if ops[start_atom] == 'none' or ops[end_atom] == 'none':\n",
|
|
" # continue\n",
|
|
" \n",
|
|
" start_index = pt.GetAtomicNumber(op_to_atom[ops[start_atom]]) - 2\n",
|
|
" end_index = pt.GetAtomicNumber(op_to_atom[ops[end_atom]]) - 2\n",
|
|
" bond_index = bonds[ops[end_atom]]\n",
|
|
" bond_count_list[bond_index] += 2\n",
|
|
"\n",
|
|
" # print(start_index, end_index, bond_index)\n",
|
|
" \n",
|
|
" transition_E[start_index, end_index, bond_index] += 2\n",
|
|
" transition_E[end_index, start_index, bond_index] += 2\n",
|
|
" transition_E_temp[start_index, end_index, bond_index] += 2\n",
|
|
" transition_E_temp[end_index, start_index, bond_index] += 2\n",
|
|
"\n",
|
|
" bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2\n",
|
|
" print(bond_count_list)\n",
|
|
" cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2\n",
|
|
" # print(f'cur_tot_bond={cur_tot_bond}') \n",
|
|
" # find non-zero element in cur_tot_bond\n",
|
|
" # for i in range(118):\n",
|
|
" # for j in range(118):\n",
|
|
" # if cur_tot_bond[i][j] != 0:\n",
|
|
" # print(f'i={i}, j={j}, cur_tot_bond[i][j]={cur_tot_bond[i][j]}')\n",
|
|
" # n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)\n",
|
|
" cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2\n",
|
|
" # print(f\"transition_E[:,:,0]={cur_tot_bond - transition_E_temp.sum(axis=-1)}\")\n",
|
|
" transition_E[:, :, 0] += cur_tot_bond - transition_E_temp.sum(axis=-1)\n",
|
|
" # find non-zero element in transition_E\n",
|
|
" # for i in range(118):\n",
|
|
" # for j in range(118):\n",
|
|
" # if transition_E[i][j][0] != 0:\n",
|
|
" # print(f'i={i}, j={j}, transition_E[i][j][0]={transition_E[i][j][0]}')\n",
|
|
" assert (cur_tot_bond > transition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}'\n",
|
|
" \n",
|
|
" n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)\n",
|
|
" n_atoms_per_mol = n_atoms_per_mol.tolist()[:51]\n",
|
|
"\n",
|
|
" atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list)\n",
|
|
" print('processed meta info: ------', filename, '------')\n",
|
|
" print('len atom_count_list', len(atom_count_list))\n",
|
|
" print('len atom_name_list', len(atom_name_list))\n",
|
|
" active_atoms = np.array(atom_name_list)[atom_count_list > 0]\n",
|
|
" active_atoms = active_atoms.tolist()\n",
|
|
" atom_count_list = atom_count_list.tolist()\n",
|
|
"\n",
|
|
" bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list)\n",
|
|
" bond_count_list = bond_count_list.tolist()\n",
|
|
" valencies = np.array(valencies) / np.sum(valencies)\n",
|
|
" valencies = valencies.tolist()\n",
|
|
"\n",
|
|
" no_edge = np.sum(transition_E, axis=-1) == 0\n",
|
|
" for i in range(118):\n",
|
|
" for j in range(118):\n",
|
|
" if no_edge[i][j] == False:\n",
|
|
" print(f'have an edge at i={i} , j={j}, transition_E[i][j]={transition_E[i][j]}')\n",
|
|
" # print(f'no_edge: {no_edge}')\n",
|
|
" first_elt = transition_E[:, :, 0]\n",
|
|
" first_elt[no_edge] = 1\n",
|
|
" transition_E[:, :, 0] = first_elt\n",
|
|
"\n",
|
|
" transition_E = transition_E / np.sum(transition_E, axis=-1, keepdims=True)\n",
|
|
"\n",
|
|
" # find non-zero element in transition_E again\n",
|
|
" for i in range(118):\n",
|
|
" for j in range(118):\n",
|
|
" if transition_E[i][j][0] != 0 and transition_E[i][j][0] != 1 and transition_E[i][j][0] != -1:\n",
|
|
" print(f'i={i}, j={j}, 2_transition_E[i][j][0]={transition_E[i][j][0]}')\n",
|
|
"\n",
|
|
" meta_dict = {\n",
|
|
" 'source': 'nasbench-201',\n",
|
|
" 'num_graph': num_graph,\n",
|
|
" 'n_atoms_per_mol_dist': n_atoms_per_mol[:51],\n",
|
|
" 'max_node': max(n_atom_list),\n",
|
|
" 'max_bond': max(n_bond_list),\n",
|
|
" 'atom_type_dist': atom_count_list,\n",
|
|
" 'bond_type_dist': bond_count_list,\n",
|
|
" 'valencies': valencies,\n",
|
|
" 'active_atoms': [atom_name_list[i] for i in range(118) if atom_count_list[i] > 0],\n",
|
|
" 'num_atom_type': len([atom_name_list[i] for i in range(118) if atom_count_list[i] > 0]),\n",
|
|
" 'transition_E': transition_E.tolist(),\n",
|
|
" }\n",
|
|
"\n",
|
|
" with open(f'{filename}.meta.json', 'w') as f:\n",
|
|
" json.dump(meta_dict, f)\n",
|
|
" return meta_dict\n",
|
|
"\n",
|
|
" "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"graphs_to_json(graphs, 'nasbench-201')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 62,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def gen_adj_matrix_and_ops(nasbench):\n",
|
|
" i = 0\n",
|
|
" epoch = 108\n",
|
|
"\n",
|
|
" for unique_hash in nasbench.hash_iterator():\n",
|
|
" fixed_metrics, computed_metrics = nasbench.get_metrics_from_hash(unique_hash)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 63,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"from torch_geometric.data import InMemoryDataset, Data\n",
|
|
"import os.path as osp\n",
|
|
"import pandas as pd\n",
|
|
"from tqdm import tqdm\n",
|
|
"import networkx as nx\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"class Dataset(InMemoryDataset):\n",
|
|
" def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):\n",
|
|
" self.target_prop = target_prop\n",
|
|
" source = './NAS-Bench-201-v1_1-096897.pth'\n",
|
|
" self.source = source\n",
|
|
" self.api = API(source) # Initialize NAS-Bench-201 API\n",
|
|
" super().__init__(root, transform, pre_transform, pre_filter)\n",
|
|
" self.data, self.slices = torch.load(self.processed_paths[0])\n",
|
|
"\n",
|
|
" @property\n",
|
|
" def raw_file_names(self):\n",
|
|
" return [] # NAS-Bench-201 data is loaded via the API, no raw files needed\n",
|
|
" \n",
|
|
" @property\n",
|
|
" def processed_file_names(self):\n",
|
|
" return [f'{self.source}.pt']\n",
|
|
"\n",
|
|
" def process(self):\n",
|
|
" def parse_architecture_string(arch_str):\n",
|
|
" stages = arch_str.split('+')\n",
|
|
" nodes = ['input']\n",
|
|
" edges = []\n",
|
|
" \n",
|
|
" for stage in stages:\n",
|
|
" operations = stage.strip('|').split('|')\n",
|
|
" for op in operations:\n",
|
|
" operation, idx = op.split('~')\n",
|
|
" idx = int(idx)\n",
|
|
" edges.append((idx, len(nodes))) # Add edge from idx to the new node\n",
|
|
" nodes.append(operation)\n",
|
|
" nodes.append('output') # Add the output node\n",
|
|
" return nodes, edges\n",
|
|
"\n",
|
|
" def create_graph(nodes, edges):\n",
|
|
" G = nx.DiGraph()\n",
|
|
" for i, node in enumerate(nodes):\n",
|
|
" G.add_node(i, label=node)\n",
|
|
" G.add_edges_from(edges)\n",
|
|
" return G\n",
|
|
"\n",
|
|
" def arch_to_graph(arch_str, sa, sc, target, target2=None, target3=None):\n",
|
|
" nodes, edges = parse_architecture_string(arch_str)\n",
|
|
"\n",
|
|
" node_labels = [bonds[node] for node in nodes] # Replace with appropriate encoding if necessary\n",
|
|
" assert 0 not in node_labels, f'Invalid node label: {node_labels}'\n",
|
|
" x = torch.LongTensor(node_labels)\n",
|
|
"\n",
|
|
" edges_list = [(start, end) for start, end in edges]\n",
|
|
" edge_type = [bonds[nodes[end]] for start, end in edges] # Example: using end node type as edge type\n",
|
|
" edge_index = torch.tensor(edges_list, dtype=torch.long).t().contiguous()\n",
|
|
" edge_type = torch.tensor(edge_type, dtype=torch.long)\n",
|
|
" edge_attr = edge_type.view(-1, 1)\n",
|
|
"\n",
|
|
" if target3 is not None:\n",
|
|
" y = torch.tensor([sa, sc, target, target2, target3], dtype=torch.float).view(1, -1)\n",
|
|
" elif target2 is not None:\n",
|
|
" y = torch.tensor([sa, sc, target, target2], dtype=torch.float).view(1, -1)\n",
|
|
" else:\n",
|
|
" y = torch.tensor([sa, sc, target], dtype=torch.float).view(1, -1)\n",
|
|
"\n",
|
|
" data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)\n",
|
|
" return data, nodes\n",
|
|
"\n",
|
|
" bonds = {\n",
|
|
" 'nor_conv_1x1': 1,\n",
|
|
" 'nor_conv_3x3': 2,\n",
|
|
" 'avg_pool_3x3': 3,\n",
|
|
" 'skip_connect': 4,\n",
|
|
" 'output': 5,\n",
|
|
" 'none': 6,\n",
|
|
" 'input': 7\n",
|
|
" }\n",
|
|
"\n",
|
|
" # Prepare to process NAS-Bench-201 data\n",
|
|
" data_list = []\n",
|
|
" len_data = len(self.api) # Number of architectures\n",
|
|
" with tqdm(total=len_data) as pbar:\n",
|
|
" for arch_index in range(len_data):\n",
|
|
" arch_info = self.api.query_meta_info_by_index(arch_index)\n",
|
|
" arch_str = arch_info.arch_str\n",
|
|
" sa = np.random.rand() # Placeholder for synthetic accessibility\n",
|
|
" sc = np.random.rand() # Placeholder for substructure count\n",
|
|
" target = np.random.rand() # Placeholder for target value\n",
|
|
" target2 = np.random.rand() # Placeholder for second target value\n",
|
|
" target3 = np.random.rand() # Placeholder for third target value\n",
|
|
"\n",
|
|
" data, active_nodes = arch_to_graph(arch_str, sa, sc, target, target2, target3)\n",
|
|
" data_list.append(data)\n",
|
|
" pbar.update(1)\n",
|
|
"\n",
|
|
" torch.save(self.collate(data_list), self.processed_paths[0])\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# dataset = Dataset(source='./NAS-Bench-201-v1_1-096897.pth', root='./data')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 84,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import pathlib\n",
|
|
"import torch\n",
|
|
"from torch_geometric.data import DataLoader\n",
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"import pandas as pd\n",
|
|
"from tqdm import tqdm\n",
|
|
"# import nas_bench_201 as nb201\n",
|
|
"\n",
|
|
"import utils as utils\n",
|
|
"from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule\n",
|
|
"from diffusion.distributions import DistributionNodes\n",
|
|
"\n",
|
|
"from rdkit.Chem import rdchem\n",
|
|
"\n",
|
|
"class DataModule(AbstractDataModule):\n",
|
|
" def __init__(self, cfg):\n",
|
|
" self.datadir = cfg.dataset.datadir\n",
|
|
" self.task = cfg.dataset.task_name\n",
|
|
" print(\"DataModule\")\n",
|
|
" print(\"task\", self.task)\n",
|
|
" print(\"datadir\", self.datadir)\n",
|
|
" super().__init__(cfg)\n",
|
|
"\n",
|
|
" def prepare_data(self) -> None:\n",
|
|
" target = getattr(self.cfg.dataset, 'guidance_target', None)\n",
|
|
" print(\"target\", target)\n",
|
|
" # try:\n",
|
|
" # base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]\n",
|
|
" # except NameError:\n",
|
|
" # base_path = pathlib.Path(os.getcwd()).parent[2]\n",
|
|
" base_path = '/home/stud/hanzhang/Graph-Dit'\n",
|
|
" root_path = os.path.join(base_path, self.datadir)\n",
|
|
" self.root_path = root_path\n",
|
|
"\n",
|
|
" batch_size = self.cfg.train.batch_size\n",
|
|
" \n",
|
|
" num_workers = self.cfg.train.num_workers\n",
|
|
" pin_memory = self.cfg.dataset.pin_memory\n",
|
|
"\n",
|
|
" # Load the dataset to the memory\n",
|
|
" # Dataset has target property, root path, and transform\n",
|
|
" source = './NAS-Bench-201-v1_1-096897.pth'\n",
|
|
" dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None)\n",
|
|
"\n",
|
|
" # if len(self.task.split('-')) == 2:\n",
|
|
" # train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)\n",
|
|
" # else:\n",
|
|
" train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset)\n",
|
|
"\n",
|
|
" self.train_index, self.val_index, self.test_index, self.unlabeled_index = train_index, val_index, test_index, unlabeled_index\n",
|
|
" train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index)\n",
|
|
" if len(unlabeled_index) > 0:\n",
|
|
" train_index = torch.cat([train_index, unlabeled_index], dim=0)\n",
|
|
" \n",
|
|
" train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index]\n",
|
|
" self.train_dataset = train_dataset \n",
|
|
" print('train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))\n",
|
|
" print('train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))\n",
|
|
" print('dataset len', len(dataset), 'train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))\n",
|
|
" self.train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=pin_memory)\n",
|
|
"\n",
|
|
" self.val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)\n",
|
|
" self.test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)\n",
|
|
"\n",
|
|
" training_iterations = len(train_dataset) // batch_size\n",
|
|
" self.training_iterations = training_iterations\n",
|
|
" \n",
|
|
" def random_data_split(self, dataset):\n",
|
|
" nan_count = torch.isnan(dataset.data.y[:, 0]).sum().item()\n",
|
|
" labeled_len = len(dataset) - nan_count\n",
|
|
" full_idx = list(range(labeled_len))\n",
|
|
" train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2\n",
|
|
" train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42)\n",
|
|
" train_index, val_index, _, _ = train_test_split(train_index, train_index, test_size=valid_ratio/(valid_ratio+train_ratio), random_state=42)\n",
|
|
" unlabeled_index = list(range(labeled_len, len(dataset)))\n",
|
|
" print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index), 'unlabeled len', len(unlabeled_index))\n",
|
|
" return train_index, val_index, test_index, unlabeled_index\n",
|
|
" \n",
|
|
" def fixed_split(self, dataset):\n",
|
|
" if self.task == 'O2-N2':\n",
|
|
" test_index = [42,43,92,122,197,198,251,254,257,355,511,512,549,602,603,604]\n",
|
|
" else:\n",
|
|
" raise ValueError('Invalid task name: {}'.format(self.task))\n",
|
|
" full_idx = list(range(len(dataset)))\n",
|
|
" full_idx = list(set(full_idx) - set(test_index))\n",
|
|
" train_ratio = 0.8\n",
|
|
" train_index, val_index, _, _ = train_test_split(full_idx, full_idx, test_size=1-train_ratio, random_state=42)\n",
|
|
" print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))\n",
|
|
" return train_index, val_index, test_index, []\n",
|
|
"\n",
|
|
" def parse_architecture_string(arch_str):\n",
|
|
" stages = arch_str.split('+')\n",
|
|
" nodes = ['input']\n",
|
|
" edges = []\n",
|
|
" \n",
|
|
" for stage in stages:\n",
|
|
" operations = stage.strip('|').split('|')\n",
|
|
" for op in operations:\n",
|
|
" operation, idx = op.split('~')\n",
|
|
" idx = int(idx)\n",
|
|
" edges.append((idx, len(nodes))) # Add edge from idx to the new node\n",
|
|
" nodes.append(operation)\n",
|
|
" nodes.append('output') # Add the output node\n",
|
|
" return nodes, edges\n",
|
|
"\n",
|
|
" def create_molecule_from_graph(nodes, edges):\n",
|
|
" mol = Chem.RWMol() # RWMol allows for building the molecule step by step\n",
|
|
" atom_indices = {}\n",
|
|
" num_to_op = {\n",
|
|
" 1 :'nor_conv_1x1',\n",
|
|
" 2 :'nor_conv_3x3',\n",
|
|
" 3 :'avg_pool_3x3',\n",
|
|
" 4 :'skip_connect',\n",
|
|
" 5 :'output',\n",
|
|
" 6 :'none',\n",
|
|
" 7 :'input'\n",
|
|
" } \n",
|
|
"\n",
|
|
" # Extract node operations from the data object\n",
|
|
"\n",
|
|
" # Add atoms to the molecule\n",
|
|
" for i, op_tensor in enumerate(nodes):\n",
|
|
" op = op_tensor.item()\n",
|
|
" if op == 0: continue\n",
|
|
" op = num_to_op[op]\n",
|
|
" atom_symbol = op_to_atom[op]\n",
|
|
" atom = Chem.Atom(atom_symbol)\n",
|
|
" atom_idx = mol.AddAtom(atom)\n",
|
|
" atom_indices[i] = atom_idx\n",
|
|
" \n",
|
|
" # Add bonds to the molecule\n",
|
|
" edge_number = edges.shape[1]\n",
|
|
" for i in range(edge_number):\n",
|
|
" start = edges[0, i].item()\n",
|
|
" end = edges[1, i].item()\n",
|
|
" mol.AddBond(atom_indices[start], atom_indices[end], rdchem.BondType.SINGLE)\n",
|
|
" \n",
|
|
" return mol\n",
|
|
"\n",
|
|
" def arch_str_to_smiles(self, arch_str):\n",
|
|
" nodes, edges = self.parse_architecture_string(arch_str)\n",
|
|
" mol = self.create_molecule_from_graph(nodes, edges)\n",
|
|
" smiles = Chem.MolToSmiles(mol)\n",
|
|
" return smiles\n",
|
|
"\n",
|
|
" def get_train_smiles(self):\n",
|
|
" # raise NotImplementedError(\"This method is not applicable for NAS-Bench-201 data.\")\n",
|
|
" # train_arch_strs = []\n",
|
|
" # test_arch_strs = []\n",
|
|
"\n",
|
|
" # for idx in self.train_index:\n",
|
|
" # arch_info = self.train_dataset[idx]\n",
|
|
" # arch_str = arch_info.arch_str\n",
|
|
" # train_arch_strs.append(arch_str)\n",
|
|
" # for idx in self.test_index:\n",
|
|
" # arch_info = self.train_dataset[idx]\n",
|
|
" # arch_str = arch_info.arch_str\n",
|
|
" # test_arch_strs.append(arch_str)\n",
|
|
"\n",
|
|
" train_smiles = [] \n",
|
|
" test_smiles = []\n",
|
|
"\n",
|
|
" for idx in self.train_index:\n",
|
|
" graph = self.train_dataset[idx]\n",
|
|
" print(graph.x, graph.edge_index)\n",
|
|
" print(f'class of graph.x: {graph.x.__class__}, class of graph.edge_index: {graph.edge_index.__class__}')\n",
|
|
" mol = self.create_molecule_from_graph(graph.x, graph.edge_index)\n",
|
|
" train_smiles.append(Chem.MolToSmiles(mol))\n",
|
|
" \n",
|
|
" for idx in self.test_index:\n",
|
|
" graph = self.train_dataset[idx]\n",
|
|
" mol = self.create_molecule_from_graph(graph.x, graph.edge_index)\n",
|
|
" test_smiles.append(Chem.MolToSmiles(mol))\n",
|
|
" \n",
|
|
" # train_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in train_arch_strs]\n",
|
|
" # test_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in test_arch_strs]\n",
|
|
" return train_smiles, test_smiles\n",
|
|
"\n",
|
|
" def get_data_split(self):\n",
|
|
" raise NotImplementedError(\"This method is not applicable for NAS-Bench-201 data.\")\n",
|
|
"\n",
|
|
" def example_batch(self):\n",
|
|
" return next(iter(self.val_loader))\n",
|
|
" \n",
|
|
" def train_dataloader(self):\n",
|
|
" return self.train_loader\n",
|
|
"\n",
|
|
" def val_dataloader(self):\n",
|
|
" return self.val_loader\n",
|
|
" \n",
|
|
" def test_dataloader(self):\n",
|
|
" return self.test_loader\n",
|
|
"\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 85,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from omegaconf import DictConfig, OmegaConf\n",
|
|
"import argparse\n",
|
|
"import hydra\n",
|
|
"\n",
|
|
"def parse_arg():\n",
|
|
" parser = argparse.ArgumentParser(description='Diffusion')\n",
|
|
" parser.add_argument('--config', type=str, default='config.yaml', help='config file')\n",
|
|
" return parser.parse_args()\n",
|
|
"\n",
|
|
"def task1(cfg: DictConfig):\n",
|
|
" datamodule = DataModule(cfg=cfg)\n",
|
|
" datamodule.prepare_data()\n",
|
|
" return datamodule\n",
|
|
"\n",
|
|
"cfg = {\n",
|
|
" 'general':{\n",
|
|
" 'name': 'graph_dit',\n",
|
|
" 'wandb': 'disabled' ,\n",
|
|
" 'gpus': 1,\n",
|
|
" 'resume': 'null',\n",
|
|
" 'test_only': 'null',\n",
|
|
" 'sample_every_val': 2500,\n",
|
|
" 'samples_to_generate': 512,\n",
|
|
" 'samples_to_save': 3,\n",
|
|
" 'chains_to_save': 1,\n",
|
|
" 'log_every_steps': 50,\n",
|
|
" 'number_chain_steps': 8,\n",
|
|
" 'final_model_samples_to_generate': 10000,\n",
|
|
" 'final_model_samples_to_save': 20,\n",
|
|
" 'final_model_chains_to_save': 1,\n",
|
|
" 'enable_progress_bar': False,\n",
|
|
" 'save_model': True,\n",
|
|
" },\n",
|
|
" 'model':{\n",
|
|
" 'type': 'discrete',\n",
|
|
" 'transition': 'marginal',\n",
|
|
" 'model': 'graph_dit',\n",
|
|
" 'diffusion_steps': 500,\n",
|
|
" 'diffusion_noise_schedule': 'cosine',\n",
|
|
" 'guide_scale': 2,\n",
|
|
" 'hidden_size': 1152,\n",
|
|
" 'depth': 6,\n",
|
|
" 'num_heads': 16,\n",
|
|
" 'mlp_ratio': 4,\n",
|
|
" 'drop_condition': 0.01,\n",
|
|
" 'lambda_train': [1, 10], # node and edge training weight \n",
|
|
" 'ensure_connected': True,\n",
|
|
" },\n",
|
|
" 'train':{\n",
|
|
" 'n_epochs': 10000,\n",
|
|
" 'batch_size': 1200,\n",
|
|
" 'lr': 0.0002,\n",
|
|
" 'clip_grad': 'null',\n",
|
|
" 'num_workers': 0,\n",
|
|
" 'weight_decay': 0,\n",
|
|
" 'seed': 0,\n",
|
|
" 'val_check_interval': 'null',\n",
|
|
" 'check_val_every_n_epoch': 1,\n",
|
|
" },\n",
|
|
" 'dataset':{\n",
|
|
" 'datadir': 'data',\n",
|
|
" 'task_name': 'nasbench-201',\n",
|
|
" 'guidance_target': 'nasbench-201',\n",
|
|
" 'pin_memory': False,\n",
|
|
" },\n",
|
|
"}\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 86,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"DataModule\n",
|
|
"task nasbench-201\n",
|
|
"datadir data\n",
|
|
"target nasbench-201\n",
|
|
"try to create the NAS-Bench-201 api from ./NAS-Bench-201-v1_1-096897.pth\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"cfg = OmegaConf.create(cfg)\n",
|
|
"dm = task1(cfg)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 73,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"def create_molecule_from_graph(nodes, edges):\n",
|
|
" mol = Chem.RWMol() # RWMol allows for building the molecule step by step\n",
|
|
" atom_indices = {}\n",
|
|
" num_to_op = {\n",
|
|
" 1 :'nor_conv_1x1',\n",
|
|
" 2 :'nor_conv_3x3',\n",
|
|
" 3 :'avg_pool_3x3',\n",
|
|
" 4 :'skip_connect',\n",
|
|
" 5 :'output',\n",
|
|
" 6 :'none',\n",
|
|
" 7 :'input'\n",
|
|
" } \n",
|
|
"\n",
|
|
" # Extract node operations from the data object\n",
|
|
"\n",
|
|
" # Add atoms to the molecule\n",
|
|
" for i, op_tensor in enumerate(nodes):\n",
|
|
" op = op_tensor.item()\n",
|
|
" if op == 0: continue\n",
|
|
" op = num_to_op[op]\n",
|
|
" atom_symbol = op_to_atom[op]\n",
|
|
" atom = Chem.Atom(atom_symbol)\n",
|
|
" atom_idx = mol.AddAtom(atom)\n",
|
|
" atom_indices[i] = atom_idx\n",
|
|
" \n",
|
|
" # Add bonds to the molecule\n",
|
|
" edge_number = edges.shape[1]\n",
|
|
" for i in range(edge_number):\n",
|
|
" start = edges[0, i].item()\n",
|
|
" end = edges[1, i].item()\n",
|
|
" mol.AddBond(atom_indices[start], atom_indices[end], rdchem.BondType.SINGLE)\n",
|
|
" \n",
|
|
" return mol"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 77,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<rdkit.Chem.rdchem.RWMol at 0x7eff28309400>"
|
|
]
|
|
},
|
|
"execution_count": 77,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"nodes = torch.tensor([7,4,1,6,6,6,1,5])\n",
|
|
"edges = torch.tensor([[0,0,1,0,1,2],[1,2,3,4,5,6]])\n",
|
|
"create_molecule_from_graph(nodes, edges)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 83,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"ename": "TypeError",
|
|
"evalue": "create_molecule_from_graph() takes 2 positional arguments but 3 were given",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
|
|
"Cell \u001b[0;32mIn[83], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m smiles \u001b[38;5;241m=\u001b[39m \u001b[43mdm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_train_smiles\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
|
"Cell \u001b[0;32mIn[75], line 167\u001b[0m, in \u001b[0;36mDataModule.get_train_smiles\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 165\u001b[0m graph \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_dataset[idx]\n\u001b[1;32m 166\u001b[0m \u001b[38;5;66;03m# print(graph.x, graph.edge_index)\u001b[39;00m\n\u001b[0;32m--> 167\u001b[0m mol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcreate_molecule_from_graph\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgraph\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgraph\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43medge_index\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m train_smiles\u001b[38;5;241m.\u001b[39mappend(Chem\u001b[38;5;241m.\u001b[39mMolToSmiles(mol))\n\u001b[1;32m 170\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtest_index:\n",
|
|
"\u001b[0;31mTypeError\u001b[0m: create_molecule_from_graph() takes 2 positional arguments but 3 were given"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"smiles = dm.get_train_smiles()"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "graphdit",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.19"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|