write graph code for the absctract dataset
This commit is contained in:
		| @@ -118,6 +118,21 @@ class AbstractDatasetInfos: | |||||||
|         example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=118).float()[:, self.active_index] |         example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=118).float()[:, self.active_index] | ||||||
|         example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=10).float() |         example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=10).float() | ||||||
|  |  | ||||||
|  |         self.input_dims = {'X': example_batch_x.size(1),  | ||||||
|  |                            'E': example_batch_edge_attr.size(1),  | ||||||
|  |                            'y': example_batch['y'].size(1)} | ||||||
|  |         self.output_dims = {'X': example_batch_x.size(1), | ||||||
|  |                             'E': example_batch_edge_attr.size(1), | ||||||
|  |                             'y': example_batch['y'].size(1)} | ||||||
|  |         print('input dims') | ||||||
|  |         print(self.input_dims) | ||||||
|  |         print('output dims') | ||||||
|  |         print(self.output_dims) | ||||||
|  |     def compute_graph_input_output_dims(self, datamodule): | ||||||
|  |         example_batch = datamodule.example_batch() | ||||||
|  |         example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=8).float()[:, self.active_index] | ||||||
|  |         example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=2).float() | ||||||
|  |  | ||||||
|         self.input_dims = {'X': example_batch_x.size(1), |         self.input_dims = {'X': example_batch_x.size(1), | ||||||
|                            'E': example_batch_edge_attr.size(1), |                            'E': example_batch_edge_attr.size(1), | ||||||
|                            'y': example_batch['y'].size(1)} |                            'y': example_batch['y'].size(1)} | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user