try to fix the dimension problem in apply noise method
This commit is contained in:
parent
82299e5213
commit
8bbadce19c
@ -179,9 +179,9 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def validation_step(self, data, i):
|
def validation_step(self, data, i):
|
||||||
data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index]
|
data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index]
|
||||||
data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float()
|
data_edge_attr = F.one_hot(data.edge_attr, num_classes=10).float()
|
||||||
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
|
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
|
||||||
dense_data = dense_data.mask(node_mask)
|
dense_data = dense_data.mask(node_mask, collapse=False)
|
||||||
noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask)
|
noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask)
|
||||||
pred = self.forward(noisy_data)
|
pred = self.forward(noisy_data)
|
||||||
nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=False)
|
nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=False)
|
||||||
@ -444,9 +444,11 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1)
|
beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1)
|
||||||
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1)
|
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1)
|
||||||
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1)
|
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1)
|
||||||
|
print(f"alpha_t_bar.shape {alpha_t_bar.shape}")
|
||||||
|
|
||||||
Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device) # (bs, dx_in, dx_out), (bs, de_in, de_out)
|
Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device) # (bs, dx_in, dx_out), (bs, de_in, de_out)
|
||||||
|
print(f"E.shape {E.shape}")
|
||||||
|
print(f"X.shape {X.shape}")
|
||||||
bs, n, d = X.shape
|
bs, n, d = X.shape
|
||||||
X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
|
X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
|
||||||
prob_all = X_all @ Qtb.X
|
prob_all = X_all @ Qtb.X
|
||||||
|
Loading…
Reference in New Issue
Block a user