diff --git a/configs/config.yaml b/configs/config.yaml index e3fade7..234f679 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -41,6 +41,6 @@ train: check_val_every_n_epoch: 1 dataset: datadir: 'data/' - task_name: null - guidance_target: null + task_name: 'nasbench-201' + guidance_target: 'nasbench-201' pin_memory: False diff --git a/graph_dit/datasets/abstract_dataset.py b/graph_dit/datasets/abstract_dataset.py index c8e82c5..63f1ea5 100644 --- a/graph_dit/datasets/abstract_dataset.py +++ b/graph_dit/datasets/abstract_dataset.py @@ -116,7 +116,7 @@ class AbstractDatasetInfos: def compute_input_output_dims(self, datamodule): example_batch = datamodule.example_batch() example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=118).float()[:, self.active_index] - example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=5).float() + example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=10).float() self.input_dims = {'X': example_batch_x.size(1), 'E': example_batch_edge_attr.size(1), diff --git a/graph_dit/datasets/dataset.py b/graph_dit/datasets/dataset.py index c3c0fc9..225969d 100644 --- a/graph_dit/datasets/dataset.py +++ b/graph_dit/datasets/dataset.py @@ -81,6 +81,7 @@ class DataModule(AbstractDataModule): train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index] self.train_dataset = train_dataset + self.test_dataset = test_dataset print('train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset)) print('train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index)) print('dataset len', len(dataset), 'train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset)) @@ -93,8 +94,9 @@ class DataModule(AbstractDataModule): self.training_iterations = training_iterations def random_data_split(self, dataset): - nan_count = torch.isnan(dataset.data.y[:, 0]).sum().item() - labeled_len = len(dataset) - nan_count + # nan_count = torch.isnan(dataset.data.y[:, 0]).sum().item() + # labeled_len = len(dataset) - nan_count + labeled_len = len(dataset) full_idx = list(range(labeled_len)) train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2 train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42) @@ -115,7 +117,7 @@ class DataModule(AbstractDataModule): print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index)) return train_index, val_index, test_index, [] - def parse_architecture_string(arch_str): + def parse_architecture_string(self, arch_str): stages = arch_str.split('+') nodes = ['input'] edges = [] @@ -130,19 +132,39 @@ class DataModule(AbstractDataModule): nodes.append('output') # Add the output node return nodes, edges - def create_molecule_from_graph(nodes, edges): + # def create_molecule_from_graph(nodes, edges): + def create_molecule_from_graph(self, graph): + nodes = graph.x + edges = graph.edge_index mol = Chem.RWMol() # RWMol allows for building the molecule step by step atom_indices = {} - + num_to_op = { + 1 :'nor_conv_1x1', + 2 :'nor_conv_3x3', + 3 :'avg_pool_3x3', + 4 :'skip_connect', + 5 :'output', + 6 :'none', + 7 :'input' + } + + # Extract node operations from the data object + # Add atoms to the molecule - for i, node in enumerate(nodes): - atom_symbol = op_to_atom[node] + for i, op_tensor in enumerate(nodes): + op = op_tensor.item() + if op == 0: continue + op = num_to_op[op] + atom_symbol = op_to_atom[op] atom = Chem.Atom(atom_symbol) atom_idx = mol.AddAtom(atom) atom_indices[i] = atom_idx # Add bonds to the molecule - for start, end in edges: + edge_number = edges.shape[1] + for i in range(edge_number): + start = edges[0, i].item() + end = edges[1, i].item() mol.AddBond(atom_indices[start], atom_indices[end], rdchem.BondType.SINGLE) return mol @@ -154,30 +176,23 @@ class DataModule(AbstractDataModule): return smiles def get_train_smiles(self): - # raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.") - # train_arch_strs = [] - # test_arch_strs = [] - - # for idx in self.train_index: - # arch_info = self.train_dataset[idx] - # arch_str = arch_info.arch_str - # train_arch_strs.append(arch_str) - # for idx in self.test_index: - # arch_info = self.train_dataset[idx] - # arch_str = arch_info.arch_str - # test_arch_strs.append(arch_str) - train_smiles = [] test_smiles = [] - for idx in self.train_index: - graph = self.train_dataset[idx] - mol = self.create_molecule_from_graph(graph.x, graph.edge_index) + for graph in self.train_dataset: + # print(f'idx={idx}') + # graph = self.train_dataset[idx] + print(graph.x) + print(graph.edge_index) + print(f'class of graph.x: {graph.x.__class__}, class of graph.edge_index: {graph.edge_index.__class__}') + mol = self.create_molecule_from_graph(graph) train_smiles.append(Chem.MolToSmiles(mol)) - for idx in self.test_index: - graph = self.train_dataset[idx] - mol = self.create_molecule_from_graph(graph.x, graph.edge_index) + # for idx in self.test_index: + for graph in self.test_dataset: + # graph = self.dataset[idx] + # mol = self.create_molecule_from_graph(graph.x, graph.edge_index) + mol = self.create_molecule_from_graph(graph) test_smiles.append(Chem.MolToSmiles(mol)) # train_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in train_arch_strs] @@ -199,161 +214,8 @@ class DataModule(AbstractDataModule): def test_dataloader(self): return self.test_loader -def graphs_to_json(graphs, filename): - bonds = { - 'nor_conv_1x1': 1, - 'nor_conv_3x3': 2, - 'avg_pool_3x3': 3, - 'skip_connect': 4, - 'input': 7, - 'output': 5, - 'none': 6 - } - source_name = "nas-bench-201" - num_graph = len(graphs) - pt = Chem.GetPeriodicTable() - atom_name_list = [] - atom_count_list = [] - for i in range(2, 119): - atom_name_list.append(pt.GetElementSymbol(i)) - atom_count_list.append(0) - atom_name_list.append('*') - atom_count_list.append(0) - n_atoms_per_mol = [0] * 500 - bond_count_list = [0, 0, 0, 0, 0, 0, 0, 0] - bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} - valencies = [0] * 500 - transition_E = np.zeros((118, 118, 5)) - n_atom_list = [] - n_bond_list = [] - # graphs = [(adj_matrix, ops), ...] - for graph in graphs: - ops = graph[1] - adj = graph[0] - n_atom = len(ops) - n_bond = len(ops) - n_atom_list.append(n_atom) - n_bond_list.append(n_bond) - - n_atoms_per_mol[n_atom] += 1 - cur_atom_count_arr = np.zeros(118) - - for op in ops: - symbol = op_to_atom[op] - if symbol == 'H': - continue - elif symbol == '*': - atom_count_list[-1] += 1 - cur_atom_count_arr[-1] += 1 - else: - atom_count_list[pt.GetAtomicNumber(symbol)-2] += 1 - cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2] += 1 - # print('symbol', symbol) - # print('pt.GetDefaultValence(symbol)', pt.GetDefaultValence(symbol)) - # print(f'cur_atom_count_arr[{pt.GetAtomicNumber(symbol)-2}], {cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2]}') - try: - valencies[int(pt.GetDefaultValence(symbol))] += 1 - except: - print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol))) - transition_E_temp = np.zeros((118, 118, 5)) - # print(n_atom) - for i in range(n_atom): - for j in range(n_atom): - if i == j or adj[i][j] == 0: - continue - start_atom, end_atom = i, j - if ops[start_atom] == 'input' or ops[end_atom] == 'input': - continue - if ops[start_atom] == 'output' or ops[end_atom] == 'output': - continue - if ops[start_atom] == 'none' or ops[end_atom] == 'none': - continue - - start_index = pt.GetAtomicNumber(op_to_atom[ops[start_atom]]) - 2 - end_index = pt.GetAtomicNumber(op_to_atom[ops[end_atom]]) - 2 - bond_index = bonds[ops[end_atom]] - bond_count_list[bond_index] += 2 - - # print(start_index, end_index, bond_index) - - transition_E[start_index, end_index, bond_index] += 2 - transition_E[end_index, start_index, bond_index] += 2 - transition_E_temp[start_index, end_index, bond_index] += 2 - transition_E_temp[end_index, start_index, bond_index] += 2 - - bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2 - print(bond_count_list) - cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2 - # print(f'cur_tot_bond={cur_tot_bond}') - # find non-zero element in cur_tot_bond - # for i in range(118): - # for j in range(118): - # if cur_tot_bond[i][j] != 0: - # print(f'i={i}, j={j}, cur_tot_bond[i][j]={cur_tot_bond[i][j]}') - # n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol) - cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2 - # print(f"transition_E[:,:,0]={cur_tot_bond - transition_E_temp.sum(axis=-1)}") - transition_E[:, :, 0] += cur_tot_bond - transition_E_temp.sum(axis=-1) - # find non-zero element in transition_E - # for i in range(118): - # for j in range(118): - # if transition_E[i][j][0] != 0: - # print(f'i={i}, j={j}, transition_E[i][j][0]={transition_E[i][j][0]}') - assert (cur_tot_bond > transition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}' - - n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol) - n_atoms_per_mol = n_atoms_per_mol.tolist()[:51] - - atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list) - print('processed meta info: ------', filename, '------') - print('len atom_count_list', len(atom_count_list)) - print('len atom_name_list', len(atom_name_list)) - active_atoms = np.array(atom_name_list)[atom_count_list > 0] - active_atoms = active_atoms.tolist() - atom_count_list = atom_count_list.tolist() - - bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list) - bond_count_list = bond_count_list.tolist() - valencies = np.array(valencies) / np.sum(valencies) - valencies = valencies.tolist() - - no_edge = np.sum(transition_E, axis=-1) == 0 - for i in range(118): - for j in range(118): - if no_edge[i][j] == False: - print(f'have an edge at i={i} , j={j}, transition_E[i][j]={transition_E[i][j]}') - # print(f'no_edge: {no_edge}') - first_elt = transition_E[:, :, 0] - first_elt[no_edge] = 1 - transition_E[:, :, 0] = first_elt - - transition_E = transition_E / np.sum(transition_E, axis=-1, keepdims=True) - - # find non-zero element in transition_E again - for i in range(118): - for j in range(118): - if transition_E[i][j][0] != 0 and transition_E[i][j][0] != 1 and transition_E[i][j][0] != -1: - print(f'i={i}, j={j}, 2_transition_E[i][j][0]={transition_E[i][j][0]}') - - meta_dict = { - 'source': 'nasbench-201', - 'num_graph': num_graph, - 'n_atoms_per_mol_dist': n_atoms_per_mol[:51], - 'max_node': max(n_atom_list), - 'max_bond': max(n_bond_list), - 'atom_type_dist': atom_count_list, - 'bond_type_dist': bond_count_list, - 'valencies': valencies, - 'active_atoms': [atom_name_list[i] for i in range(118) if atom_count_list[i] > 0], - 'num_atom_type': len([atom_name_list[i] for i in range(118) if atom_count_list[i] > 0]), - 'transition_E': transition_E.tolist(), - } - - with open(f'{filename}.meta.json', 'w') as f: - json.dump(meta_dict, f) - return meta_dict class DataModule_original(AbstractDataModule): def __init__(self, cfg): @@ -482,7 +344,7 @@ def graphs_to_json(graphs, filename): bond_count_list = [0, 0, 0, 0, 0, 0, 0, 0] bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} valencies = [0] * 500 - transition_E = np.zeros((118, 118, 5)) + transition_E = np.zeros((118, 118, 8)) n_atom_list = [] n_bond_list = [] @@ -515,7 +377,7 @@ def graphs_to_json(graphs, filename): valencies[int(pt.GetDefaultValence(symbol))] += 1 except: print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol))) - transition_E_temp = np.zeros((118, 118, 5)) + transition_E_temp = np.zeros((118, 118, 8)) # print(n_atom) for i in range(n_atom): for j in range(n_atom): @@ -612,14 +474,16 @@ def graphs_to_json(graphs, filename): with open(f'{filename}.meta.json', 'w') as f: json.dump(meta_dict, f) return meta_dict - class Dataset(InMemoryDataset): def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None): self.target_prop = target_prop source = '/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' self.source = source self.api = API(source) # Initialize NAS-Bench-201 API + print('API loaded') super().__init__(root, transform, pre_transform, pre_filter) + print('Dataset initialized') + print(self.processed_paths[0]) self.data, self.slices = torch.load(self.processed_paths[0]) @property @@ -655,8 +519,11 @@ class Dataset(InMemoryDataset): def arch_to_graph(arch_str, sa, sc, target, target2=None, target3=None): nodes, edges = parse_architecture_string(arch_str) + node_labels = [bonds[node] for node in nodes] # Replace with appropriate encoding if necessary + assert 0 not in node_labels, f'Invalid node label: {node_labels}' x = torch.LongTensor(node_labels) + print(f'in initialize Dataset, arch_to_Graph x={x}') edges_list = [(start, end) for start, end in edges] edge_type = [bonds[nodes[end]] for start, end in edges] # Example: using end node type as edge type @@ -671,6 +538,7 @@ class Dataset(InMemoryDataset): else: y = torch.tensor([sa, sc, target], dtype=torch.float).view(1, -1) + print(f'in initialize Dataset, Data_init, x={x}, y={y}, edge_index={edge_index}, edge_attr={edge_attr}') data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y) return data, nodes @@ -679,9 +547,9 @@ class Dataset(InMemoryDataset): 'nor_conv_3x3': 2, 'avg_pool_3x3': 3, 'skip_connect': 4, - 'input': 7, 'output': 5, - 'none': 6 + 'none': 6, + 'input': 7 } # Prepare to process NAS-Bench-201 data diff --git a/graph_dit/diffusion_model.py b/graph_dit/diffusion_model.py index 5595c1c..9d26ecf 100644 --- a/graph_dit/diffusion_model.py +++ b/graph_dit/diffusion_model.py @@ -13,9 +13,11 @@ from metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL import utils class Graph_DiT(pl.LightningModule): - def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools): + # def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools): + def __init__(self, cfg, dataset_infos, visualization_tools): + super().__init__() - self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics']) + # self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics']) self.test_only = cfg.general.test_only self.guidance_target = getattr(cfg.dataset, 'guidance_target', None) @@ -55,8 +57,8 @@ class Graph_DiT(pl.LightningModule): self.test_E_logp = SumExceptBatchMetric() self.test_y_collection = [] - self.train_metrics = train_metrics - self.sampling_metrics = sampling_metrics + # self.train_metrics = train_metrics + # self.sampling_metrics = sampling_metrics self.visualization_tools = visualization_tools self.max_n_nodes = dataset_infos.max_n_nodes @@ -171,7 +173,7 @@ class Graph_DiT(pl.LightningModule): self.val_E_kl.reset() self.val_X_logp.reset() self.val_E_logp.reset() - self.sampling_metrics.reset() + # self.sampling_metrics.reset() self.val_y_collection = [] @torch.no_grad() @@ -239,14 +241,15 @@ class Graph_DiT(pl.LightningModule): samples_left_to_generate -= to_generate chains_left_to_save -= chains_save - print(f"Computing sampling metrics", ' ...') - valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False) - print(f'Done. Sampling took {time.time() - start:.2f} seconds\n') + # print(f"Computing sampling metrics", ' ...') + # valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False) + # print(f'Done. Sampling took {time.time() - start:.2f} seconds\n') + current_path = os.getcwd() result_path = os.path.join(current_path, f'graphs/{self.name}/epoch{self.current_epoch}_b0/') - self.visualization_tools.visualize_by_smiles(result_path, valid_smiles, self.cfg.general.samples_to_save) - self.sampling_metrics.reset() + # self.visualization_tools.visualize_by_smiles(result_path, valid_smiles, self.cfg.general.samples_to_save) + # self.sampling_metrics.reset() def on_test_epoch_start(self) -> None: print("Starting test...") diff --git a/graph_dit/main.py b/graph_dit/main.py index 2dcd97a..d7fbe14 100644 --- a/graph_dit/main.py +++ b/graph_dit/main.py @@ -50,7 +50,6 @@ def get_resume_adaptive(cfg, model_kwargs): # Fetch path to this file to get base path current_path = os.path.dirname(os.path.realpath(__file__)) root_dir = current_path.split("outputs")[0] - resume_path = os.path.join(root_dir, cfg.general.resume) if cfg.model.type == "discrete": @@ -80,21 +79,21 @@ def main(cfg: DictConfig): datamodule = dataset.DataModule(cfg) datamodule.prepare_data() dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg) - train_smiles, reference_smiles = datamodule.get_train_smiles() + # train_smiles, reference_smiles = datamodule.get_train_smiles() # get input output dimensions dataset_infos.compute_input_output_dims(datamodule=datamodule) - train_metrics = TrainMolecularMetricsDiscrete(dataset_infos) + # train_metrics = TrainMolecularMetricsDiscrete(dataset_infos) - sampling_metrics = SamplingMolecularMetrics( - dataset_infos, train_smiles, reference_smiles - ) + # sampling_metrics = SamplingMolecularMetrics( + # dataset_infos, train_smiles, reference_smiles + # ) visualization_tools = MolecularVisualization(dataset_infos) model_kwargs = { "dataset_infos": dataset_infos, - "train_metrics": train_metrics, - "sampling_metrics": sampling_metrics, + # "train_metrics": train_metrics, + # "sampling_metrics": sampling_metrics, "visualization_tools": visualization_tools, } @@ -110,9 +109,10 @@ def main(cfg: DictConfig): model = Graph_DiT(cfg=cfg, **model_kwargs) trainer = Trainer( gradient_clip_val=cfg.train.clip_grad, - accelerator="gpu" - if torch.cuda.is_available() and cfg.general.gpus > 0 - else "cpu", + # accelerator="gpu" + # if torch.cuda.is_available() and cfg.general.gpus > 0 + # else "cpu", + accelerator="cpu", devices=cfg.general.gpus if torch.cuda.is_available() and cfg.general.gpus > 0 else None,