write test code
This commit is contained in:
		| @@ -80,14 +80,18 @@ def main(cfg: DictConfig): | ||||
|     datamodule.prepare_data() | ||||
|     dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset) | ||||
|     # train_smiles, reference_smiles = datamodule.get_train_smiles() | ||||
|     train_graphs, reference_graphs = datamodule.get_train_graphs() | ||||
|  | ||||
|     # get input output dimensions | ||||
|     dataset_infos.compute_input_output_dims(datamodule=datamodule) | ||||
|     # train_metrics = TrainMolecularMetricsDiscrete(dataset_infos) | ||||
|     train_metrics = TrainMolecularMetricsDiscrete(dataset_infos) | ||||
|  | ||||
|     # sampling_metrics = SamplingMolecularMetrics( | ||||
|     #     dataset_infos, train_smiles, reference_smiles | ||||
|     # ) | ||||
|     sampling_metrics = SamplingGraphMetrics( | ||||
|         dataset_infos, train_graphs, reference_graphs | ||||
|     ) | ||||
|     visualization_tools = MolecularVisualization(dataset_infos) | ||||
|  | ||||
|     model_kwargs = { | ||||
| @@ -135,5 +139,16 @@ def main(cfg: DictConfig): | ||||
|     else: | ||||
|         trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only) | ||||
|  | ||||
| @hydra.main( | ||||
|     version_base="1.1", config_path="../configs", config_name="config" | ||||
| ) | ||||
| def test(cfg: DictConfig): | ||||
|     datamodule = dataset.DataModule(cfg) | ||||
|     datamodule.prepare_data() | ||||
|     dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset) | ||||
|     train_graphs, reference_graphs = datamodule.get_train_graphs() | ||||
|  | ||||
|     dataset_infos.compute_input_output_dims(datamodule=datamodule) | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
|     test() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user