Compare commits
	
		
			2 Commits
		
	
	
		
			99163a5150
			...
			82299e5213
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 82299e5213 | |||
|  | e04ad5fbe7 | 
| @@ -41,6 +41,6 @@ train: | ||||
|     check_val_every_n_epoch: 1 | ||||
| dataset: | ||||
|     datadir: 'data/' | ||||
|     task_name: null | ||||
|     guidance_target: null | ||||
|     task_name: 'nasbench-201' | ||||
|     guidance_target: 'nasbench-201' | ||||
|     pin_memory: False | ||||
|   | ||||
| @@ -116,7 +116,7 @@ class AbstractDatasetInfos: | ||||
|     def compute_input_output_dims(self, datamodule): | ||||
|         example_batch = datamodule.example_batch() | ||||
|         example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=118).float()[:, self.active_index] | ||||
|         example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=5).float() | ||||
|         example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=10).float() | ||||
|  | ||||
|         self.input_dims = {'X': example_batch_x.size(1),  | ||||
|                            'E': example_batch_edge_attr.size(1),  | ||||
|   | ||||
| @@ -13,6 +13,7 @@ import torch | ||||
| import torch.nn.functional as F | ||||
| from rdkit import Chem, RDLogger | ||||
| from rdkit.Chem.rdchem import BondType as BT | ||||
| from rdkit.Chem import rdchem | ||||
| from tqdm import tqdm | ||||
| import numpy as np | ||||
| import pandas as pd | ||||
| @@ -24,6 +25,9 @@ import utils as utils | ||||
| from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule | ||||
| from diffusion.distributions import DistributionNodes | ||||
|  | ||||
| import networkx as nx | ||||
|  | ||||
|  | ||||
| bonds = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} | ||||
|  | ||||
| op_to_atom = { | ||||
| @@ -77,6 +81,7 @@ class DataModule(AbstractDataModule): | ||||
|          | ||||
|         train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index] | ||||
|         self.train_dataset = train_dataset   | ||||
|         self.test_dataset = test_dataset | ||||
|         print('train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset)) | ||||
|         print('train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index)) | ||||
|         print('dataset len', len(dataset), 'train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset)) | ||||
| @@ -89,8 +94,9 @@ class DataModule(AbstractDataModule): | ||||
|         self.training_iterations = training_iterations | ||||
|      | ||||
|     def random_data_split(self, dataset): | ||||
|         nan_count = torch.isnan(dataset.data.y[:, 0]).sum().item() | ||||
|         labeled_len = len(dataset) - nan_count | ||||
|         # nan_count = torch.isnan(dataset.data.y[:, 0]).sum().item() | ||||
|         # labeled_len = len(dataset) - nan_count | ||||
|         labeled_len = len(dataset)  | ||||
|         full_idx = list(range(labeled_len)) | ||||
|         train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2 | ||||
|         train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42) | ||||
| @@ -111,8 +117,87 @@ class DataModule(AbstractDataModule): | ||||
|         print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index)) | ||||
|         return train_index, val_index, test_index, [] | ||||
|  | ||||
|     def parse_architecture_string(self, arch_str): | ||||
|             stages = arch_str.split('+') | ||||
|             nodes = ['input'] | ||||
|             edges = [] | ||||
|              | ||||
|             for stage in stages: | ||||
|                 operations = stage.strip('|').split('|') | ||||
|                 for op in operations: | ||||
|                     operation, idx = op.split('~') | ||||
|                     idx = int(idx) | ||||
|                     edges.append((idx, len(nodes)))  # Add edge from idx to the new node | ||||
|                     nodes.append(operation) | ||||
|             nodes.append('output')  # Add the output node | ||||
|             return nodes, edges | ||||
|  | ||||
|     # def create_molecule_from_graph(nodes, edges): | ||||
|     def create_molecule_from_graph(self, graph): | ||||
|         nodes = graph.x | ||||
|         edges = graph.edge_index | ||||
|         mol = Chem.RWMol()  # RWMol allows for building the molecule step by step | ||||
|         atom_indices = {} | ||||
|         num_to_op = { | ||||
|             1 :'nor_conv_1x1', | ||||
|             2 :'nor_conv_3x3', | ||||
|             3 :'avg_pool_3x3', | ||||
|             4 :'skip_connect', | ||||
|             5 :'output', | ||||
|             6 :'none', | ||||
|             7 :'input' | ||||
|         }  | ||||
|  | ||||
|         # Extract node operations from the data object | ||||
|  | ||||
|         # Add atoms to the molecule | ||||
|         for i, op_tensor in enumerate(nodes): | ||||
|             op = op_tensor.item() | ||||
|             if op == 0: continue | ||||
|             op = num_to_op[op] | ||||
|             atom_symbol = op_to_atom[op] | ||||
|             atom = Chem.Atom(atom_symbol) | ||||
|             atom_idx = mol.AddAtom(atom) | ||||
|             atom_indices[i] = atom_idx | ||||
|          | ||||
|         # Add bonds to the molecule | ||||
|         edge_number = edges.shape[1] | ||||
|         for i in range(edge_number): | ||||
|             start = edges[0, i].item() | ||||
|             end = edges[1, i].item() | ||||
|             mol.AddBond(atom_indices[start], atom_indices[end], rdchem.BondType.SINGLE) | ||||
|          | ||||
|         return mol | ||||
|  | ||||
|     def arch_str_to_smiles(self, arch_str): | ||||
|         nodes, edges = self.parse_architecture_string(arch_str) | ||||
|         mol = self.create_molecule_from_graph(nodes, edges) | ||||
|         smiles = Chem.MolToSmiles(mol) | ||||
|         return smiles | ||||
|  | ||||
|     def get_train_smiles(self): | ||||
|         raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.") | ||||
|         train_smiles = []    | ||||
|         test_smiles = [] | ||||
|  | ||||
|         for graph in self.train_dataset: | ||||
|             # print(f'idx={idx}') | ||||
|             # graph = self.train_dataset[idx] | ||||
|             print(graph.x) | ||||
|             print(graph.edge_index) | ||||
|             print(f'class of graph.x: {graph.x.__class__}, class of graph.edge_index: {graph.edge_index.__class__}') | ||||
|             mol = self.create_molecule_from_graph(graph) | ||||
|             train_smiles.append(Chem.MolToSmiles(mol)) | ||||
|          | ||||
|         # for idx in self.test_index: | ||||
|         for graph in self.test_dataset: | ||||
|             # graph = self.dataset[idx] | ||||
|             # mol = self.create_molecule_from_graph(graph.x, graph.edge_index) | ||||
|             mol = self.create_molecule_from_graph(graph) | ||||
|             test_smiles.append(Chem.MolToSmiles(mol)) | ||||
|          | ||||
|         # train_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in train_arch_strs] | ||||
|         # test_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in test_arch_strs] | ||||
|         return train_smiles, test_smiles | ||||
|  | ||||
|     def get_data_split(self): | ||||
|         raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.") | ||||
| @@ -129,161 +214,8 @@ class DataModule(AbstractDataModule): | ||||
|     def test_dataloader(self): | ||||
|         return self.test_loader | ||||
|  | ||||
| def graphs_to_json(graphs, filename): | ||||
|     bonds = { | ||||
|         'nor_conv_1x1': 1, | ||||
|         'nor_conv_3x3': 2, | ||||
|         'avg_pool_3x3': 3, | ||||
|         'skip_connect': 4, | ||||
|         'input': 7, | ||||
|         'output': 5, | ||||
|         'none': 6 | ||||
|     } | ||||
|  | ||||
|     source_name = "nas-bench-201" | ||||
|     num_graph = len(graphs) | ||||
|     pt = Chem.GetPeriodicTable() | ||||
|     atom_name_list = [] | ||||
|     atom_count_list = [] | ||||
|     for i in range(2, 119): | ||||
|         atom_name_list.append(pt.GetElementSymbol(i)) | ||||
|         atom_count_list.append(0) | ||||
|     atom_name_list.append('*') | ||||
|     atom_count_list.append(0) | ||||
|     n_atoms_per_mol = [0] * 500 | ||||
|     bond_count_list = [0, 0, 0, 0, 0, 0, 0, 0] | ||||
|     bond_type_to_index =  {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} | ||||
|     valencies = [0] * 500 | ||||
|     transition_E = np.zeros((118, 118, 5)) | ||||
|  | ||||
|     n_atom_list = [] | ||||
|     n_bond_list = [] | ||||
|     # graphs = [(adj_matrix, ops), ...] | ||||
|     for graph in graphs: | ||||
|         ops = graph[1] | ||||
|         adj = graph[0] | ||||
|         n_atom = len(ops) | ||||
|         n_bond = len(ops) | ||||
|         n_atom_list.append(n_atom) | ||||
|         n_bond_list.append(n_bond) | ||||
|  | ||||
|         n_atoms_per_mol[n_atom] += 1 | ||||
|         cur_atom_count_arr = np.zeros(118) | ||||
|  | ||||
|         for op in ops: | ||||
|             symbol = op_to_atom[op] | ||||
|             if symbol == 'H': | ||||
|                 continue | ||||
|             elif symbol == '*': | ||||
|                 atom_count_list[-1] += 1 | ||||
|                 cur_atom_count_arr[-1] += 1 | ||||
|             else: | ||||
|                 atom_count_list[pt.GetAtomicNumber(symbol)-2] += 1 | ||||
|                 cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2] += 1 | ||||
|                 # print('symbol', symbol) | ||||
|                 # print('pt.GetDefaultValence(symbol)', pt.GetDefaultValence(symbol)) | ||||
|                 # print(f'cur_atom_count_arr[{pt.GetAtomicNumber(symbol)-2}], {cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2]}') | ||||
|                 try: | ||||
|                     valencies[int(pt.GetDefaultValence(symbol))] += 1 | ||||
|                 except: | ||||
|                     print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol))) | ||||
|         transition_E_temp = np.zeros((118, 118, 5)) | ||||
|         # print(n_atom) | ||||
|         for i in range(n_atom): | ||||
|             for j in range(n_atom): | ||||
|                 if i == j or adj[i][j] == 0: | ||||
|                     continue | ||||
|                 start_atom, end_atom = i, j | ||||
|                 if ops[start_atom] == 'input' or ops[end_atom] == 'input': | ||||
|                     continue | ||||
|                 if ops[start_atom] == 'output' or ops[end_atom] == 'output': | ||||
|                     continue | ||||
|                 if ops[start_atom] == 'none' or ops[end_atom] == 'none': | ||||
|                     continue | ||||
|                  | ||||
|                 start_index = pt.GetAtomicNumber(op_to_atom[ops[start_atom]]) - 2 | ||||
|                 end_index = pt.GetAtomicNumber(op_to_atom[ops[end_atom]]) - 2 | ||||
|                 bond_index = bonds[ops[end_atom]] | ||||
|                 bond_count_list[bond_index] += 2 | ||||
|  | ||||
|                 # print(start_index, end_index, bond_index) | ||||
|                  | ||||
|                 transition_E[start_index, end_index, bond_index] += 2 | ||||
|                 transition_E[end_index, start_index, bond_index] += 2 | ||||
|                 transition_E_temp[start_index, end_index, bond_index] += 2 | ||||
|                 transition_E_temp[end_index, start_index, bond_index] += 2 | ||||
|  | ||||
|         bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2 | ||||
|         print(bond_count_list) | ||||
|         cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2 | ||||
|         # print(f'cur_tot_bond={cur_tot_bond}')    | ||||
|         # find non-zero element in cur_tot_bond | ||||
|         # for i in range(118): | ||||
|         #     for j in range(118): | ||||
|         #         if cur_tot_bond[i][j] != 0: | ||||
|         #             print(f'i={i}, j={j}, cur_tot_bond[i][j]={cur_tot_bond[i][j]}') | ||||
|         # n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol) | ||||
|         cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2 | ||||
|         # print(f"transition_E[:,:,0]={cur_tot_bond - transition_E_temp.sum(axis=-1)}") | ||||
|         transition_E[:, :, 0] += cur_tot_bond - transition_E_temp.sum(axis=-1) | ||||
|         # find non-zero element in transition_E | ||||
|         # for i in range(118): | ||||
|         #     for j in range(118): | ||||
|         #         if transition_E[i][j][0] != 0: | ||||
|         #             print(f'i={i}, j={j}, transition_E[i][j][0]={transition_E[i][j][0]}') | ||||
|         assert (cur_tot_bond > transition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}' | ||||
|      | ||||
|     n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol) | ||||
|     n_atoms_per_mol = n_atoms_per_mol.tolist()[:51] | ||||
|  | ||||
|     atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list) | ||||
|     print('processed meta info: ------', filename, '------') | ||||
|     print('len atom_count_list', len(atom_count_list)) | ||||
|     print('len atom_name_list', len(atom_name_list)) | ||||
|     active_atoms = np.array(atom_name_list)[atom_count_list > 0] | ||||
|     active_atoms = active_atoms.tolist() | ||||
|     atom_count_list = atom_count_list.tolist() | ||||
|  | ||||
|     bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list) | ||||
|     bond_count_list = bond_count_list.tolist() | ||||
|     valencies = np.array(valencies) / np.sum(valencies) | ||||
|     valencies = valencies.tolist() | ||||
|  | ||||
|     no_edge = np.sum(transition_E, axis=-1) == 0 | ||||
|     for i in range(118): | ||||
|         for j in range(118): | ||||
|             if no_edge[i][j] == False: | ||||
|                 print(f'have an edge at i={i} , j={j}, transition_E[i][j]={transition_E[i][j]}') | ||||
|     # print(f'no_edge: {no_edge}') | ||||
|     first_elt = transition_E[:, :, 0] | ||||
|     first_elt[no_edge] = 1 | ||||
|     transition_E[:, :, 0] = first_elt | ||||
|  | ||||
|     transition_E = transition_E / np.sum(transition_E, axis=-1, keepdims=True) | ||||
|  | ||||
|     # find non-zero element in transition_E again | ||||
|     for i in range(118): | ||||
|         for j in range(118): | ||||
|             if transition_E[i][j][0] != 0 and transition_E[i][j][0] != 1 and transition_E[i][j][0] != -1: | ||||
|                 print(f'i={i}, j={j}, 2_transition_E[i][j][0]={transition_E[i][j][0]}') | ||||
|  | ||||
|     meta_dict = { | ||||
|         'source': 'nasbench-201', | ||||
|         'num_graph': num_graph, | ||||
|         'n_atoms_per_mol_dist': n_atoms_per_mol[:51], | ||||
|         'max_node': max(n_atom_list), | ||||
|         'max_bond': max(n_bond_list), | ||||
|         'atom_type_dist': atom_count_list, | ||||
|         'bond_type_dist': bond_count_list, | ||||
|         'valencies': valencies, | ||||
|         'active_atoms': [atom_name_list[i] for i in range(118) if atom_count_list[i] > 0], | ||||
|         'num_atom_type': len([atom_name_list[i] for i in range(118) if atom_count_list[i] > 0]), | ||||
|         'transition_E': transition_E.tolist(), | ||||
|     } | ||||
|  | ||||
|     with open(f'{filename}.meta.json', 'w') as f: | ||||
|         json.dump(meta_dict, f) | ||||
|     return meta_dict | ||||
|  | ||||
| class DataModule_original(AbstractDataModule): | ||||
|     def __init__(self, cfg): | ||||
| @@ -412,7 +344,7 @@ def graphs_to_json(graphs, filename): | ||||
|     bond_count_list = [0, 0, 0, 0, 0, 0, 0, 0] | ||||
|     bond_type_to_index =  {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} | ||||
|     valencies = [0] * 500 | ||||
|     transition_E = np.zeros((118, 118, 5)) | ||||
|     transition_E = np.zeros((118, 118, 8)) | ||||
|  | ||||
|     n_atom_list = [] | ||||
|     n_bond_list = [] | ||||
| @@ -445,7 +377,7 @@ def graphs_to_json(graphs, filename): | ||||
|                     valencies[int(pt.GetDefaultValence(symbol))] += 1 | ||||
|                 except: | ||||
|                     print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol))) | ||||
|         transition_E_temp = np.zeros((118, 118, 5)) | ||||
|         transition_E_temp = np.zeros((118, 118, 8)) | ||||
|         # print(n_atom) | ||||
|         for i in range(n_atom): | ||||
|             for j in range(n_atom): | ||||
| @@ -542,6 +474,102 @@ def graphs_to_json(graphs, filename): | ||||
|     with open(f'{filename}.meta.json', 'w') as f: | ||||
|         json.dump(meta_dict, f) | ||||
|     return meta_dict | ||||
| class Dataset(InMemoryDataset): | ||||
|     def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None): | ||||
|         self.target_prop = target_prop | ||||
|         source = '/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||
|         self.source = source | ||||
|         self.api = API(source)  # Initialize NAS-Bench-201 API | ||||
|         print('API loaded') | ||||
|         super().__init__(root, transform, pre_transform, pre_filter) | ||||
|         print('Dataset initialized') | ||||
|         print(self.processed_paths[0]) | ||||
|         self.data, self.slices = torch.load(self.processed_paths[0]) | ||||
|  | ||||
|     @property | ||||
|     def raw_file_names(self): | ||||
|         return []  # NAS-Bench-201 data is loaded via the API, no raw files needed | ||||
|      | ||||
|     @property | ||||
|     def processed_file_names(self): | ||||
|         return [f'{self.source}.pt'] | ||||
|  | ||||
|     def process(self): | ||||
|         def parse_architecture_string(arch_str): | ||||
|             stages = arch_str.split('+') | ||||
|             nodes = ['input'] | ||||
|             edges = [] | ||||
|              | ||||
|             for stage in stages: | ||||
|                 operations = stage.strip('|').split('|') | ||||
|                 for op in operations: | ||||
|                     operation, idx = op.split('~') | ||||
|                     idx = int(idx) | ||||
|                     edges.append((idx, len(nodes)))  # Add edge from idx to the new node | ||||
|                     nodes.append(operation) | ||||
|             nodes.append('output')  # Add the output node | ||||
|             return nodes, edges | ||||
|  | ||||
|         def create_graph(nodes, edges): | ||||
|             G = nx.DiGraph() | ||||
|             for i, node in enumerate(nodes): | ||||
|                 G.add_node(i, label=node) | ||||
|             G.add_edges_from(edges) | ||||
|             return G | ||||
|  | ||||
|         def arch_to_graph(arch_str, sa, sc, target, target2=None, target3=None): | ||||
|             nodes, edges = parse_architecture_string(arch_str) | ||||
|  | ||||
|             node_labels = [bonds[node] for node in nodes]  # Replace with appropriate encoding if necessary | ||||
|             assert 0 not in node_labels, f'Invalid node label: {node_labels}' | ||||
|             x = torch.LongTensor(node_labels) | ||||
|             print(f'in initialize Dataset, arch_to_Graph x={x}') | ||||
|  | ||||
|             edges_list = [(start, end) for start, end in edges] | ||||
|             edge_type = [bonds[nodes[end]] for start, end in edges]  # Example: using end node type as edge type | ||||
|             edge_index = torch.tensor(edges_list, dtype=torch.long).t().contiguous() | ||||
|             edge_type = torch.tensor(edge_type, dtype=torch.long) | ||||
|             edge_attr = edge_type.view(-1, 1) | ||||
|  | ||||
|             if target3 is not None: | ||||
|                 y = torch.tensor([sa, sc, target, target2, target3], dtype=torch.float).view(1, -1) | ||||
|             elif target2 is not None: | ||||
|                 y = torch.tensor([sa, sc, target, target2], dtype=torch.float).view(1, -1) | ||||
|             else: | ||||
|                 y = torch.tensor([sa, sc, target], dtype=torch.float).view(1, -1) | ||||
|  | ||||
|             print(f'in initialize Dataset, Data_init, x={x}, y={y}, edge_index={edge_index}, edge_attr={edge_attr}') | ||||
|             data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y) | ||||
|             return data, nodes | ||||
|  | ||||
|         bonds = { | ||||
|             'nor_conv_1x1': 1, | ||||
|             'nor_conv_3x3': 2, | ||||
|             'avg_pool_3x3': 3, | ||||
|             'skip_connect': 4, | ||||
|             'output': 5, | ||||
|             'none': 6, | ||||
|             'input': 7 | ||||
|         } | ||||
|  | ||||
|         # Prepare to process NAS-Bench-201 data | ||||
|         data_list = [] | ||||
|         len_data = len(self.api)  # Number of architectures | ||||
|         with tqdm(total=len_data) as pbar: | ||||
|             for arch_index in range(len_data): | ||||
|                 arch_info = self.api.query_meta_info_by_index(arch_index) | ||||
|                 arch_str = arch_info.arch_str | ||||
|                 sa = np.random.rand()  # Placeholder for synthetic accessibility | ||||
|                 sc = np.random.rand()  # Placeholder for substructure count | ||||
|                 target = np.random.rand()  # Placeholder for target value | ||||
|                 target2 = np.random.rand()  # Placeholder for second target value | ||||
|                 target3 = np.random.rand()  # Placeholder for third target value | ||||
|  | ||||
|                 data, active_nodes = arch_to_graph(arch_str, sa, sc, target, target2, target3) | ||||
|                 data_list.append(data) | ||||
|                 pbar.update(1) | ||||
|  | ||||
|         torch.save(self.collate(data_list), self.processed_paths[0]) | ||||
|  | ||||
| class Dataset_origin(InMemoryDataset): | ||||
|     def __init__(self, source, root, target_prop=None, | ||||
| @@ -671,7 +699,7 @@ class DataInfos(AbstractDatasetInfos): | ||||
|         length = 15625 | ||||
|         ops_type = {} | ||||
|         len_ops = set() | ||||
|         api = API('../NAS-Bench-201-v1_0-e61699.pth') | ||||
|         api = API('/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth') | ||||
|         for i in range(length): | ||||
|             arch_info = api.query_meta_info_by_index(i) | ||||
|             nodes, edges = parse_architecture_string(arch_info.arch_str) | ||||
|   | ||||
| @@ -13,9 +13,11 @@ from metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL | ||||
| import utils | ||||
|  | ||||
| class Graph_DiT(pl.LightningModule): | ||||
|     def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools): | ||||
|     # def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools): | ||||
|     def __init__(self, cfg, dataset_infos, visualization_tools): | ||||
|  | ||||
|         super().__init__() | ||||
|         self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics']) | ||||
|         # self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics']) | ||||
|         self.test_only = cfg.general.test_only | ||||
|         self.guidance_target = getattr(cfg.dataset, 'guidance_target', None) | ||||
|  | ||||
| @@ -55,8 +57,8 @@ class Graph_DiT(pl.LightningModule): | ||||
|         self.test_E_logp = SumExceptBatchMetric() | ||||
|         self.test_y_collection = [] | ||||
|  | ||||
|         self.train_metrics = train_metrics | ||||
|         self.sampling_metrics = sampling_metrics | ||||
|         # self.train_metrics = train_metrics | ||||
|         # self.sampling_metrics = sampling_metrics | ||||
|  | ||||
|         self.visualization_tools = visualization_tools | ||||
|         self.max_n_nodes = dataset_infos.max_n_nodes | ||||
| @@ -171,7 +173,7 @@ class Graph_DiT(pl.LightningModule): | ||||
|         self.val_E_kl.reset() | ||||
|         self.val_X_logp.reset() | ||||
|         self.val_E_logp.reset() | ||||
|         self.sampling_metrics.reset() | ||||
|         # self.sampling_metrics.reset() | ||||
|         self.val_y_collection = [] | ||||
|  | ||||
|     @torch.no_grad() | ||||
| @@ -239,14 +241,15 @@ class Graph_DiT(pl.LightningModule): | ||||
|                 samples_left_to_generate -= to_generate | ||||
|                 chains_left_to_save -= chains_save | ||||
|  | ||||
|             print(f"Computing sampling metrics", ' ...') | ||||
|             valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False) | ||||
|             print(f'Done. Sampling took {time.time() - start:.2f} seconds\n') | ||||
|             # print(f"Computing sampling metrics", ' ...') | ||||
|             # valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False) | ||||
|             # print(f'Done. Sampling took {time.time() - start:.2f} seconds\n') | ||||
|  | ||||
|             current_path = os.getcwd() | ||||
|             result_path = os.path.join(current_path, | ||||
|                                        f'graphs/{self.name}/epoch{self.current_epoch}_b0/') | ||||
|             self.visualization_tools.visualize_by_smiles(result_path, valid_smiles, self.cfg.general.samples_to_save) | ||||
|             self.sampling_metrics.reset() | ||||
|             # self.visualization_tools.visualize_by_smiles(result_path, valid_smiles, self.cfg.general.samples_to_save) | ||||
|             # self.sampling_metrics.reset() | ||||
|  | ||||
|     def on_test_epoch_start(self) -> None: | ||||
|         print("Starting test...") | ||||
|   | ||||
| @@ -50,7 +50,6 @@ def get_resume_adaptive(cfg, model_kwargs): | ||||
|     # Fetch path to this file to get base path | ||||
|     current_path = os.path.dirname(os.path.realpath(__file__)) | ||||
|     root_dir = current_path.split("outputs")[0] | ||||
|  | ||||
|     resume_path = os.path.join(root_dir, cfg.general.resume) | ||||
|  | ||||
|     if cfg.model.type == "discrete": | ||||
| @@ -80,21 +79,21 @@ def main(cfg: DictConfig): | ||||
|     datamodule = dataset.DataModule(cfg) | ||||
|     datamodule.prepare_data() | ||||
|     dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg) | ||||
|     train_smiles, reference_smiles = datamodule.get_train_smiles() | ||||
|     # train_smiles, reference_smiles = datamodule.get_train_smiles() | ||||
|  | ||||
|     # get input output dimensions | ||||
|     dataset_infos.compute_input_output_dims(datamodule=datamodule) | ||||
|     train_metrics = TrainMolecularMetricsDiscrete(dataset_infos) | ||||
|     # train_metrics = TrainMolecularMetricsDiscrete(dataset_infos) | ||||
|  | ||||
|     sampling_metrics = SamplingMolecularMetrics( | ||||
|         dataset_infos, train_smiles, reference_smiles | ||||
|     ) | ||||
|     # sampling_metrics = SamplingMolecularMetrics( | ||||
|     #     dataset_infos, train_smiles, reference_smiles | ||||
|     # ) | ||||
|     visualization_tools = MolecularVisualization(dataset_infos) | ||||
|  | ||||
|     model_kwargs = { | ||||
|         "dataset_infos": dataset_infos, | ||||
|         "train_metrics": train_metrics, | ||||
|         "sampling_metrics": sampling_metrics, | ||||
|         # "train_metrics": train_metrics, | ||||
|         # "sampling_metrics": sampling_metrics, | ||||
|         "visualization_tools": visualization_tools, | ||||
|     } | ||||
|  | ||||
| @@ -110,9 +109,10 @@ def main(cfg: DictConfig): | ||||
|     model = Graph_DiT(cfg=cfg, **model_kwargs) | ||||
|     trainer = Trainer( | ||||
|         gradient_clip_val=cfg.train.clip_grad, | ||||
|         accelerator="gpu" | ||||
|         if torch.cuda.is_available() and cfg.general.gpus > 0 | ||||
|         else "cpu", | ||||
|         # accelerator="gpu" | ||||
|         # if torch.cuda.is_available() and cfg.general.gpus > 0 | ||||
|         # else "cpu", | ||||
|         accelerator="cpu", | ||||
|         devices=cfg.general.gpus | ||||
|         if torch.cuda.is_available() and cfg.general.gpus > 0 | ||||
|         else None, | ||||
|   | ||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
		Reference in New Issue
	
	Block a user