comment some output statements

This commit is contained in:
mhz 2024-07-01 10:03:40 +02:00
parent 572f030677
commit dd31fda8d5

View File

@ -46,13 +46,17 @@ def unnormalize(X, E, y, norm_values, norm_biases, node_mask, collapse=False):
def to_dense(x, edge_index, edge_attr, batch, max_num_nodes=None):
# print(f"to dense X: {x.shape}, edge_index: {edge_index.shape}, edge_attr: {edge_attr.shape}, batch: {batch}, max_num_nodes: {max_num_nodes}")
X, node_mask = to_dense_batch(x=x, batch=batch, max_num_nodes=max_num_nodes)
# node_mask = node_mask.float()
edge_index, edge_attr = torch_geometric.utils.remove_self_loops(edge_index, edge_attr)
if max_num_nodes is None:
max_num_nodes = X.size(1)
# print(f"to dense X: {X.shape}, edge_index: {edge_index.shape}, edge_attr: {edge_attr.shape}, batch: {batch}, max_num_nodes: {max_num_nodes}")
E = to_dense_adj(edge_index=edge_index, batch=batch, edge_attr=edge_attr, max_num_nodes=max_num_nodes)
E = encode_no_edge(E)
# print(f"to dense X: {X.shape}, edge_index: {edge_index.shape}, edge_attr: {edge_attr.shape}, batch: {batch}, max_num_nodes: {max_num_nodes}")
# print(f"to dense X: {X.shape}, E: {E.shape}, batch: {batch}, lenE: {len(E)}")
return PlaceHolder(X=X, E=E, y=None), node_mask
@ -119,6 +123,7 @@ class PlaceHolder:
x_mask = node_mask.unsqueeze(-1) # bs, n, 1
e_mask1 = x_mask.unsqueeze(2) # bs, n, 1, 1
e_mask2 = x_mask.unsqueeze(1) # bs, 1, n, 1
# print(f"mask X: {self.X.shape}, E: {self.E.shape}, node_mask: {node_mask.shape}, x_mask: {x_mask.shape}, e_mask1: {e_mask1.shape}, e_mask2: {e_mask2.shape}")
if collapse:
self.X = torch.argmax(self.X, dim=-1)
@ -127,8 +132,13 @@ class PlaceHolder:
self.X[node_mask == 0] = - 1
self.E[(e_mask1 * e_mask2).squeeze(-1) == 0] = - 1
else:
# print(f"X: {self.X.shape}, E: {self.E.shape}")
# print(f"X: {self.X}, E: {self.E}")
# print(f"x_mask: {x_mask}, e_mask1: {e_mask1}, e_mask2: {e_mask2}")
self.X = self.X * x_mask
self.E = self.E * e_mask1 * e_mask2
# print(f"X: {self.X.shape}, E: {self.E.shape}")
# print(f"X: {self.X}, E: {self.E}")
assert torch.allclose(self.E, torch.transpose(self.E, 1, 2))
return self