From 3950a8438d3de906032e25c0bc1f11191b288ba6 Mon Sep 17 00:00:00 2001 From: mhz Date: Tue, 20 Aug 2024 22:15:25 +0200 Subject: [PATCH] set batch_y to 1 and want to test 15625 --- graph_dit/datasets/dataset.py | 2 +- graph_dit/diffusion_model.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/graph_dit/datasets/dataset.py b/graph_dit/datasets/dataset.py index 94d9437..9db49eb 100644 --- a/graph_dit/datasets/dataset.py +++ b/graph_dit/datasets/dataset.py @@ -781,7 +781,7 @@ class Dataset(InMemoryDataset): print(f'idx={idx}, y={y}') y = torch.tensor([0, 0], dtype=torch.float).view(1, -1) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i) - return None + # return None return data graph_list = [] class Args: diff --git a/graph_dit/diffusion_model.py b/graph_dit/diffusion_model.py index 79cee7d..a3f0993 100644 --- a/graph_dit/diffusion_model.py +++ b/graph_dit/diffusion_model.py @@ -356,7 +356,8 @@ class Graph_DiT(pl.LightningModule): to_generate = min(samples_left_to_generate, bs) to_save = min(samples_left_to_save, bs) chains_save = min(chains_left_to_save, bs) - batch_y = test_y_collection[batch_id : batch_id + to_generate] + # batch_y = test_y_collection[batch_id : batch_id + to_generate] + batch_y = torch.ones(to_generate, self.ydim_output, device=self.device) cur_sample = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, keep_chain=chains_save, number_chain_steps=self.number_chain_steps)