write graph code for the absctract dataset
This commit is contained in:
		| @@ -127,4 +127,19 @@ class AbstractDatasetInfos: | |||||||
|         print('input dims') |         print('input dims') | ||||||
|         print(self.input_dims) |         print(self.input_dims) | ||||||
|         print('output dims') |         print('output dims') | ||||||
|  |         print(self.output_dims) | ||||||
|  |     def compute_graph_input_output_dims(self, datamodule): | ||||||
|  |         example_batch = datamodule.example_batch() | ||||||
|  |         example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=8).float()[:, self.active_index] | ||||||
|  |         example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=2).float() | ||||||
|  |  | ||||||
|  |         self.input_dims = {'X': example_batch_x.size(1), | ||||||
|  |                            'E': example_batch_edge_attr.size(1), | ||||||
|  |                            'y': example_batch['y'].size(1)} | ||||||
|  |         self.output_dims = {'X': example_batch_x.size(1), | ||||||
|  |                             'E': example_batch_edge_attr.size(1), | ||||||
|  |                             'y': example_batch['y'].size(1)} | ||||||
|  |         print('input dims') | ||||||
|  |         print(self.input_dims) | ||||||
|  |         print('output dims') | ||||||
|         print(self.output_dims) |         print(self.output_dims) | ||||||
		Reference in New Issue
	
	Block a user