update the main function

This commit is contained in:
mhz 2024-07-01 10:02:51 +02:00
parent f5911be781
commit ba008ae54c

View File

@ -11,9 +11,13 @@ import utils
from datasets import dataset
from diffusion_model import Graph_DiT
from metrics.molecular_metrics_train import TrainMolecularMetricsDiscrete
from metrics.molecular_metrics_train import TrainGraphMetricsDiscrete
from metrics.molecular_metrics_sampling import SamplingMolecularMetrics
from metrics.molecular_metrics_sampling import SamplingGraphMetrics
from analysis.visualization import MolecularVisualization
from analysis.visualization import GraphVisualization
warnings.filterwarnings("ignore", category=UserWarning)
torch.set_float32_matmul_precision("medium")
@ -79,19 +83,20 @@ def main(cfg: DictConfig):
datamodule = dataset.DataModule(cfg)
datamodule.prepare_data()
dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset)
# train_smiles, reference_smiles = datamodule.get_train_smiles()
train_graphs, reference_graphs = datamodule.get_train_graphs()
train_smiles, reference_smiles = datamodule.get_train_smiles()
# train_graphs, reference_graphs = datamodule.get_train_graphs()
# get input output dimensions
dataset_infos.compute_input_output_dims(datamodule=datamodule)
train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)
# train_metrics = TrainGraphMetricsDiscrete(dataset_infos)
# sampling_metrics = SamplingMolecularMetrics(
# dataset_infos, train_smiles, reference_smiles
# )
sampling_metrics = SamplingGraphMetrics(
dataset_infos, train_graphs, reference_graphs
sampling_metrics = SamplingMolecularMetrics(
dataset_infos, train_smiles, reference_smiles
)
# sampling_metrics = SamplingGraphMetrics(
# dataset_infos, train_graphs, reference_graphs
# )
visualization_tools = MolecularVisualization(dataset_infos)
model_kwargs = {
@ -149,6 +154,54 @@ def test(cfg: DictConfig):
train_graphs, reference_graphs = datamodule.get_train_graphs()
dataset_infos.compute_input_output_dims(datamodule=datamodule)
train_metrics = TrainGraphMetricsDiscrete(dataset_infos)
sampling_metrics = SamplingGraphMetrics(
dataset_infos, train_graphs, reference_graphs
)
visulization_tools = GraphVisualization(dataset_infos)
model_kwargs = {
"dataset_infos": dataset_infos,
"train_metrics": train_metrics,
"sampling_metrics": sampling_metrics,
"visualization_tools": visulization_tools,
}
if cfg.general.test_only:
cfg, _ = get_resume(cfg, model_kwargs)
os.chdir(cfg.general.test_only.split("checkpoints")[0])
elif cfg.general.resume is not None:
cfg, _ = get_resume_adaptive(cfg, model_kwargs)
os.chdir(cfg.general.resume.split("checkpoints")[0])
model = Graph_DiT(cfg=cfg, **model_kwargs)
trainer = Trainer(
gradient_clip_val=cfg.train.clip_grad,
# accelerator="cpu",
accelerator="gpu"
if torch.cuda.is_available() and cfg.general.gpus > 0
else "cpu",
devices=cfg.general.gpus
if torch.cuda.is_available() and cfg.general.gpus > 0
else None,
max_epochs=cfg.train.n_epochs,
enable_checkpointing=False,
check_val_every_n_epoch=cfg.train.check_val_every_n_epoch,
val_check_interval=cfg.train.val_check_interval,
strategy="ddp" if cfg.general.gpus > 1 else "auto",
enable_progress_bar=cfg.general.enable_progress_bar,
callbacks=[],
reload_dataloaders_every_n_epochs=0,
logger=[],
)
if not cfg.general.test_only:
print("start testing fit method")
trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume)
if cfg.general.save_model:
trainer.save_checkpoint(f"checkpoints/{cfg.general.name}/last.ckpt")
trainer.test(model, datamodule=datamodule)
if __name__ == "__main__":
test()