add some shape commits
This commit is contained in:
		| @@ -78,8 +78,8 @@ class Graph_DiT(pl.LightningModule): | ||||
|                                                               timesteps=cfg.model.diffusion_steps) | ||||
|  | ||||
|  | ||||
|         print("__init__") | ||||
|         print("dataset_info.node_types", self.dataset_info.node_types) | ||||
|         # print("__init__") | ||||
|         # print("dataset_info.node_types", self.dataset_info.node_types) | ||||
|         # dataset_info.node_types tensor([7.4826e-01, 2.6870e-02, 9.3930e-02, 4.4959e-02, 5.2982e-03, 7.5689e-04, 5.3739e-03, 1.5138e-03, 7.5689e-05, 4.3143e-03, 6.8650e-02]) | ||||
|         x_marginals = self.dataset_info.node_types.float() / torch.sum(self.dataset_info.node_types.float()) | ||||
|          | ||||
| @@ -123,8 +123,8 @@ class Graph_DiT(pl.LightningModule): | ||||
|         return pred | ||||
|          | ||||
|     def training_step(self, data, i): | ||||
|         data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index] | ||||
|         data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float() | ||||
|         data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index] | ||||
|         data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() | ||||
|  | ||||
|         dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) | ||||
|         dense_data = dense_data.mask(node_mask) | ||||
| @@ -138,6 +138,9 @@ class Graph_DiT(pl.LightningModule): | ||||
|         self.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E, | ||||
|                         log=i % self.log_every_steps == 0) | ||||
|         self.log(f'loss', loss, batch_size=X.size(0), sync_dist=True) | ||||
|         print(f"training loss: {loss}") | ||||
|         with open("training-loss.csv", "a") as f: | ||||
|             f.write(f"{loss}, {i}\n") | ||||
|         return {'loss': loss} | ||||
|  | ||||
|  | ||||
| @@ -150,7 +153,7 @@ class Graph_DiT(pl.LightningModule): | ||||
|     def on_fit_start(self) -> None: | ||||
|         self.train_iterations = self.trainer.datamodule.training_iterations | ||||
|         print('on fit train iteration:', self.train_iterations) | ||||
|         print("Size of the input features Xdim {}, Edim {}, ydim {}".format(self.Xdim, self.Edim, self.ydim)) | ||||
|         # print("Size of the input features Xdim {}, Edim {}, ydim {}".format(self.Xdim, self.Edim, self.ydim)) | ||||
|  | ||||
|     def on_train_epoch_start(self) -> None: | ||||
|         if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
| @@ -160,10 +163,12 @@ class Graph_DiT(pl.LightningModule): | ||||
|         self.train_metrics.reset() | ||||
|  | ||||
|     def on_train_epoch_end(self) -> None: | ||||
|  | ||||
|         if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
|             log = True | ||||
|         else: | ||||
|             log = False | ||||
|         log = True | ||||
|         self.train_loss.log_epoch_metrics(self.current_epoch, self.start_epoch_time, log) | ||||
|         self.train_metrics.log_epoch_metrics(self.current_epoch, log) | ||||
|  | ||||
| @@ -178,22 +183,31 @@ class Graph_DiT(pl.LightningModule): | ||||
|  | ||||
|     @torch.no_grad() | ||||
|     def validation_step(self, data, i): | ||||
|         data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index] | ||||
|         data_edge_attr = F.one_hot(data.edge_attr, num_classes=10).float() | ||||
|         data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index] | ||||
|         data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() | ||||
|         dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) | ||||
|         dense_data = dense_data.mask(node_mask, collapse=True) | ||||
|         dense_data = dense_data.mask(node_mask, collapse=False) | ||||
|         noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) | ||||
|         pred = self.forward(noisy_data) | ||||
|         nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=False) | ||||
|         self.val_y_collection.append(data.y) | ||||
|         self.log(f'valid_nll', nll, batch_size=data.x.size(0), sync_dist=True) | ||||
|         print(f'validation loss: {nll}, epoch: {self.current_epoch}') | ||||
|         return {'loss': nll} | ||||
|  | ||||
|     def on_validation_epoch_end(self) -> None: | ||||
|         metrics = [self.val_nll.compute(), self.val_X_kl.compute() * self.T, self.val_E_kl.compute() * self.T, | ||||
|          | ||||
|                    self.val_X_logp.compute(), self.val_E_logp.compute()] | ||||
|          | ||||
|         if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
|             print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ", | ||||
|                 f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best :  %.2f\n' % (metrics[0], self.best_val_nll)) | ||||
|         with open("validation-metrics.csv", "a") as f: | ||||
|             # save the metrics as csv file | ||||
|             f.write(f"{self.current_epoch}, {metrics[0]}, {metrics[1]}, {metrics[2]}, {metrics[3]}, {metrics[4]}\n") | ||||
|              | ||||
|  | ||||
|         print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ", | ||||
|               f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best :  %.2f\n' % (metrics[0], self.best_val_nll)) | ||||
|  | ||||
| @@ -241,15 +255,15 @@ class Graph_DiT(pl.LightningModule): | ||||
|                 samples_left_to_generate -= to_generate | ||||
|                 chains_left_to_save -= chains_save | ||||
|  | ||||
|             # print(f"Computing sampling metrics", ' ...') | ||||
|             # valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False) | ||||
|             # print(f'Done. Sampling took {time.time() - start:.2f} seconds\n') | ||||
|             print(f"Computing sampling metrics", ' ...') | ||||
|             valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False) | ||||
|             print(f'Done. Sampling took {time.time() - start:.2f} seconds\n') | ||||
|  | ||||
|             current_path = os.getcwd() | ||||
|             result_path = os.path.join(current_path, | ||||
|                                        f'graphs/{self.name}/epoch{self.current_epoch}_b0/') | ||||
|             # current_path = os.getcwd() | ||||
|             # result_path = os.path.join(current_path, | ||||
|                                     #    f'graphs/{self.name}/epoch{self.current_epoch}_b0/') | ||||
|             # self.visualization_tools.visualize_by_smiles(result_path, valid_smiles, self.cfg.general.samples_to_save) | ||||
|             # self.sampling_metrics.reset() | ||||
|             self.sampling_metrics.reset() | ||||
|  | ||||
|     def on_test_epoch_start(self) -> None: | ||||
|         print("Starting test...") | ||||
| @@ -262,8 +276,8 @@ class Graph_DiT(pl.LightningModule): | ||||
|      | ||||
|     @torch.no_grad() | ||||
|     def test_step(self, data, i): | ||||
|         data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index] | ||||
|         data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float() | ||||
|         data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index] | ||||
|         data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() | ||||
|  | ||||
|         dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) | ||||
|         dense_data = dense_data.mask(node_mask) | ||||
| @@ -277,6 +291,8 @@ class Graph_DiT(pl.LightningModule): | ||||
|         """ Measure likelihood on a test set and compute stability metrics. """ | ||||
|         metrics = [self.test_nll.compute(), self.test_X_kl.compute(), self.test_E_kl.compute(), | ||||
|                    self.test_X_logp.compute(), self.test_E_logp.compute()] | ||||
|         with open("test-metrics.csv", "a") as f: | ||||
|             f.write(f"{self.current_epoch}, {metrics[0]}, {metrics[1]}, {metrics[2]}, {metrics[3]}, {metrics[4]}\n") | ||||
|  | ||||
|         print(f"Epoch {self.current_epoch}: Test NLL {metrics[0] :.2f} -- Test Atom type KL {metrics[1] :.2f} -- ", | ||||
|               f"Test Edge type KL: {metrics[2] :.2f}") | ||||
| @@ -433,10 +449,12 @@ class Graph_DiT(pl.LightningModule): | ||||
|  | ||||
|         # Sample a timestep t. | ||||
|         # When evaluating, the loss for t=0 is computed separately | ||||
|         # print(f"apply_noise X shape: {X.shape}, E shape: {E.shape}, y shape: {y.shape}, node_mask shape: {node_mask.shape}") | ||||
|         lowest_t = 0 if self.training else 1 | ||||
|         t_int = torch.randint(lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device).float()  # (bs, 1) | ||||
|         s_int = t_int - 1 | ||||
|  | ||||
|  | ||||
|         t_float = t_int / self.T | ||||
|         s_float = s_int / self.T | ||||
|  | ||||
| @@ -444,10 +462,23 @@ class Graph_DiT(pl.LightningModule): | ||||
|         beta_t = self.noise_schedule(t_normalized=t_float)                         # (bs, 1) | ||||
|         alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float)      # (bs, 1) | ||||
|         alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float)      # (bs, 1) | ||||
|         # print(f"alpha_s_bar: {alpha_s_bar.shape}, alpha_t_bar: {alpha_t_bar.shape}") | ||||
|  | ||||
|         Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device)  # (bs, dx_in, dx_out), (bs, de_in, de_out) | ||||
|         # print(f"X shape: {X.shape}, E shape: {E.shape}, node_mask shape: {node_mask.shape}") | ||||
|         # print(f"Qtb shape: {Qtb.X.shape}") | ||||
|         """ | ||||
|         X shape: torch.Size([1200, 8]),  | ||||
|         E shape: torch.Size([1200, 8, 8]),  | ||||
|         y shape: torch.Size([1200, 1]),  | ||||
|         node_mask shape: torch.Size([1200, 8]) | ||||
|         alpha_s_bar: torch.Size([1200, 1]), alpha_t_bar: torch.Size([1200, 1]) | ||||
|         """ | ||||
|          | ||||
|         # print(X.shape) | ||||
|         bs, n, d = X.shape | ||||
|         E = E[..., :2] | ||||
|         # bs, n = X.shape | ||||
|         X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1) | ||||
|         prob_all = X_all @ Qtb.X | ||||
|         probX = prob_all[:, :, :self.Xdim_output] | ||||
| @@ -457,6 +488,7 @@ class Graph_DiT(pl.LightningModule): | ||||
|  | ||||
|         X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output) | ||||
|         E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output) | ||||
|         # print(f"X.shape: {X.shape}, X_t shape: {X_t.shape}, E.shape: {E.shape},  E_t shape: {E_t.shape}") | ||||
|         assert (X.shape == X_t.shape) and (E.shape == E_t.shape) | ||||
|  | ||||
|         y_t = y | ||||
|   | ||||
		Reference in New Issue
	
	Block a user