Compare commits
	
		
			5 Commits
		
	
	
		
			82299e5213
			...
			a7f7010da7
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| a7f7010da7 | |||
| 14186fa97f | |||
| a222c514d9 | |||
| 062a27b83f | |||
| 0c7c525680 | 
| @@ -127,4 +127,19 @@ class AbstractDatasetInfos: | |||||||
|         print('input dims') |         print('input dims') | ||||||
|         print(self.input_dims) |         print(self.input_dims) | ||||||
|         print('output dims') |         print('output dims') | ||||||
|  |         print(self.output_dims) | ||||||
|  |     def compute_graph_input_output_dims(self, datamodule): | ||||||
|  |         example_batch = datamodule.example_batch() | ||||||
|  |         example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=8).float()[:, self.active_index] | ||||||
|  |         example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=2).float() | ||||||
|  |  | ||||||
|  |         self.input_dims = {'X': example_batch_x.size(1), | ||||||
|  |                            'E': example_batch_edge_attr.size(1), | ||||||
|  |                            'y': example_batch['y'].size(1)} | ||||||
|  |         self.output_dims = {'X': example_batch_x.size(1), | ||||||
|  |                             'E': example_batch_edge_attr.size(1), | ||||||
|  |                             'y': example_batch['y'].size(1)} | ||||||
|  |         print('input dims') | ||||||
|  |         print(self.input_dims) | ||||||
|  |         print('output dims') | ||||||
|         print(self.output_dims) |         print(self.output_dims) | ||||||
| @@ -50,12 +50,12 @@ class DataModule(AbstractDataModule): | |||||||
|  |  | ||||||
|     def prepare_data(self) -> None: |     def prepare_data(self) -> None: | ||||||
|         target = getattr(self.cfg.dataset, 'guidance_target', None) |         target = getattr(self.cfg.dataset, 'guidance_target', None) | ||||||
|         print("target", target) |         print("target", target) # nasbench-201 | ||||||
|         # try: |         # try: | ||||||
|         #     base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] |         #     base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] | ||||||
|         # except NameError: |         # except NameError: | ||||||
|         # base_path = pathlib.Path(os.getcwd()).parent[2] |         # base_path = pathlib.Path(os.getcwd()).parent[2] | ||||||
|         base_path = '/home/stud/hanzhang/Graph-Dit' |         base_path = '/home/stud/hanzhang/nasbenchDiT' | ||||||
|         root_path = os.path.join(base_path, self.datadir) |         root_path = os.path.join(base_path, self.datadir) | ||||||
|         self.root_path = root_path |         self.root_path = root_path | ||||||
|  |  | ||||||
| @@ -68,13 +68,16 @@ class DataModule(AbstractDataModule): | |||||||
|         # Dataset has target property, root path, and transform |         # Dataset has target property, root path, and transform | ||||||
|         source = './NAS-Bench-201-v1_1-096897.pth' |         source = './NAS-Bench-201-v1_1-096897.pth' | ||||||
|         dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None) |         dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None) | ||||||
|  |         self.dataset = dataset | ||||||
|  |         self.api = dataset.api | ||||||
|  |  | ||||||
|         # if len(self.task.split('-')) == 2: |         # if len(self.task.split('-')) == 2: | ||||||
|         #     train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset) |         #     train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset) | ||||||
|         # else: |         # else: | ||||||
|         train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset) |         train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset) | ||||||
|  |  | ||||||
|         self.train_index, self.val_index, self.test_index, self.unlabeled_index = train_index, val_index, test_index, unlabeled_index |         self.train_index, self.val_index, self.test_index, self.unlabeled_index = ( | ||||||
|  |             train_index, val_index, test_index, unlabeled_index) | ||||||
|         train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index) |         train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index) | ||||||
|         if len(unlabeled_index) > 0: |         if len(unlabeled_index) > 0: | ||||||
|             train_index = torch.cat([train_index, unlabeled_index], dim=0) |             train_index = torch.cat([train_index, unlabeled_index], dim=0) | ||||||
| @@ -175,6 +178,27 @@ class DataModule(AbstractDataModule): | |||||||
|         smiles = Chem.MolToSmiles(mol) |         smiles = Chem.MolToSmiles(mol) | ||||||
|         return smiles |         return smiles | ||||||
|  |  | ||||||
|  |     def get_train_graphs(self): | ||||||
|  |         train_graphs = [] | ||||||
|  |         test_graphs = [] | ||||||
|  |         for graph in self.train_dataset: | ||||||
|  |             train_graphs.append(graph) | ||||||
|  |         for graph in self.test_dataset: | ||||||
|  |             test_graphs.append(graph) | ||||||
|  |         return train_graphs, test_graphs | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     # def get_train_smiles(self): | ||||||
|  |     #     filename = f'{self.task}.csv.gz' | ||||||
|  |     #     df = pd.read_csv(f'{self.root_path}/raw/{filename}') | ||||||
|  |     #     df_test = df.iloc[self.test_index] | ||||||
|  |     #     df = df.iloc[self.train_index] | ||||||
|  |     #     smiles_list = df['smiles'].tolist() | ||||||
|  |     #     smiles_list_test = df_test['smiles'].tolist() | ||||||
|  |     #     smiles_list = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list] | ||||||
|  |     #     smiles_list_test = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list_test] | ||||||
|  |     #     return smiles_list, smiles_list_test | ||||||
|  |  | ||||||
|     def get_train_smiles(self): |     def get_train_smiles(self): | ||||||
|         train_smiles = []    |         train_smiles = []    | ||||||
|         test_smiles = [] |         test_smiles = [] | ||||||
| @@ -477,14 +501,17 @@ def graphs_to_json(graphs, filename): | |||||||
| class Dataset(InMemoryDataset): | class Dataset(InMemoryDataset): | ||||||
|     def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None): |     def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None): | ||||||
|         self.target_prop = target_prop |         self.target_prop = target_prop | ||||||
|         source = '/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' |         source = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||||
|         self.source = source |         self.source = source | ||||||
|  |         super().__init__(root, transform, pre_transform, pre_filter) | ||||||
|  |         print(self.processed_paths[0]) #/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth.pt | ||||||
|         self.api = API(source)  # Initialize NAS-Bench-201 API |         self.api = API(source)  # Initialize NAS-Bench-201 API | ||||||
|         print('API loaded') |         print('API loaded') | ||||||
|         super().__init__(root, transform, pre_transform, pre_filter) |  | ||||||
|         print('Dataset initialized') |         print('Dataset initialized') | ||||||
|         print(self.processed_paths[0]) |  | ||||||
|         self.data, self.slices = torch.load(self.processed_paths[0]) |         self.data, self.slices = torch.load(self.processed_paths[0]) | ||||||
|  |         self.data.edge_attr = self.data.edge_attr.squeeze() | ||||||
|  |         self.data.idx = torch.arange(len(self.data.y)) | ||||||
|  |         print(f"self.data={self.data}, self.slices={self.slices}") | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def raw_file_names(self): |     def raw_file_names(self): | ||||||
| @@ -676,7 +703,7 @@ def create_adj_matrix_and_ops(nodes, edges): | |||||||
|         adj_matrix[src][dst] = 1 |         adj_matrix[src][dst] = 1 | ||||||
|     return adj_matrix, nodes |     return adj_matrix, nodes | ||||||
| class DataInfos(AbstractDatasetInfos): | class DataInfos(AbstractDatasetInfos): | ||||||
|     def __init__(self, datamodule, cfg): |     def __init__(self, datamodule, cfg, dataset): | ||||||
|         tasktype_dict = { |         tasktype_dict = { | ||||||
|             'hiv_b': 'classification', |             'hiv_b': 'classification', | ||||||
|             'bace_b': 'classification', |             'bace_b': 'classification', | ||||||
| @@ -689,6 +716,7 @@ class DataInfos(AbstractDatasetInfos): | |||||||
|         self.task = task_name |         self.task = task_name | ||||||
|         self.task_type = tasktype_dict.get(task_name, "regression") |         self.task_type = tasktype_dict.get(task_name, "regression") | ||||||
|         self.ensure_connected = cfg.model.ensure_connected |         self.ensure_connected = cfg.model.ensure_connected | ||||||
|  |         self.api = dataset.api | ||||||
|  |  | ||||||
|         datadir = cfg.dataset.datadir |         datadir = cfg.dataset.datadir | ||||||
|  |  | ||||||
| @@ -699,9 +727,9 @@ class DataInfos(AbstractDatasetInfos): | |||||||
|         length = 15625 |         length = 15625 | ||||||
|         ops_type = {} |         ops_type = {} | ||||||
|         len_ops = set() |         len_ops = set() | ||||||
|         api = API('/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth') |         # api = API('/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth') | ||||||
|         for i in range(length): |         for i in range(length): | ||||||
|             arch_info = api.query_meta_info_by_index(i) |             arch_info = self.api.query_meta_info_by_index(i) | ||||||
|             nodes, edges = parse_architecture_string(arch_info.arch_str) |             nodes, edges = parse_architecture_string(arch_info.arch_str) | ||||||
|             adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges)     |             adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges)     | ||||||
|             if i < 5: |             if i < 5: | ||||||
| @@ -716,7 +744,6 @@ class DataInfos(AbstractDatasetInfos): | |||||||
|             graphs.append((adj_matrix, ops)) |             graphs.append((adj_matrix, ops)) | ||||||
|  |  | ||||||
|         meta_dict = graphs_to_json(graphs, 'nasbench-201') |         meta_dict = graphs_to_json(graphs, 'nasbench-201') | ||||||
|  |  | ||||||
|         self.base_path = base_path |         self.base_path = base_path | ||||||
|         self.active_atoms = meta_dict['active_atoms'] |         self.active_atoms = meta_dict['active_atoms'] | ||||||
|         self.max_n_nodes = meta_dict['max_node'] |         self.max_n_nodes = meta_dict['max_node'] | ||||||
| @@ -930,4 +957,4 @@ def compute_meta(root, source_name, train_index, test_index): | |||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     pass |     dataset = Dataset(source='nasbench', root='/home/stud/hanzhang/nasbenchDiT/graph-dit', target_prop='Class', transform=None) | ||||||
|   | |||||||
| @@ -78,16 +78,20 @@ def main(cfg: DictConfig): | |||||||
|  |  | ||||||
|     datamodule = dataset.DataModule(cfg) |     datamodule = dataset.DataModule(cfg) | ||||||
|     datamodule.prepare_data() |     datamodule.prepare_data() | ||||||
|     dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg) |     dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset) | ||||||
|     # train_smiles, reference_smiles = datamodule.get_train_smiles() |     # train_smiles, reference_smiles = datamodule.get_train_smiles() | ||||||
|  |     train_graphs, reference_graphs = datamodule.get_train_graphs() | ||||||
|  |  | ||||||
|     # get input output dimensions |     # get input output dimensions | ||||||
|     dataset_infos.compute_input_output_dims(datamodule=datamodule) |     dataset_infos.compute_input_output_dims(datamodule=datamodule) | ||||||
|     # train_metrics = TrainMolecularMetricsDiscrete(dataset_infos) |     train_metrics = TrainMolecularMetricsDiscrete(dataset_infos) | ||||||
|  |  | ||||||
|     # sampling_metrics = SamplingMolecularMetrics( |     # sampling_metrics = SamplingMolecularMetrics( | ||||||
|     #     dataset_infos, train_smiles, reference_smiles |     #     dataset_infos, train_smiles, reference_smiles | ||||||
|     # ) |     # ) | ||||||
|  |     sampling_metrics = SamplingGraphMetrics( | ||||||
|  |         dataset_infos, train_graphs, reference_graphs | ||||||
|  |     ) | ||||||
|     visualization_tools = MolecularVisualization(dataset_infos) |     visualization_tools = MolecularVisualization(dataset_infos) | ||||||
|  |  | ||||||
|     model_kwargs = { |     model_kwargs = { | ||||||
| @@ -135,5 +139,16 @@ def main(cfg: DictConfig): | |||||||
|     else: |     else: | ||||||
|         trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only) |         trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only) | ||||||
|  |  | ||||||
|  | @hydra.main( | ||||||
|  |     version_base="1.1", config_path="../configs", config_name="config" | ||||||
|  | ) | ||||||
|  | def test(cfg: DictConfig): | ||||||
|  |     datamodule = dataset.DataModule(cfg) | ||||||
|  |     datamodule.prepare_data() | ||||||
|  |     dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset) | ||||||
|  |     train_graphs, reference_graphs = datamodule.get_train_graphs() | ||||||
|  |  | ||||||
|  |     dataset_infos.compute_input_output_dims(datamodule=datamodule) | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     main() |     test() | ||||||
|   | |||||||
							
								
								
									
										0
									
								
								graph_dit/workingdoc.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								graph_dit/workingdoc.md
									
									
									
									
									
										Normal file
									
								
							
		Reference in New Issue
	
	Block a user