write graph code for the absctract dataset
This commit is contained in:
		| @@ -127,4 +127,19 @@ class AbstractDatasetInfos: | ||||
|         print('input dims') | ||||
|         print(self.input_dims) | ||||
|         print('output dims') | ||||
|         print(self.output_dims) | ||||
|     def compute_graph_input_output_dims(self, datamodule): | ||||
|         example_batch = datamodule.example_batch() | ||||
|         example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=8).float()[:, self.active_index] | ||||
|         example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=2).float() | ||||
|  | ||||
|         self.input_dims = {'X': example_batch_x.size(1), | ||||
|                            'E': example_batch_edge_attr.size(1), | ||||
|                            'y': example_batch['y'].size(1)} | ||||
|         self.output_dims = {'X': example_batch_x.size(1), | ||||
|                             'E': example_batch_edge_attr.size(1), | ||||
|                             'y': example_batch['y'].size(1)} | ||||
|         print('input dims') | ||||
|         print(self.input_dims) | ||||
|         print('output dims') | ||||
|         print(self.output_dims) | ||||
		Reference in New Issue
	
	Block a user