add sample phase and try to get log prob
This commit is contained in:
		| @@ -286,7 +286,7 @@ class Graph_DiT(pl.LightningModule): | |||||||
|                 samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y, |                 samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y, | ||||||
|                                                 save_final=to_save, |                                                 save_final=to_save, | ||||||
|                                                 keep_chain=chains_save, |                                                 keep_chain=chains_save, | ||||||
|                                                 number_chain_steps=self.number_chain_steps)) |                                                 number_chain_steps=self.number_chain_steps)[0]) | ||||||
|                 ident += to_generate |                 ident += to_generate | ||||||
|                 start_index += to_generate |                 start_index += to_generate | ||||||
|  |  | ||||||
| @@ -360,7 +360,7 @@ class Graph_DiT(pl.LightningModule): | |||||||
|             batch_y = torch.ones(to_generate, self.ydim_output, device=self.device) |             batch_y = torch.ones(to_generate, self.ydim_output, device=self.device) | ||||||
|  |  | ||||||
|             cur_sample = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, |             cur_sample = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, | ||||||
|                                             keep_chain=chains_save, number_chain_steps=self.number_chain_steps) |                                             keep_chain=chains_save, number_chain_steps=self.number_chain_steps)[0] | ||||||
|             samples = samples + cur_sample |             samples = samples + cur_sample | ||||||
|              |              | ||||||
|             all_ys.append(batch_y) |             all_ys.append(batch_y) | ||||||
| @@ -601,6 +601,8 @@ class Graph_DiT(pl.LightningModule): | |||||||
|  |  | ||||||
|         assert (E == torch.transpose(E, 1, 2)).all() |         assert (E == torch.transpose(E, 1, 2)).all() | ||||||
|  |  | ||||||
|  |         total_log_probs = torch.zeros(batch_size, device=self.device) | ||||||
|  |  | ||||||
|         # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1. |         # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1. | ||||||
|         for s_int in reversed(range(0, self.T)): |         for s_int in reversed(range(0, self.T)): | ||||||
|             s_array = s_int * torch.ones((batch_size, 1)).type_as(y) |             s_array = s_int * torch.ones((batch_size, 1)).type_as(y) | ||||||
| @@ -609,21 +611,22 @@ class Graph_DiT(pl.LightningModule): | |||||||
|             t_norm = t_array / self.T |             t_norm = t_array / self.T | ||||||
|  |  | ||||||
|             # Sample z_s |             # Sample z_s | ||||||
|             sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask) |             sampled_s, discrete_sampled_s, log_probs= self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask) | ||||||
|             X, E, y = sampled_s.X, sampled_s.E, sampled_s.y |             X, E, y = sampled_s.X, sampled_s.E, sampled_s.y | ||||||
|  |             total_log_probs += log_probs | ||||||
|  |  | ||||||
|         # Sample |         # Sample | ||||||
|         sampled_s = sampled_s.mask(node_mask, collapse=True) |         sampled_s = sampled_s.mask(node_mask, collapse=True) | ||||||
|         X, E, y = sampled_s.X, sampled_s.E, sampled_s.y |         X, E, y = sampled_s.X, sampled_s.E, sampled_s.y | ||||||
|          |          | ||||||
|         molecule_list = [] |         graph_list = [] | ||||||
|         for i in range(batch_size): |         for i in range(batch_size): | ||||||
|             n = n_nodes[i] |             n = n_nodes[i] | ||||||
|             atom_types = X[i, :n].cpu() |             node_types = X[i, :n].cpu() | ||||||
|             edge_types = E[i, :n, :n].cpu() |             edge_types = E[i, :n, :n].cpu() | ||||||
|             molecule_list.append([atom_types, edge_types]) |             graph_list.append([node_types, edge_types]) | ||||||
|          |          | ||||||
|         return molecule_list |         return graph_list, total_log_probs | ||||||
|  |  | ||||||
|     def sample_p_zs_given_zt(self, s, t, X_t, E_t, y_t, node_mask): |     def sample_p_zs_given_zt(self, s, t, X_t, E_t, y_t, node_mask): | ||||||
|         """Samples from zs ~ p(zs | zt). Only used during sampling. |         """Samples from zs ~ p(zs | zt). Only used during sampling. | ||||||
| @@ -635,6 +638,7 @@ class Graph_DiT(pl.LightningModule): | |||||||
|  |  | ||||||
|         # Neural net predictions |         # Neural net predictions | ||||||
|         noisy_data = {'X_t': X_t, 'E_t': E_t, 'y_t': y_t, 't': t, 'node_mask': node_mask} |         noisy_data = {'X_t': X_t, 'E_t': E_t, 'y_t': y_t, 't': t, 'node_mask': node_mask} | ||||||
|  |         print(f"sample p zs given zt X_t shape: {X_t.shape}, E_t shape: {E_t.shape}, y_t shape: {y_t.shape}, node_mask shape: {node_mask.shape}") | ||||||
|          |          | ||||||
|         def get_prob(noisy_data, unconditioned=False): |         def get_prob(noisy_data, unconditioned=False): | ||||||
|             pred = self.forward(noisy_data, unconditioned=unconditioned) |             pred = self.forward(noisy_data, unconditioned=unconditioned) | ||||||
| @@ -674,6 +678,17 @@ class Graph_DiT(pl.LightningModule): | |||||||
|         # with condition = P_t(G_{t-1} |G_t, C) |         # with condition = P_t(G_{t-1} |G_t, C) | ||||||
|         # with condition = P_t(A_{t-1} |A_t, y) |         # with condition = P_t(A_{t-1} |A_t, y) | ||||||
|         prob_X, prob_E, pred = get_prob(noisy_data) |         prob_X, prob_E, pred = get_prob(noisy_data) | ||||||
|  |         print(f'prob_X shape: {prob_X.shape}, prob_E shape: {prob_E.shape}') | ||||||
|  |         print(f'X_t shape: {X_t.shape}, E_t shape: {E_t.shape}, y_t shape: {y_t.shape}') | ||||||
|  |         print(f'X_t: {X_t}') | ||||||
|  |         log_prob_X = torch.log(torch.gather(prob_X, -1, X_t.long()).squeeze(-1))  # bs, n | ||||||
|  |         log_prob_E = torch.log(torch.gather(prob_E, -1, E_t.long()).squeeze(-1))  # bs, n, n | ||||||
|  |  | ||||||
|  |         # Sum the log_prob across dimensions for total log_prob | ||||||
|  |         log_prob_X = log_prob_X.sum(dim=-1) | ||||||
|  |         log_prob_E = log_prob_E.sum(dim=(1, 2)) | ||||||
|  |         print(f'log_prob_X shape: {log_prob_X.shape}, log_prob_E shape: {log_prob_E.shape}') | ||||||
|  |         log_probs = log_prob_E + log_prob_X | ||||||
|  |  | ||||||
|         ### Guidance |         ### Guidance | ||||||
|         if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1: |         if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1: | ||||||
| @@ -810,4 +825,4 @@ class Graph_DiT(pl.LightningModule): | |||||||
|         out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=y_t) |         out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=y_t) | ||||||
|         out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=y_t) |         out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=y_t) | ||||||
|  |  | ||||||
|         return out_one_hot.mask(node_mask).type_as(y_t), out_discrete.mask(node_mask, collapse=True).type_as(y_t) |         return out_one_hot.mask(node_mask).type_as(y_t), out_discrete.mask(node_mask, collapse=True).type_as(y_t), log_probs | ||||||
|   | |||||||
| @@ -1,5 +1,5 @@ | |||||||
| # These imports are tricky because they use c++, do not move them | # These imports are tricky because they use c++, do not move them | ||||||
| import tqdm | from tqdm import tqdm | ||||||
| import os, shutil | import os, shutil | ||||||
| import warnings | import warnings | ||||||
|  |  | ||||||
| @@ -233,27 +233,62 @@ def test(cfg: DictConfig): | |||||||
|             optimizer.zero_grad() |             optimizer.zero_grad() | ||||||
|             # return {'loss': loss} |             # return {'loss': loss} | ||||||
|      |      | ||||||
|  |     # start testing | ||||||
|  |     print("start testing") | ||||||
|  |     graph_dit_model.eval() | ||||||
|  |     test_dataloader = accelerator.prepare(datamodule.test_dataloader()) | ||||||
|  |     for data in test_dataloader: | ||||||
|  |         data_x = F.one_hot(data.x, num_classes=12).float()[:, graph_dit_model.active_index] | ||||||
|  |         data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() | ||||||
|  |  | ||||||
|  |         dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, graph_dit_model.max_n_nodes) | ||||||
|  |         dense_data = dense_data.mask(node_mask) | ||||||
|  |         noisy_data = graph_dit_model.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) | ||||||
|  |         pred = graph_dit_model.forward(noisy_data) | ||||||
|  |         nll = graph_dit_model.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=True) | ||||||
|  |         graph_dit_model.test_y_collection.append(data.y) | ||||||
|  |         print(f'test loss: {nll}') | ||||||
|  |  | ||||||
|     # start sampling |     # start sampling | ||||||
|  |  | ||||||
|     samples = [] |     samples_left_to_generate = cfg.general.final_model_samples_to_generate | ||||||
|  |     samples_left_to_save = cfg.general.final_model_samples_to_save | ||||||
|  |     chains_left_to_save = cfg.general.final_model_chains_to_save | ||||||
|  |  | ||||||
|     for i in tqdm( |     samples, all_ys, batch_id = [], [], 0 | ||||||
|         range(cfg.general.n_samples), desc="Sampling", disable=not cfg.general.enable_progress_bar |     test_y_collection = torch.cat(graph_dit_model.test_y_collection, dim=0) | ||||||
|     ): |     num_examples = test_y_collection.size(0) | ||||||
|         batch_size = cfg.train.batch_size |     if cfg.general.final_model_samples_to_generate > num_examples: | ||||||
|         num_steps = cfg.model.diffusion_steps |         ratio = cfg.general.final_model_samples_to_generate // num_examples | ||||||
|         y = torch.ones(batch_size, num_steps, 1, 1, device=accelerator.device, dtype=inference_dtype) |         test_y_collection = test_y_collection.repeat(ratio+1, 1) | ||||||
|  |         num_examples = test_y_collection.size(0) | ||||||
|      |      | ||||||
|         # sample from the model |     while samples_left_to_generate > 0: | ||||||
|         samples_batch = graph_dit_model.sample_batch( |         print(f'samples left to generate: {samples_left_to_generate}/' | ||||||
|             batch_id=i, |             f'{cfg.general.final_model_samples_to_generate}', end='', flush=True) | ||||||
|             batch_size=batch_size, |         bs = 1 * cfg.train.batch_size | ||||||
|             y=y, |         to_generate = min(samples_left_to_generate, bs) | ||||||
|             keep_chain=1, |         to_save = min(samples_left_to_save, bs) | ||||||
|             number_chain_steps=num_steps, |         chains_save = min(chains_left_to_save, bs) | ||||||
|             save_final=batch_size |         # batch_y = test_y_collection[batch_id : batch_id + to_generate] | ||||||
|         ) |         batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) | ||||||
|         samples.append(samples_batch) |  | ||||||
|  |         cur_sample = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, | ||||||
|  |                                         keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)[0] | ||||||
|  |         samples = samples + cur_sample | ||||||
|  |          | ||||||
|  |         all_ys.append(batch_y) | ||||||
|  |         batch_id += to_generate | ||||||
|  |  | ||||||
|  |         samples_left_to_save -= to_save | ||||||
|  |         samples_left_to_generate -= to_generate | ||||||
|  |         chains_left_to_save -= chains_save | ||||||
|  |          | ||||||
|  |     print(f"final Computing sampling metrics...") | ||||||
|  |     graph_dit_model.sampling_metrics.reset() | ||||||
|  |     graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True) | ||||||
|  |     graph_dit_model.sampling_metrics.reset() | ||||||
|  |     print(f"Done.") | ||||||
|  |  | ||||||
|     # save samples |     # save samples | ||||||
|     print("Samples:") |     print("Samples:") | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user