Compare commits
	
		
			2 Commits
		
	
	
		
			7274b3f606
			...
			0fc6f6e686
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 0fc6f6e686 | |||
| d57575586d | 
| @@ -13,11 +13,11 @@ from metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL | ||||
| import utils | ||||
|  | ||||
| class Graph_DiT(pl.LightningModule): | ||||
|     # def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools): | ||||
|     def __init__(self, cfg, dataset_infos, visualization_tools): | ||||
|     def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools): | ||||
|     # def __init__(self, cfg, dataset_infos, visualization_tools): | ||||
|  | ||||
|         super().__init__() | ||||
|         # self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics']) | ||||
|         self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics']) | ||||
|         self.test_only = cfg.general.test_only | ||||
|         self.guidance_target = getattr(cfg.dataset, 'guidance_target', None) | ||||
|  | ||||
| @@ -57,8 +57,8 @@ class Graph_DiT(pl.LightningModule): | ||||
|         self.test_E_logp = SumExceptBatchMetric() | ||||
|         self.test_y_collection = [] | ||||
|  | ||||
|         # self.train_metrics = train_metrics | ||||
|         # self.sampling_metrics = sampling_metrics | ||||
|         self.train_metrics = train_metrics | ||||
|         self.sampling_metrics = sampling_metrics | ||||
|  | ||||
|         self.visualization_tools = visualization_tools | ||||
|         self.max_n_nodes = dataset_infos.max_n_nodes | ||||
| @@ -179,9 +179,9 @@ class Graph_DiT(pl.LightningModule): | ||||
|     @torch.no_grad() | ||||
|     def validation_step(self, data, i): | ||||
|         data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index] | ||||
|         data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float() | ||||
|         data_edge_attr = F.one_hot(data.edge_attr, num_classes=10).float() | ||||
|         dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) | ||||
|         dense_data = dense_data.mask(node_mask) | ||||
|         dense_data = dense_data.mask(node_mask, collapse=True) | ||||
|         noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) | ||||
|         pred = self.forward(noisy_data) | ||||
|         nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=False) | ||||
|   | ||||
| @@ -23,6 +23,103 @@ def result_to_csv(path, dict_data): | ||||
|             writer.writeheader() | ||||
|         writer.writerow(dict_data) | ||||
|  | ||||
| class SamplingGraphMetrics(nn.Module): | ||||
|     def __init__( | ||||
|             self, | ||||
|             dataset_infos, | ||||
|             train_graphs, | ||||
|             reference_graphs, | ||||
|             n_jobs=1, | ||||
|             device="cpu", | ||||
|             batch_size=512, | ||||
|     ): | ||||
|         super().__init__() | ||||
|         self.task_name = dataset_infos.task | ||||
|         self.dataset_infos = dataset_infos | ||||
|         self.active_nodes = dataset_infos.active_nodes | ||||
|         self.train_graphs = train_graphs | ||||
|  | ||||
|         self.stat_ref = None | ||||
|  | ||||
|         self.compute_config = { | ||||
|             "n_jobs": n_jobs, | ||||
|             "device": device, | ||||
|             "batch_size": batch_size, | ||||
|         } | ||||
|  | ||||
|         self.task_evaluator = { | ||||
|             'meta_taskname': dataset_infos.task, | ||||
|             'sas': None, | ||||
|             'scs': None | ||||
|         } | ||||
|  | ||||
|         for cur_task in dataset_infos.task.split("-")[:]: | ||||
|             model_path = os.path.join( | ||||
|                 dataset_infos.base_path, "data/evaluator", f"{cur_task}.joblib" | ||||
|             ) | ||||
|             os.makedirs(os.path.dirname(model_path), exist_ok=True) | ||||
|             evaluator = TaskModel(model_path, cur_task) | ||||
|             self.task_evaluator[cur_task] = evaluator | ||||
|  | ||||
|     def forward(self, graphs, targets, name, current_epoch, val_counter, test=False): | ||||
|         if isinstance(targets, list): | ||||
|             targets_cat = torch.cat(targets, dim=0) | ||||
|             targets_np = targets_cat.detach().cpu().numpy() | ||||
|         else: | ||||
|             targets_np = targets.detach().cpu().numpy() | ||||
|  | ||||
|         unique_graphs, all_graphs, all_graphs, targets_log = compute_molecular_metrics( | ||||
|             graphs, | ||||
|             targets_np, | ||||
|             self.train_graphs, | ||||
|             self.stat_ref, | ||||
|             self.dataset_infos, | ||||
|             self.task_evaluator, | ||||
|             self.compute_config, | ||||
|         ) | ||||
|  | ||||
|         if test: | ||||
|             file_name = "final_graphs.txt" | ||||
|             with open(file_name, "w") as fp: | ||||
|                 all_tasks_name = list(self.task_evaluator.keys()) | ||||
|                 all_tasks_name = all_tasks_name.copy() | ||||
|                 if 'meta_taskname' in all_tasks_name: | ||||
|                     all_tasks_name.remove('meta_taskname') | ||||
|  | ||||
|                 all_tasks_str = "graph, " + ", ".join([f"input_{task}" for task in all_tasks_name] + [f"output_{task}" for task in all_tasks_name]) | ||||
|                 fp.write(all_tasks_str + "\n") | ||||
|                 for i, graph in enumerate(all_graphs): | ||||
|                     if targets_log is not None: | ||||
|                         all_result_str = f"{graph}, " + ", ".join([f"{targets_log['input_'+task][i]}" for task in all_tasks_name] + [f"{targets_log['output_'+task][i]}" for task in all_tasks_name]) | ||||
|                         fp.write(all_result_str + "\n") | ||||
|                     else: | ||||
|                         fp.write("%s\n" % graph) | ||||
|                 print("All graphs saved") | ||||
|         else: | ||||
|             result_path = os.path.join(os.getcwd(), f"graphs/{name}") | ||||
|             os.makedirs(result_path, exist_ok=True) | ||||
|             text_path = os.path.join( | ||||
|                 result_path, | ||||
|                 f"valid_unique_graphs_e{current_epoch}_b{val_counter}.txt", | ||||
|             ) | ||||
|             textfile = open(text_path, "w") | ||||
|             for graph in unique_graphs: | ||||
|                 textfile.write(graph + "\n") | ||||
|             textfile.close() | ||||
|          | ||||
|         all_logs = all_graphs | ||||
|         if test: | ||||
|             all_logs["log_name"] = "test" | ||||
|         else: | ||||
|             all_logs["log_name"] = ( | ||||
|                 "epoch" + str(current_epoch) + "_batch" + str(val_counter) | ||||
|             ) | ||||
|          | ||||
|         result_to_csv("output.csv", all_logs) | ||||
|         return all_graphs | ||||
|      | ||||
|     def reset(self): | ||||
|         pass | ||||
|              | ||||
| class SamplingMolecularMetrics(nn.Module): | ||||
|     def __init__( | ||||
| @@ -40,20 +137,20 @@ class SamplingMolecularMetrics(nn.Module): | ||||
|         self.active_atoms = dataset_infos.active_atoms | ||||
|         self.train_smiles = train_smiles | ||||
|  | ||||
|         if reference_smiles is not None: | ||||
|             print( | ||||
|                 f"--- Computing intermediate statistics for training for #{len(reference_smiles)} smiles ---" | ||||
|             ) | ||||
|             start_time = time.time() | ||||
|             self.stat_ref = compute_intermediate_statistics( | ||||
|                 reference_smiles, n_jobs=n_jobs, device=device, batch_size=batch_size | ||||
|             ) | ||||
|             end_time = time.time() | ||||
|             elapsed_time = end_time - start_time | ||||
|             print( | ||||
|                 f"--- End computing intermediate statistics: using {elapsed_time:.2f}s ---" | ||||
|             ) | ||||
|         else: | ||||
|         # if reference_smiles is not None: | ||||
|         #     print( | ||||
|         #         f"--- Computing intermediate statistics for training for #{len(reference_smiles)} smiles ---" | ||||
|         #     ) | ||||
|         #     start_time = time.time() | ||||
|         #     self.stat_ref = compute_intermediate_statistics( | ||||
|         #         reference_smiles, n_jobs=n_jobs, device=device, batch_size=batch_size | ||||
|         #     ) | ||||
|         #     end_time = time.time() | ||||
|         #     elapsed_time = end_time - start_time | ||||
|         #     print( | ||||
|         #         f"--- End computing intermediate statistics: using {elapsed_time:.2f}s ---" | ||||
|         #     ) | ||||
|         # else: | ||||
|         self.stat_ref = None | ||||
|      | ||||
|         self.comput_config = { | ||||
|   | ||||
| @@ -77,6 +77,15 @@ class NodeMetricsCE(MetricCollection): | ||||
|  | ||||
|         for i, node_type in enumerate(active_nodes) : | ||||
|             metrics_list.append(type(f'{node_type}_CE', (NodeCE,), {})(i)) | ||||
|         super().__init__(metrics_list) | ||||
|  | ||||
| class EdgeMetricsCE(MetricCollection): | ||||
|     def __init__(self): | ||||
|         ce_no_bond = NoBondCE(0) | ||||
|         ce_SI = SingleCE(1) | ||||
|         ce_DO = DoubleCE(2) | ||||
|         ce_TR = TripleCE(3) | ||||
|         super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR]) | ||||
|  | ||||
| class AtomMetricsCE(MetricCollection): | ||||
|     def __init__(self, active_atoms): | ||||
| @@ -101,6 +110,41 @@ class BondMetricsCE(MetricCollection): | ||||
| class TrainGraphMetricsDiscrete(nn.Module): | ||||
|     def __init__(self, dataset_infos): | ||||
|         super().__init__() | ||||
|         active_nodes = dataset_infos.active_nodes | ||||
|         self.train_node_metrics = NodeMetricsCE(active_nodes=active_nodes) | ||||
|         self.train_edge_metrics = EdgeMetricsCE() | ||||
|  | ||||
|     def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool): | ||||
|         self.train_node_metrics(masked_pred_X, true_X) | ||||
|         self.train_edge_metrics(masked_pred_E, true_E) | ||||
|         if log: | ||||
|             to_log = {} | ||||
|             for key, val in self.train_node_metrics.compute().items(): | ||||
|                 to_log['train/' + key] = val.item() | ||||
|             for key, val in self.train_edge_metrics.compute().items(): | ||||
|                 to_log['train/' + key] = val.item() | ||||
|  | ||||
|     def reset(self): | ||||
|         for metric in [self.train_node_metrics, self.train_edge_metrics]: | ||||
|             metric.reset() | ||||
|  | ||||
|     def log_epoch_metrics(self, current_epoch, log=True): | ||||
|         epoch_node_metrics = self.train_node_metrics.compute() | ||||
|         epoch_edge_metrics = self.train_edge_metrics.compute() | ||||
|  | ||||
|         to_log = {} | ||||
|         for key, val in epoch_node_metrics.items(): | ||||
|             to_log['train_epoch/' + key] = val.item() | ||||
|         for key, val in epoch_edge_metrics.items(): | ||||
|             to_log['train_epoch/' + key] = val.item() | ||||
|  | ||||
|         for key, val in epoch_node_metrics.items(): | ||||
|             epoch_node_metrics[key] = round(val.item(),4) | ||||
|         for key, val in epoch_edge_metrics.items(): | ||||
|             epoch_edge_metrics[key] = round(val.item(),4) | ||||
|  | ||||
|         if log: | ||||
|             print(f"Epoch {current_epoch}: {epoch_node_metrics} -- {epoch_edge_metrics}") | ||||
|  | ||||
| class TrainMolecularMetricsDiscrete(nn.Module): | ||||
|     def __init__(self, dataset_infos): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user