write test code
This commit is contained in:
parent
a222c514d9
commit
14186fa97f
@ -80,14 +80,18 @@ def main(cfg: DictConfig):
|
||||
datamodule.prepare_data()
|
||||
dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset)
|
||||
# train_smiles, reference_smiles = datamodule.get_train_smiles()
|
||||
train_graphs, reference_graphs = datamodule.get_train_graphs()
|
||||
|
||||
# get input output dimensions
|
||||
dataset_infos.compute_input_output_dims(datamodule=datamodule)
|
||||
# train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)
|
||||
train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)
|
||||
|
||||
# sampling_metrics = SamplingMolecularMetrics(
|
||||
# dataset_infos, train_smiles, reference_smiles
|
||||
# )
|
||||
sampling_metrics = SamplingGraphMetrics(
|
||||
dataset_infos, train_graphs, reference_graphs
|
||||
)
|
||||
visualization_tools = MolecularVisualization(dataset_infos)
|
||||
|
||||
model_kwargs = {
|
||||
@ -135,5 +139,16 @@ def main(cfg: DictConfig):
|
||||
else:
|
||||
trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only)
|
||||
|
||||
@hydra.main(
|
||||
version_base="1.1", config_path="../configs", config_name="config"
|
||||
)
|
||||
def test(cfg: DictConfig):
|
||||
datamodule = dataset.DataModule(cfg)
|
||||
datamodule.prepare_data()
|
||||
dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset)
|
||||
train_graphs, reference_graphs = datamodule.get_train_graphs()
|
||||
|
||||
dataset_infos.compute_input_output_dims(datamodule=datamodule)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
test()
|
||||
|
Loading…
Reference in New Issue
Block a user