some onehot issue
This commit is contained in:
parent
be8bb16f61
commit
f5911be781
@ -116,7 +116,7 @@ class AbstractDatasetInfos:
|
||||
def compute_input_output_dims(self, datamodule):
|
||||
example_batch = datamodule.example_batch()
|
||||
example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=118).float()[:, self.active_index]
|
||||
example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=10).float()
|
||||
example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=2).float()
|
||||
|
||||
self.input_dims = {'X': example_batch_x.size(1),
|
||||
'E': example_batch_edge_attr.size(1),
|
||||
|
Loading…
Reference in New Issue
Block a user