make the metrics code back
This commit is contained in:
		| @@ -13,11 +13,11 @@ from metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL | |||||||
| import utils | import utils | ||||||
|  |  | ||||||
| class Graph_DiT(pl.LightningModule): | class Graph_DiT(pl.LightningModule): | ||||||
|     # def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools): |     def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools): | ||||||
|     def __init__(self, cfg, dataset_infos, visualization_tools): |     # def __init__(self, cfg, dataset_infos, visualization_tools): | ||||||
|  |  | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         # self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics']) |         self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics']) | ||||||
|         self.test_only = cfg.general.test_only |         self.test_only = cfg.general.test_only | ||||||
|         self.guidance_target = getattr(cfg.dataset, 'guidance_target', None) |         self.guidance_target = getattr(cfg.dataset, 'guidance_target', None) | ||||||
|  |  | ||||||
| @@ -57,8 +57,8 @@ class Graph_DiT(pl.LightningModule): | |||||||
|         self.test_E_logp = SumExceptBatchMetric() |         self.test_E_logp = SumExceptBatchMetric() | ||||||
|         self.test_y_collection = [] |         self.test_y_collection = [] | ||||||
|  |  | ||||||
|         # self.train_metrics = train_metrics |         self.train_metrics = train_metrics | ||||||
|         # self.sampling_metrics = sampling_metrics |         self.sampling_metrics = sampling_metrics | ||||||
|  |  | ||||||
|         self.visualization_tools = visualization_tools |         self.visualization_tools = visualization_tools | ||||||
|         self.max_n_nodes = dataset_infos.max_n_nodes |         self.max_n_nodes = dataset_infos.max_n_nodes | ||||||
| @@ -179,9 +179,9 @@ class Graph_DiT(pl.LightningModule): | |||||||
|     @torch.no_grad() |     @torch.no_grad() | ||||||
|     def validation_step(self, data, i): |     def validation_step(self, data, i): | ||||||
|         data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index] |         data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index] | ||||||
|         data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float() |         data_edge_attr = F.one_hot(data.edge_attr, num_classes=10).float() | ||||||
|         dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) |         dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) | ||||||
|         dense_data = dense_data.mask(node_mask) |         dense_data = dense_data.mask(node_mask, collapse=True) | ||||||
|         noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) |         noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) | ||||||
|         pred = self.forward(noisy_data) |         pred = self.forward(noisy_data) | ||||||
|         nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=False) |         nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=False) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user