138 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			138 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | import torch | ||
|  | from torch import Tensor | ||
|  | from torch.nn import functional as F | ||
|  | from torchmetrics import Metric, MeanSquaredError | ||
|  | 
 | ||
|  | 
 | ||
|  | class TrainAbstractMetricsDiscrete(torch.nn.Module): | ||
|  |     def __init__(self): | ||
|  |         super().__init__() | ||
|  | 
 | ||
|  |     def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool): | ||
|  |         pass | ||
|  | 
 | ||
|  |     def reset(self): | ||
|  |         pass | ||
|  | 
 | ||
|  |     def log_epoch_metrics(self, current_epoch): | ||
|  |         pass | ||
|  | 
 | ||
|  | 
 | ||
|  | class TrainAbstractMetrics(torch.nn.Module): | ||
|  |     def __init__(self): | ||
|  |         super().__init__() | ||
|  | 
 | ||
|  |     def forward(self, masked_pred_epsX, masked_pred_epsE, pred_y, true_epsX, true_epsE, true_y, log): | ||
|  |         pass | ||
|  | 
 | ||
|  |     def reset(self): | ||
|  |         pass | ||
|  | 
 | ||
|  |     def log_epoch_metrics(self, current_epoch): | ||
|  |         pass | ||
|  | 
 | ||
|  | 
 | ||
|  | class SumExceptBatchMetric(Metric): | ||
|  |     def __init__(self): | ||
|  |         super().__init__() | ||
|  |         self.add_state('total_value', default=torch.tensor(0.), dist_reduce_fx="sum") | ||
|  |         self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") | ||
|  | 
 | ||
|  |     def update(self, values) -> None: | ||
|  |         self.total_value += torch.sum(values) | ||
|  |         self.total_samples += values.shape[0] | ||
|  | 
 | ||
|  |     def compute(self): | ||
|  |         return self.total_value / self.total_samples | ||
|  | 
 | ||
|  | 
 | ||
|  | class SumExceptBatchMSE(MeanSquaredError): | ||
|  |     def update(self, preds: Tensor, target: Tensor) -> None: | ||
|  |         """Update state with predictions and targets.
 | ||
|  | 
 | ||
|  |         Args: | ||
|  |             preds: Predictions from model | ||
|  |             target: Ground truth values | ||
|  |         """
 | ||
|  |         assert preds.shape == target.shape | ||
|  |         sum_squared_error, n_obs = self._mean_squared_error_update(preds, target) | ||
|  | 
 | ||
|  |         self.sum_squared_error += sum_squared_error | ||
|  |         self.total += n_obs | ||
|  | 
 | ||
|  |     def _mean_squared_error_update(self, preds: Tensor, target: Tensor): | ||
|  |             """ Updates and returns variables required to compute Mean Squared Error. Checks for same shape of input
 | ||
|  |             tensors. | ||
|  |                 preds: Predicted tensor | ||
|  |                 target: Ground truth tensor | ||
|  |             """
 | ||
|  |             diff = preds - target | ||
|  |             sum_squared_error = torch.sum(diff * diff) | ||
|  |             n_obs = preds.shape[0] | ||
|  |             return sum_squared_error, n_obs | ||
|  | 
 | ||
|  | 
 | ||
|  | class SumExceptBatchKL(Metric): | ||
|  |     def __init__(self): | ||
|  |         super().__init__() | ||
|  |         self.add_state('total_value', default=torch.tensor(0.), dist_reduce_fx="sum") | ||
|  |         self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") | ||
|  | 
 | ||
|  |     def update(self, p, q) -> None: | ||
|  |         self.total_value += F.kl_div(q, p, reduction='sum') | ||
|  |         self.total_samples += p.size(0) | ||
|  | 
 | ||
|  |     def compute(self): | ||
|  |         return self.total_value / self.total_samples | ||
|  | 
 | ||
|  | 
 | ||
|  | class CrossEntropyMetric(Metric): | ||
|  |     def __init__(self): | ||
|  |         super().__init__() | ||
|  |         self.add_state('total_ce', default=torch.tensor(0.), dist_reduce_fx="sum") | ||
|  |         self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") | ||
|  | 
 | ||
|  |     def update(self, preds: Tensor, target: Tensor, weight=None) -> None: | ||
|  |         """ Update state with predictions and targets.
 | ||
|  |             preds: Predictions from model   (bs * n, d) or (bs * n * n, d) | ||
|  |             target: Ground truth values     (bs * n, d) or (bs * n * n, d). """
 | ||
|  |         target = torch.argmax(target, dim=-1) | ||
|  |         if weight is not None: | ||
|  |             weight = weight.type_as(preds) | ||
|  |             output = F.cross_entropy(preds, target, weight = weight, reduction='sum') | ||
|  |         else: | ||
|  |             output = F.cross_entropy(preds, target, reduction='sum') | ||
|  |         self.total_ce += output | ||
|  |         self.total_samples += preds.size(0) | ||
|  | 
 | ||
|  |     def compute(self): | ||
|  |         return self.total_ce / self.total_samples | ||
|  | 
 | ||
|  | 
 | ||
|  | class ProbabilityMetric(Metric): | ||
|  |     def __init__(self): | ||
|  |         """ This metric is used to track the marginal predicted probability of a class during training. """ | ||
|  |         super().__init__() | ||
|  |         self.add_state('prob', default=torch.tensor(0.), dist_reduce_fx="sum") | ||
|  |         self.add_state('total', default=torch.tensor(0.), dist_reduce_fx="sum") | ||
|  | 
 | ||
|  |     def update(self, preds: Tensor) -> None: | ||
|  |         self.prob += preds.sum() | ||
|  |         self.total += preds.numel() | ||
|  | 
 | ||
|  |     def compute(self): | ||
|  |         return self.prob / self.total | ||
|  | 
 | ||
|  | 
 | ||
|  | class NLL(Metric): | ||
|  |     def __init__(self): | ||
|  |         super().__init__() | ||
|  |         self.add_state('total_nll', default=torch.tensor(0.), dist_reduce_fx="sum") | ||
|  |         self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") | ||
|  | 
 | ||
|  |     def update(self, batch_nll) -> None: | ||
|  |         self.total_nll += torch.sum(batch_nll) | ||
|  |         self.total_samples += batch_nll.numel() | ||
|  | 
 | ||
|  |     def compute(self): | ||
|  |         return self.total_nll / self.total_samples |