write graph code for the absctract dataset

This commit is contained in:
mhz 2024-06-26 23:42:01 +02:00
parent 14186fa97f
commit a7f7010da7

View File

@ -127,4 +127,19 @@ class AbstractDatasetInfos:
print('input dims')
print(self.input_dims)
print('output dims')
print(self.output_dims)
def compute_graph_input_output_dims(self, datamodule):
example_batch = datamodule.example_batch()
example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=8).float()[:, self.active_index]
example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=2).float()
self.input_dims = {'X': example_batch_x.size(1),
'E': example_batch_edge_attr.size(1),
'y': example_batch['y'].size(1)}
self.output_dims = {'X': example_batch_x.size(1),
'E': example_batch_edge_attr.size(1),
'y': example_batch['y'].size(1)}
print('input dims')
print(self.input_dims)
print('output dims')
print(self.output_dims)