try to fix the dimension problem in apply noise method

This commit is contained in:
mhz 2024-06-26 21:58:08 +02:00
parent 82299e5213
commit 8bbadce19c

View File

@ -179,9 +179,9 @@ class Graph_DiT(pl.LightningModule):
@torch.no_grad() @torch.no_grad()
def validation_step(self, data, i): def validation_step(self, data, i):
data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index] data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index]
data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float() data_edge_attr = F.one_hot(data.edge_attr, num_classes=10).float()
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
dense_data = dense_data.mask(node_mask) dense_data = dense_data.mask(node_mask, collapse=False)
noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask)
pred = self.forward(noisy_data) pred = self.forward(noisy_data)
nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=False) nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=False)
@ -444,9 +444,11 @@ class Graph_DiT(pl.LightningModule):
beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1) beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1)
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1) alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1)
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1) alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1)
print(f"alpha_t_bar.shape {alpha_t_bar.shape}")
Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device) # (bs, dx_in, dx_out), (bs, de_in, de_out) Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device) # (bs, dx_in, dx_out), (bs, de_in, de_out)
print(f"E.shape {E.shape}")
print(f"X.shape {X.shape}")
bs, n, d = X.shape bs, n, d = X.shape
X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1) X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
prob_all = X_all @ Qtb.X prob_all = X_all @ Qtb.X