make the metrics code back
This commit is contained in:
parent
7274b3f606
commit
d57575586d
@ -13,11 +13,11 @@ from metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL
|
||||
import utils
|
||||
|
||||
class Graph_DiT(pl.LightningModule):
|
||||
# def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools):
|
||||
def __init__(self, cfg, dataset_infos, visualization_tools):
|
||||
def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools):
|
||||
# def __init__(self, cfg, dataset_infos, visualization_tools):
|
||||
|
||||
super().__init__()
|
||||
# self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics'])
|
||||
self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics'])
|
||||
self.test_only = cfg.general.test_only
|
||||
self.guidance_target = getattr(cfg.dataset, 'guidance_target', None)
|
||||
|
||||
@ -57,8 +57,8 @@ class Graph_DiT(pl.LightningModule):
|
||||
self.test_E_logp = SumExceptBatchMetric()
|
||||
self.test_y_collection = []
|
||||
|
||||
# self.train_metrics = train_metrics
|
||||
# self.sampling_metrics = sampling_metrics
|
||||
self.train_metrics = train_metrics
|
||||
self.sampling_metrics = sampling_metrics
|
||||
|
||||
self.visualization_tools = visualization_tools
|
||||
self.max_n_nodes = dataset_infos.max_n_nodes
|
||||
@ -179,9 +179,9 @@ class Graph_DiT(pl.LightningModule):
|
||||
@torch.no_grad()
|
||||
def validation_step(self, data, i):
|
||||
data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index]
|
||||
data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float()
|
||||
data_edge_attr = F.one_hot(data.edge_attr, num_classes=10).float()
|
||||
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
|
||||
dense_data = dense_data.mask(node_mask)
|
||||
dense_data = dense_data.mask(node_mask, collapse=True)
|
||||
noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask)
|
||||
pred = self.forward(noisy_data)
|
||||
nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=False)
|
||||
|
Loading…
Reference in New Issue
Block a user