# These imports are tricky because they use c++, do not move them import tqdm import os, shutil import warnings import torch import hydra from omegaconf import DictConfig from pytorch_lightning import Trainer import utils from datasets import dataset from diffusion_model import Graph_DiT from metrics.molecular_metrics_train import TrainMolecularMetricsDiscrete from metrics.molecular_metrics_train import TrainGraphMetricsDiscrete from metrics.molecular_metrics_sampling import SamplingMolecularMetrics from metrics.molecular_metrics_sampling import SamplingGraphMetrics from analysis.visualization import MolecularVisualization from analysis.visualization import GraphVisualization warnings.filterwarnings("ignore", category=UserWarning) torch.set_float32_matmul_precision("medium") def remove_folder(folder): for filename in os.listdir(folder): file_path = os.path.join(folder, filename) try: if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) elif os.path.isdir(file_path): shutil.rmtree(file_path) except Exception as e: print("Failed to delete %s. Reason: %s" % (file_path, e)) def get_resume(cfg, model_kwargs): """Resumes a run. It loads previous config without allowing to update keys (used for testing).""" saved_cfg = cfg.copy() name = cfg.general.name + "_resume" resume = cfg.general.test_only batch_size = cfg.train.batch_size model = Graph_DiT.load_from_checkpoint(resume, **model_kwargs) cfg = model.cfg cfg.general.test_only = resume cfg.general.name = name cfg.train.batch_size = batch_size cfg = utils.update_config_with_new_keys(cfg, saved_cfg) return cfg, model def get_resume_adaptive(cfg, model_kwargs): """Resumes a run. It loads previous config but allows to make some changes (used for resuming training).""" saved_cfg = cfg.copy() # Fetch path to this file to get base path current_path = os.path.dirname(os.path.realpath(__file__)) root_dir = current_path.split("outputs")[0] resume_path = os.path.join(root_dir, cfg.general.resume) if cfg.model.type == "discrete": model = Graph_DiT.load_from_checkpoint( resume_path, **model_kwargs ) else: raise NotImplementedError("Unknown model") new_cfg = model.cfg for category in cfg: for arg in cfg[category]: new_cfg[category][arg] = cfg[category][arg] new_cfg.general.resume = resume_path new_cfg.general.name = new_cfg.general.name + "_resume" new_cfg = utils.update_config_with_new_keys(new_cfg, saved_cfg) return new_cfg, model @hydra.main( version_base="1.1", config_path="../configs", config_name="config" ) def main(cfg: DictConfig): datamodule = dataset.DataModule(cfg) datamodule.prepare_data() dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset) train_smiles, reference_smiles = datamodule.get_train_smiles() # train_graphs, reference_graphs = datamodule.get_train_graphs() # get input output dimensions dataset_infos.compute_input_output_dims(datamodule=datamodule) train_metrics = TrainMolecularMetricsDiscrete(dataset_infos) # train_metrics = TrainGraphMetricsDiscrete(dataset_infos) sampling_metrics = SamplingMolecularMetrics( dataset_infos, train_smiles, reference_smiles ) # sampling_metrics = SamplingGraphMetrics( # dataset_infos, train_graphs, reference_graphs # ) visualization_tools = MolecularVisualization(dataset_infos) model_kwargs = { "dataset_infos": dataset_infos, # "train_metrics": train_metrics, # "sampling_metrics": sampling_metrics, "visualization_tools": visualization_tools, } if cfg.general.test_only: # When testing, previous configuration is fully loaded cfg, _ = get_resume(cfg, model_kwargs) os.chdir(cfg.general.test_only.split("checkpoints")[0]) elif cfg.general.resume is not None: # When resuming, we can override some parts of previous configuration cfg, _ = get_resume_adaptive(cfg, model_kwargs) os.chdir(cfg.general.resume.split("checkpoints")[0]) model = Graph_DiT(cfg=cfg, **model_kwargs) trainer = Trainer( gradient_clip_val=cfg.train.clip_grad, # accelerator="gpu" # if torch.cuda.is_available() and cfg.general.gpus > 0 # else "cpu", accelerator="cpu", devices=cfg.general.gpus if torch.cuda.is_available() and cfg.general.gpus > 0 else None, max_epochs=cfg.train.n_epochs, enable_checkpointing=False, check_val_every_n_epoch=cfg.train.check_val_every_n_epoch, val_check_interval=cfg.train.val_check_interval, strategy="ddp" if cfg.general.gpus > 1 else "auto", enable_progress_bar=cfg.general.enable_progress_bar, callbacks=[], reload_dataloaders_every_n_epochs=0, logger=[], ) if not cfg.general.test_only: trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume) if cfg.general.save_model: trainer.save_checkpoint(f"checkpoints/{cfg.general.name}/last.ckpt") trainer.test(model, datamodule=datamodule) else: trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only) from accelerate import Accelerator from accelerate.utils import set_seed, ProjectConfiguration @hydra.main( version_base="1.1", config_path="../configs", config_name="config" ) def test(cfg: DictConfig): accelerator_config = ProjectConfiguration( project_dir=os.path.join(cfg.general.log_dir, cfg.general.name), automatic_checkpoint_naming=True, total_limit=cfg.general.number_checkpoint_limit, ) accelerator = Accelerator( mixed_precision='no', project_config=accelerator_config, gradient_accumulation_steps=cfg.train.gradient_accumulation_steps * cfg.train.n_epochs, ) set_seed(cfg.train.seed, device_specific=True) datamodule = dataset.DataModule(cfg) datamodule.prepare_data() dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset) train_graphs, reference_graphs = datamodule.get_train_graphs() dataset_infos.compute_input_output_dims(datamodule=datamodule) train_metrics = TrainGraphMetricsDiscrete(dataset_infos) sampling_metrics = SamplingGraphMetrics( dataset_infos, train_graphs, reference_graphs ) visulization_tools = GraphVisualization(dataset_infos) model_kwargs = { "dataset_infos": dataset_infos, "train_metrics": train_metrics, "sampling_metrics": sampling_metrics, "visualization_tools": visulization_tools, } if cfg.general.test_only: cfg, _ = get_resume(cfg, model_kwargs) os.chdir(cfg.general.test_only.split("checkpoints")[0]) elif cfg.general.resume is not None: cfg, _ = get_resume_adaptive(cfg, model_kwargs) os.chdir(cfg.general.resume.split("checkpoints")[0]) # os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number model = Graph_DiT(cfg=cfg, **model_kwargs) graph_dit_model = model inference_dtype = torch.float32 graph_dit_model.to(accelerator.device, dtype=inference_dtype) # optional: freeze the model # graph_dit_model.model.requires_grad_(True) import torch.nn.functional as F optimizer = graph_dit_model.configure_optimizers() train_dataloader = accelerator.prepare(datamodule.train_dataloader()) optimizer, graph_dit_model = accelerator.prepare(optimizer, graph_dit_model) # start training for epoch in range(cfg.train.n_epochs): graph_dit_model.train() # 设置模型为训练模式 print(f"Epoch {epoch}", end="\n") for data in train_dataloader: # 从数据加载器中获取一个批次的数据 data.to(accelerator.device) data_x = F.one_hot(data.x, num_classes=12).float()[:, graph_dit_model.active_index] data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, graph_dit_model.max_n_nodes) dense_data = dense_data.mask(node_mask) X, E = dense_data.X, dense_data.E noisy_data = graph_dit_model.apply_noise(X, E, data.y, node_mask) pred = graph_dit_model.forward(noisy_data) loss = graph_dit_model.train_loss(masked_pred_X=pred.X, masked_pred_E=pred.E, pred_y=pred.y, true_X=X, true_E=E, true_y=data.y, node_mask=node_mask, log=epoch % graph_dit_model.log_every_steps == 0) # print(f'training loss: {loss}, epoch: {self.current_epoch}, batch: {i}\n, pred type: {type(pred)}, pred.X shape: {type(pred.X)}, {pred.X.shape}, pred.E shape: {type(pred.E)}, {pred.E.shape}') graph_dit_model.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E, log=epoch % graph_dit_model.log_every_steps == 0) graph_dit_model.log(f'loss', loss, batch_size=X.size(0), sync_dist=True) print(f"training loss: {loss}") with open("training-loss.csv", "a") as f: f.write(f"{loss}, {epoch}\n") accelerator.backward(loss) optimizer.step() optimizer.zero_grad() # return {'loss': loss} # start sampling samples = [] for i in tqdm( range(cfg.general.n_samples), desc="Sampling", disable=not cfg.general.enable_progress_bar ): batch_size = cfg.train.batch_size num_steps = cfg.model.diffusion_steps y = torch.ones(batch_size, num_steps, 1, 1, device=accelerator.device, dtype=inference_dtype) # sample from the model samples_batch = graph_dit_model.sample_batch( batch_id=i, batch_size=batch_size, y=y, keep_chain=1, number_chain_steps=num_steps, save_final=batch_size ) samples.append(samples_batch) # save samples print("Samples:") print(samples) # trainer = Trainer( # gradient_clip_val=cfg.train.clip_grad, # # accelerator="cpu", # accelerator="gpu" # if torch.cuda.is_available() and cfg.general.gpus > 0 # else "cpu", # devices=[cfg.general.gpu_number] # if torch.cuda.is_available() and cfg.general.gpus > 0 # else None, # max_epochs=cfg.train.n_epochs, # enable_checkpointing=False, # check_val_every_n_epoch=cfg.train.check_val_every_n_epoch, # val_check_interval=cfg.train.val_check_interval, # strategy="ddp" if cfg.general.gpus > 1 else "auto", # enable_progress_bar=cfg.general.enable_progress_bar, # callbacks=[], # reload_dataloaders_every_n_epochs=0, # logger=[], # ) # if not cfg.general.test_only: # print("start testing fit method") # trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume) # if cfg.general.save_model: # trainer.save_checkpoint(f"checkpoints/{cfg.general.name}/last.ckpt") # trainer.test(model, datamodule=datamodule) if __name__ == "__main__": test()