comment some output statements
This commit is contained in:
parent
dd31fda8d5
commit
817ef04c58
@ -87,7 +87,7 @@ class Denoiser(nn.Module):
|
||||
def forward(self, x, e, node_mask, y, t, unconditioned):
|
||||
|
||||
print("Denoiser Forward")
|
||||
print(x.shape, e.shape, y.shape, t.shape, unconditioned)
|
||||
# print(x.shape, e.shape, y.shape, t.shape, unconditioned)
|
||||
force_drop_id = torch.zeros_like(y.sum(-1))
|
||||
# drop the nan values
|
||||
force_drop_id[torch.isnan(y.sum(-1))] = 1
|
||||
@ -98,32 +98,32 @@ class Denoiser(nn.Module):
|
||||
# bs = batch size, n = number of nodes
|
||||
bs, n, _ = x.size()
|
||||
x = torch.cat([x, e.reshape(bs, n, -1)], dim=-1)
|
||||
print("X after concat with E")
|
||||
print(x.shape)
|
||||
# print("X after concat with E")
|
||||
# print(x.shape)
|
||||
# self.x_embedder = nn.Linear(Xdim + max_n_nodes * Edim, hidden_size, bias=False)
|
||||
x = self.x_embedder(x)
|
||||
print("X after x_embedder")
|
||||
print(x.shape)
|
||||
# print("X after x_embedder")
|
||||
# print(x.shape)
|
||||
|
||||
# self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
c1 = self.t_embedder(t)
|
||||
print("C1 after t_embedder")
|
||||
print(c1.shape)
|
||||
# print("C1 after t_embedder")
|
||||
# print(c1.shape)
|
||||
for i in range(1, self.ydim):
|
||||
if i == 1:
|
||||
c2 = self.y_embedding_list[i-1](y[:, :2], self.training, force_drop_id, t)
|
||||
else:
|
||||
c2 = c2 + self.y_embedding_list[i-1](y[:, i:i+1], self.training, force_drop_id, t)
|
||||
print("C2 after y_embedding_list")
|
||||
print(c2.shape)
|
||||
print("C1 + C2")
|
||||
# print("C2 after y_embedding_list")
|
||||
# print(c2.shape)
|
||||
# print("C1 + C2")
|
||||
c = c1 + c2
|
||||
print(c.shape)
|
||||
# print(c.shape)
|
||||
|
||||
for i, block in enumerate(self.encoders):
|
||||
x = block(x, c, node_mask)
|
||||
print("X after block")
|
||||
print(x.shape)
|
||||
# print("X after block")
|
||||
# print(x.shape)
|
||||
|
||||
# X: B * N * dx, E: B * N * N * de
|
||||
X, E, y = self.out_layer(x, x_in, e_in, c, t, node_mask)
|
||||
|
Loading…
Reference in New Issue
Block a user