find the guidance part
This commit is contained in:
parent
0b9da26eda
commit
fcdd8efc4f
@ -609,6 +609,7 @@ class Graph_DiT(pl.LightningModule):
|
||||
Qt = self.transition_model.get_Qt(beta_t, self.device)
|
||||
|
||||
Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1)
|
||||
# p(x_0|x_t)
|
||||
p_s_and_t_given_0 = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=Xt_all,
|
||||
Qt=Qt.X,
|
||||
Qsb=Qsb.X,
|
||||
|
Loading…
Reference in New Issue
Block a user