add some shape commits
This commit is contained in:
parent
7e83bf1401
commit
0b9da26eda
@ -78,8 +78,8 @@ class Graph_DiT(pl.LightningModule):
|
||||
timesteps=cfg.model.diffusion_steps)
|
||||
|
||||
|
||||
print("__init__")
|
||||
print("dataset_info.node_types", self.dataset_info.node_types)
|
||||
# print("__init__")
|
||||
# print("dataset_info.node_types", self.dataset_info.node_types)
|
||||
# dataset_info.node_types tensor([7.4826e-01, 2.6870e-02, 9.3930e-02, 4.4959e-02, 5.2982e-03, 7.5689e-04, 5.3739e-03, 1.5138e-03, 7.5689e-05, 4.3143e-03, 6.8650e-02])
|
||||
x_marginals = self.dataset_info.node_types.float() / torch.sum(self.dataset_info.node_types.float())
|
||||
|
||||
@ -123,8 +123,8 @@ class Graph_DiT(pl.LightningModule):
|
||||
return pred
|
||||
|
||||
def training_step(self, data, i):
|
||||
data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index]
|
||||
data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float()
|
||||
data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index]
|
||||
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
|
||||
|
||||
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
|
||||
dense_data = dense_data.mask(node_mask)
|
||||
@ -138,6 +138,9 @@ class Graph_DiT(pl.LightningModule):
|
||||
self.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E,
|
||||
log=i % self.log_every_steps == 0)
|
||||
self.log(f'loss', loss, batch_size=X.size(0), sync_dist=True)
|
||||
print(f"training loss: {loss}")
|
||||
with open("training-loss.csv", "a") as f:
|
||||
f.write(f"{loss}, {i}\n")
|
||||
return {'loss': loss}
|
||||
|
||||
|
||||
@ -150,7 +153,7 @@ class Graph_DiT(pl.LightningModule):
|
||||
def on_fit_start(self) -> None:
|
||||
self.train_iterations = self.trainer.datamodule.training_iterations
|
||||
print('on fit train iteration:', self.train_iterations)
|
||||
print("Size of the input features Xdim {}, Edim {}, ydim {}".format(self.Xdim, self.Edim, self.ydim))
|
||||
# print("Size of the input features Xdim {}, Edim {}, ydim {}".format(self.Xdim, self.Edim, self.ydim))
|
||||
|
||||
def on_train_epoch_start(self) -> None:
|
||||
if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
|
||||
@ -160,10 +163,12 @@ class Graph_DiT(pl.LightningModule):
|
||||
self.train_metrics.reset()
|
||||
|
||||
def on_train_epoch_end(self) -> None:
|
||||
|
||||
if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
|
||||
log = True
|
||||
else:
|
||||
log = False
|
||||
log = True
|
||||
self.train_loss.log_epoch_metrics(self.current_epoch, self.start_epoch_time, log)
|
||||
self.train_metrics.log_epoch_metrics(self.current_epoch, log)
|
||||
|
||||
@ -178,24 +183,33 @@ class Graph_DiT(pl.LightningModule):
|
||||
|
||||
@torch.no_grad()
|
||||
def validation_step(self, data, i):
|
||||
data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index]
|
||||
data_edge_attr = F.one_hot(data.edge_attr, num_classes=10).float()
|
||||
data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index]
|
||||
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
|
||||
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
|
||||
dense_data = dense_data.mask(node_mask, collapse=True)
|
||||
dense_data = dense_data.mask(node_mask, collapse=False)
|
||||
noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask)
|
||||
pred = self.forward(noisy_data)
|
||||
nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=False)
|
||||
self.val_y_collection.append(data.y)
|
||||
self.log(f'valid_nll', nll, batch_size=data.x.size(0), sync_dist=True)
|
||||
print(f'validation loss: {nll}, epoch: {self.current_epoch}')
|
||||
return {'loss': nll}
|
||||
|
||||
def on_validation_epoch_end(self) -> None:
|
||||
metrics = [self.val_nll.compute(), self.val_X_kl.compute() * self.T, self.val_E_kl.compute() * self.T,
|
||||
|
||||
self.val_X_logp.compute(), self.val_E_logp.compute()]
|
||||
|
||||
if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
|
||||
print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ",
|
||||
f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best : %.2f\n' % (metrics[0], self.best_val_nll))
|
||||
with open("validation-metrics.csv", "a") as f:
|
||||
# save the metrics as csv file
|
||||
f.write(f"{self.current_epoch}, {metrics[0]}, {metrics[1]}, {metrics[2]}, {metrics[3]}, {metrics[4]}\n")
|
||||
|
||||
|
||||
print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ",
|
||||
f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best : %.2f\n' % (metrics[0], self.best_val_nll))
|
||||
|
||||
# Log val nll with default Lightning logger, so it can be monitored by checkpoint callback
|
||||
self.log("val/NLL", metrics[0], sync_dist=True)
|
||||
@ -241,15 +255,15 @@ class Graph_DiT(pl.LightningModule):
|
||||
samples_left_to_generate -= to_generate
|
||||
chains_left_to_save -= chains_save
|
||||
|
||||
# print(f"Computing sampling metrics", ' ...')
|
||||
# valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False)
|
||||
# print(f'Done. Sampling took {time.time() - start:.2f} seconds\n')
|
||||
print(f"Computing sampling metrics", ' ...')
|
||||
valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False)
|
||||
print(f'Done. Sampling took {time.time() - start:.2f} seconds\n')
|
||||
|
||||
current_path = os.getcwd()
|
||||
result_path = os.path.join(current_path,
|
||||
f'graphs/{self.name}/epoch{self.current_epoch}_b0/')
|
||||
# current_path = os.getcwd()
|
||||
# result_path = os.path.join(current_path,
|
||||
# f'graphs/{self.name}/epoch{self.current_epoch}_b0/')
|
||||
# self.visualization_tools.visualize_by_smiles(result_path, valid_smiles, self.cfg.general.samples_to_save)
|
||||
# self.sampling_metrics.reset()
|
||||
self.sampling_metrics.reset()
|
||||
|
||||
def on_test_epoch_start(self) -> None:
|
||||
print("Starting test...")
|
||||
@ -262,8 +276,8 @@ class Graph_DiT(pl.LightningModule):
|
||||
|
||||
@torch.no_grad()
|
||||
def test_step(self, data, i):
|
||||
data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index]
|
||||
data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float()
|
||||
data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index]
|
||||
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
|
||||
|
||||
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
|
||||
dense_data = dense_data.mask(node_mask)
|
||||
@ -277,6 +291,8 @@ class Graph_DiT(pl.LightningModule):
|
||||
""" Measure likelihood on a test set and compute stability metrics. """
|
||||
metrics = [self.test_nll.compute(), self.test_X_kl.compute(), self.test_E_kl.compute(),
|
||||
self.test_X_logp.compute(), self.test_E_logp.compute()]
|
||||
with open("test-metrics.csv", "a") as f:
|
||||
f.write(f"{self.current_epoch}, {metrics[0]}, {metrics[1]}, {metrics[2]}, {metrics[3]}, {metrics[4]}\n")
|
||||
|
||||
print(f"Epoch {self.current_epoch}: Test NLL {metrics[0] :.2f} -- Test Atom type KL {metrics[1] :.2f} -- ",
|
||||
f"Test Edge type KL: {metrics[2] :.2f}")
|
||||
@ -433,10 +449,12 @@ class Graph_DiT(pl.LightningModule):
|
||||
|
||||
# Sample a timestep t.
|
||||
# When evaluating, the loss for t=0 is computed separately
|
||||
# print(f"apply_noise X shape: {X.shape}, E shape: {E.shape}, y shape: {y.shape}, node_mask shape: {node_mask.shape}")
|
||||
lowest_t = 0 if self.training else 1
|
||||
t_int = torch.randint(lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device).float() # (bs, 1)
|
||||
s_int = t_int - 1
|
||||
|
||||
|
||||
t_float = t_int / self.T
|
||||
s_float = s_int / self.T
|
||||
|
||||
@ -444,10 +462,23 @@ class Graph_DiT(pl.LightningModule):
|
||||
beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1)
|
||||
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1)
|
||||
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1)
|
||||
# print(f"alpha_s_bar: {alpha_s_bar.shape}, alpha_t_bar: {alpha_t_bar.shape}")
|
||||
|
||||
Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device) # (bs, dx_in, dx_out), (bs, de_in, de_out)
|
||||
# print(f"X shape: {X.shape}, E shape: {E.shape}, node_mask shape: {node_mask.shape}")
|
||||
# print(f"Qtb shape: {Qtb.X.shape}")
|
||||
"""
|
||||
X shape: torch.Size([1200, 8]),
|
||||
E shape: torch.Size([1200, 8, 8]),
|
||||
y shape: torch.Size([1200, 1]),
|
||||
node_mask shape: torch.Size([1200, 8])
|
||||
alpha_s_bar: torch.Size([1200, 1]), alpha_t_bar: torch.Size([1200, 1])
|
||||
"""
|
||||
|
||||
# print(X.shape)
|
||||
bs, n, d = X.shape
|
||||
E = E[..., :2]
|
||||
# bs, n = X.shape
|
||||
X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
|
||||
prob_all = X_all @ Qtb.X
|
||||
probX = prob_all[:, :, :self.Xdim_output]
|
||||
@ -457,6 +488,7 @@ class Graph_DiT(pl.LightningModule):
|
||||
|
||||
X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
|
||||
E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output)
|
||||
# print(f"X.shape: {X.shape}, X_t shape: {X_t.shape}, E.shape: {E.shape}, E_t shape: {E_t.shape}")
|
||||
assert (X.shape == X_t.shape) and (E.shape == E_t.shape)
|
||||
|
||||
y_t = y
|
||||
|
Loading…
Reference in New Issue
Block a user