find the guidance part

This commit is contained in:
mhz 2024-07-16 13:27:44 +02:00
parent 0b9da26eda
commit fcdd8efc4f

View File

@ -609,6 +609,7 @@ class Graph_DiT(pl.LightningModule):
Qt = self.transition_model.get_Qt(beta_t, self.device)
Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1)
# p(x_0|x_t)
p_s_and_t_given_0 = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=Xt_all,
Qt=Qt.X,
Qsb=Qsb.X,