try update the api in DataInfo

This commit is contained in:
mhz 2024-06-26 22:10:07 +02:00
parent 0c7c525680
commit 062a27b83f

View File

@ -78,7 +78,7 @@ def main(cfg: DictConfig):
datamodule = dataset.DataModule(cfg)
datamodule.prepare_data()
dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg)
dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset)
# train_smiles, reference_smiles = datamodule.get_train_smiles()
# get input output dimensions