make the metrics code back

This commit is contained in:
mhz 2024-06-30 16:43:08 +02:00
parent 7274b3f606
commit d57575586d

View File

@ -13,11 +13,11 @@ from metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL
import utils
class Graph_DiT(pl.LightningModule):
# def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools):
def __init__(self, cfg, dataset_infos, visualization_tools):
def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools):
# def __init__(self, cfg, dataset_infos, visualization_tools):
super().__init__()
# self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics'])
self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics'])
self.test_only = cfg.general.test_only
self.guidance_target = getattr(cfg.dataset, 'guidance_target', None)
@ -57,8 +57,8 @@ class Graph_DiT(pl.LightningModule):
self.test_E_logp = SumExceptBatchMetric()
self.test_y_collection = []
# self.train_metrics = train_metrics
# self.sampling_metrics = sampling_metrics
self.train_metrics = train_metrics
self.sampling_metrics = sampling_metrics
self.visualization_tools = visualization_tools
self.max_n_nodes = dataset_infos.max_n_nodes
@ -179,9 +179,9 @@ class Graph_DiT(pl.LightningModule):
@torch.no_grad()
def validation_step(self, data, i):
data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index]
data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float()
data_edge_attr = F.one_hot(data.edge_attr, num_classes=10).float()
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
dense_data = dense_data.mask(node_mask)
dense_data = dense_data.mask(node_mask, collapse=True)
noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask)
pred = self.forward(noisy_data)
nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=False)