update_name
This commit is contained in:
		
							
								
								
									
										126
									
								
								graph_dit/metrics/molecular_metrics_train.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										126
									
								
								graph_dit/metrics/molecular_metrics_train.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,126 @@ | ||||
| import torch | ||||
| from torchmetrics import Metric, MetricCollection | ||||
| from torch import Tensor | ||||
| import torch.nn as nn | ||||
|  | ||||
| class CEPerClass(Metric): | ||||
|     full_state_update = False | ||||
|     def __init__(self, class_id): | ||||
|         super().__init__() | ||||
|         self.class_id = class_id | ||||
|         self.add_state('total_ce', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|         self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|         self.softmax = torch.nn.Softmax(dim=-1) | ||||
|         self.binary_cross_entropy = torch.nn.BCELoss(reduction='sum') | ||||
|      | ||||
|     def update(self, preds: Tensor, target: Tensor) -> None: | ||||
|         """Update state with predictions and targets. | ||||
|         Args: | ||||
|             preds: Predictions from model   (bs, n, d) or (bs, n, n, d) | ||||
|             target: Ground truth values     (bs, n, d) or (bs, n, n, d) | ||||
|         """ | ||||
|         target = target.reshape(-1, target.shape[-1]) | ||||
|         mask = (target != 0.).any(dim=-1) | ||||
|  | ||||
|         prob = self.softmax(preds)[..., self.class_id] | ||||
|         prob = prob.flatten()[mask] | ||||
|  | ||||
|         target = target[:, self.class_id] | ||||
|         target = target[mask] | ||||
|  | ||||
|         output = self.binary_cross_entropy(prob, target) | ||||
|  | ||||
|         self.total_ce += output | ||||
|         self.total_samples += prob.numel() | ||||
|  | ||||
|     def compute(self): | ||||
|         return self.total_ce / self.total_samples | ||||
|  | ||||
|  | ||||
| class AtomCE(CEPerClass): | ||||
|     def __init__(self, i): | ||||
|         super().__init__(i) | ||||
|  | ||||
| class NoBondCE(CEPerClass): | ||||
|     def __init__(self, i): | ||||
|         super().__init__(i) | ||||
|  | ||||
|  | ||||
| class SingleCE(CEPerClass): | ||||
|     def __init__(self, i): | ||||
|         super().__init__(i) | ||||
|  | ||||
|  | ||||
| class DoubleCE(CEPerClass): | ||||
|     def __init__(self, i): | ||||
|         super().__init__(i) | ||||
|  | ||||
|  | ||||
| class TripleCE(CEPerClass): | ||||
|     def __init__(self, i): | ||||
|         super().__init__(i) | ||||
|  | ||||
|  | ||||
| class AromaticCE(CEPerClass): | ||||
|     def __init__(self, i): | ||||
|         super().__init__(i) | ||||
|  | ||||
|  | ||||
| class AtomMetricsCE(MetricCollection): | ||||
|     def __init__(self, active_atoms): | ||||
|         metrics_list = [] | ||||
|          | ||||
|         for i, atom_type in enumerate(active_atoms): | ||||
|             metrics_list.append(type(f'{atom_type}_CE', (AtomCE,), {})(i)) | ||||
|  | ||||
|         super().__init__(metrics_list) | ||||
|  | ||||
|  | ||||
| class BondMetricsCE(MetricCollection): | ||||
|     def __init__(self): | ||||
|         ce_no_bond = NoBondCE(0) | ||||
|         ce_SI = SingleCE(1) | ||||
|         ce_DO = DoubleCE(2) | ||||
|         ce_TR = TripleCE(3) | ||||
|         super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR]) | ||||
|  | ||||
|  | ||||
| class TrainMolecularMetricsDiscrete(nn.Module): | ||||
|     def __init__(self, dataset_infos): | ||||
|         super().__init__() | ||||
|         active_atoms = dataset_infos.active_atoms | ||||
|         self.train_atom_metrics = AtomMetricsCE(active_atoms=active_atoms) | ||||
|         self.train_bond_metrics = BondMetricsCE() | ||||
|  | ||||
|     def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool): | ||||
|         self.train_atom_metrics(masked_pred_X, true_X) | ||||
|         self.train_bond_metrics(masked_pred_E, true_E) | ||||
|         if log: | ||||
|             to_log = {} | ||||
|             for key, val in self.train_atom_metrics.compute().items(): | ||||
|                 to_log['train/' + key] = val.item() | ||||
|             for key, val in self.train_bond_metrics.compute().items(): | ||||
|                 to_log['train/' + key] = val.item() | ||||
|  | ||||
|     def reset(self): | ||||
|         for metric in [self.train_atom_metrics, self.train_bond_metrics]: | ||||
|             metric.reset() | ||||
|  | ||||
|     def log_epoch_metrics(self, current_epoch, log=True): | ||||
|         epoch_atom_metrics = self.train_atom_metrics.compute() | ||||
|         epoch_bond_metrics = self.train_bond_metrics.compute() | ||||
|  | ||||
|         to_log = {} | ||||
|         for key, val in epoch_atom_metrics.items(): | ||||
|             to_log['train_epoch/' + key] = val.item() | ||||
|         for key, val in epoch_bond_metrics.items(): | ||||
|             to_log['train_epoch/' + key] = val.item() | ||||
|  | ||||
|         for key, val in epoch_atom_metrics.items(): | ||||
|             epoch_atom_metrics[key] = round(val.item(),4) | ||||
|         for key, val in epoch_bond_metrics.items(): | ||||
|             epoch_bond_metrics[key] = round(val.item(),4) | ||||
|  | ||||
|         if log: | ||||
|             print(f"Epoch {current_epoch}: {epoch_atom_metrics} -- {epoch_bond_metrics}") | ||||
|  | ||||
		Reference in New Issue
	
	Block a user