From 2c00828630fce3202b3f1e482b53f0b3c7911b6c Mon Sep 17 00:00:00 2001 From: gang liu Date: Sat, 25 May 2024 15:32:36 -0400 Subject: [PATCH] update_name --- .gitignore | 161 ++++++++++++++++++ README.md | 6 +- configs/config.yaml | 6 +- {mcd => graph_dit}/__init__.py | 0 {mcd => graph_dit}/analysis/__init__.py | 0 .../analysis/rdkit_functions.py | 0 {mcd => graph_dit}/analysis/visualization.py | 0 {mcd => graph_dit}/datasets/__init__.py | 0 .../datasets/abstract_dataset.py | 0 {mcd => graph_dit}/datasets/dataset.py | 0 {mcd => graph_dit}/diffusion/__init__.py | 0 .../diffusion/diffusion_utils.py | 0 {mcd => graph_dit}/diffusion/distributions.py | 0 .../diffusion/noise_schedule.py | 0 {mcd => graph_dit}/diffusion_model.py | 4 +- {mcd => graph_dit}/main.py | 10 +- {mcd => graph_dit}/metrics/__init__.py | 0 .../metrics/abstract_metrics.py | 0 {mcd => graph_dit}/metrics/fpscores.pkl.gz | Bin .../metrics/molecular_metrics_sampling.py | 0 .../metrics/molecular_metrics_train.py | 0 {mcd => graph_dit}/metrics/property_metric.py | 0 {mcd => graph_dit}/metrics/train_loss.py | 0 {mcd => graph_dit}/models/__init__.py | 0 {mcd => graph_dit}/models/conditions.py | 0 {mcd => graph_dit}/models/layers.py | 0 {mcd => graph_dit}/models/transformer.py | 10 +- {mcd => graph_dit}/utils.py | 0 28 files changed, 178 insertions(+), 19 deletions(-) create mode 100644 .gitignore rename {mcd => graph_dit}/__init__.py (100%) rename {mcd => graph_dit}/analysis/__init__.py (100%) rename {mcd => graph_dit}/analysis/rdkit_functions.py (100%) rename {mcd => graph_dit}/analysis/visualization.py (100%) rename {mcd => graph_dit}/datasets/__init__.py (100%) rename {mcd => graph_dit}/datasets/abstract_dataset.py (100%) rename {mcd => graph_dit}/datasets/dataset.py (100%) rename {mcd => graph_dit}/diffusion/__init__.py (100%) rename {mcd => graph_dit}/diffusion/diffusion_utils.py (100%) rename {mcd => graph_dit}/diffusion/distributions.py (100%) rename {mcd => graph_dit}/diffusion/noise_schedule.py (100%) rename {mcd => graph_dit}/diffusion_model.py (97%) rename {mcd => graph_dit}/main.py (95%) rename {mcd => graph_dit}/metrics/__init__.py (100%) rename {mcd => graph_dit}/metrics/abstract_metrics.py (100%) rename {mcd => graph_dit}/metrics/fpscores.pkl.gz (100%) rename {mcd => graph_dit}/metrics/molecular_metrics_sampling.py (100%) rename {mcd => graph_dit}/metrics/molecular_metrics_train.py (100%) rename {mcd => graph_dit}/metrics/property_metric.py (100%) rename {mcd => graph_dit}/metrics/train_loss.py (100%) rename {mcd => graph_dit}/models/__init__.py (100%) rename {mcd => graph_dit}/models/conditions.py (100%) rename {mcd => graph_dit}/models/layers.py (100%) rename {mcd => graph_dit}/models/transformer.py (96%) rename {mcd => graph_dit}/utils.py (100%) diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..53db6cc --- /dev/null +++ b/.gitignore @@ -0,0 +1,161 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + + +.DS_Store +.idea/ +__pycache__/ +dgd/configs/__pycache__/ +egnn/__pycache__/ +equivariant_diffusion/__pycache__/ +outputs/ +archives/qm9/__pycache__/ +archives/qm9/data_utils/__pycache__/ +archives/qm9/data_utils/prepare/__pycache__/ +archives/qm9/property_prediction/__pycache__/ +archives/* +.env +results/*.ckpt +results/qm9_molecules_h +results/qm9_molecules_noh +dgd/analysis/orca/orca +results/* +ggg_data/ +ggg_utils/ +saved_models +src/analysis/orca/orca +src/analysis/orca/tmp_XMYAR426.txt + + +# New +archive.zip +logs/ +generated/ +data/processed/ \ No newline at end of file diff --git a/README.md b/README.md index c1394b5..913b559 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,9 @@ -Inverse Molecular Design with Multi-Conditional Diffusion Guidance +Graph Diffusion Transformer for Multi-Conditional Molecular Generation ================================================================ Paper: https://arxiv.org/abs/2401.13858 -This is the code for MCD: a Multi-Conditional Diffusion Model for inverse small molecule and polymer designs and generations. The denoising model architecture in `mcd/models` looks like: +This is the code for Graph DiT. The denoising model architecture in `graph_dit/models` looks like:
Description of the first image @@ -16,7 +16,7 @@ All dependencies are specified in the `requirements.txt` file. This code was developed and tested with Python 3.9.16, PyTorch 2.0.0, and PyG 2.3.0, Pytorch-lightning 2.0.1. -For molecular generation evaluation, we should first install rdkit: +For molecular generation evaluation, we should first install rdkit. Then `fcd_torch`: `pip install fcd_torch` (https://github.com/insilicomedicine/fcd_torch). diff --git a/configs/config.yaml b/configs/config.yaml index 33dcb95..e3fade7 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -1,5 +1,5 @@ general: - name: 'MCD' + name: 'graph_dit' wandb: 'disabled' gpus: 1 resume: null @@ -14,11 +14,11 @@ general: final_model_samples_to_save: 20 final_model_chains_to_save: 1 enable_progress_bar: False - save_model: False + save_model: True model: type: 'discrete' transition: 'marginal' - model: 'MCD' + model: 'graph_dit' diffusion_steps: 500 diffusion_noise_schedule: 'cosine' guide_scale: 2 diff --git a/mcd/__init__.py b/graph_dit/__init__.py similarity index 100% rename from mcd/__init__.py rename to graph_dit/__init__.py diff --git a/mcd/analysis/__init__.py b/graph_dit/analysis/__init__.py similarity index 100% rename from mcd/analysis/__init__.py rename to graph_dit/analysis/__init__.py diff --git a/mcd/analysis/rdkit_functions.py b/graph_dit/analysis/rdkit_functions.py similarity index 100% rename from mcd/analysis/rdkit_functions.py rename to graph_dit/analysis/rdkit_functions.py diff --git a/mcd/analysis/visualization.py b/graph_dit/analysis/visualization.py similarity index 100% rename from mcd/analysis/visualization.py rename to graph_dit/analysis/visualization.py diff --git a/mcd/datasets/__init__.py b/graph_dit/datasets/__init__.py similarity index 100% rename from mcd/datasets/__init__.py rename to graph_dit/datasets/__init__.py diff --git a/mcd/datasets/abstract_dataset.py b/graph_dit/datasets/abstract_dataset.py similarity index 100% rename from mcd/datasets/abstract_dataset.py rename to graph_dit/datasets/abstract_dataset.py diff --git a/mcd/datasets/dataset.py b/graph_dit/datasets/dataset.py similarity index 100% rename from mcd/datasets/dataset.py rename to graph_dit/datasets/dataset.py diff --git a/mcd/diffusion/__init__.py b/graph_dit/diffusion/__init__.py similarity index 100% rename from mcd/diffusion/__init__.py rename to graph_dit/diffusion/__init__.py diff --git a/mcd/diffusion/diffusion_utils.py b/graph_dit/diffusion/diffusion_utils.py similarity index 100% rename from mcd/diffusion/diffusion_utils.py rename to graph_dit/diffusion/diffusion_utils.py diff --git a/mcd/diffusion/distributions.py b/graph_dit/diffusion/distributions.py similarity index 100% rename from mcd/diffusion/distributions.py rename to graph_dit/diffusion/distributions.py diff --git a/mcd/diffusion/noise_schedule.py b/graph_dit/diffusion/noise_schedule.py similarity index 100% rename from mcd/diffusion/noise_schedule.py rename to graph_dit/diffusion/noise_schedule.py diff --git a/mcd/diffusion_model.py b/graph_dit/diffusion_model.py similarity index 97% rename from mcd/diffusion_model.py rename to graph_dit/diffusion_model.py index 90d4e8b..4a0c9a6 100644 --- a/mcd/diffusion_model.py +++ b/graph_dit/diffusion_model.py @@ -12,7 +12,7 @@ from metrics.train_loss import TrainLossDiscrete from metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL import utils -class MCD(pl.LightningModule): +class Graph_DiT(pl.LightningModule): def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools): super().__init__() self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics']) @@ -174,7 +174,6 @@ class MCD(pl.LightningModule): def validation_step(self, data, i): data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index] data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float() - dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) dense_data = dense_data.mask(node_mask) noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) @@ -281,7 +280,6 @@ class MCD(pl.LightningModule): chains_left_to_save = self.cfg.general.final_model_chains_to_save samples, all_ys, batch_id = [], [], 0 - test_y_collection = torch.cat(self.test_y_collection, dim=0) num_examples = test_y_collection.size(0) if self.cfg.general.final_model_samples_to_generate > num_examples: diff --git a/mcd/main.py b/graph_dit/main.py similarity index 95% rename from mcd/main.py rename to graph_dit/main.py index 433377d..fb4f4ea 100644 --- a/mcd/main.py +++ b/graph_dit/main.py @@ -9,7 +9,7 @@ from pytorch_lightning import Trainer import utils from datasets import dataset -from diffusion_model import MCD +from diffusion_model import Graph_DiT from metrics.molecular_metrics_train import TrainMolecularMetricsDiscrete from metrics.molecular_metrics_sampling import SamplingMolecularMetrics @@ -36,7 +36,7 @@ def get_resume(cfg, model_kwargs): name = cfg.general.name + "_resume" resume = cfg.general.test_only batch_size = cfg.train.batch_size - model = MCD.load_from_checkpoint(resume, **model_kwargs) + model = Graph_DiT.load_from_checkpoint(resume, **model_kwargs) cfg = model.cfg cfg.general.test_only = resume cfg.general.name = name @@ -54,7 +54,7 @@ def get_resume_adaptive(cfg, model_kwargs): resume_path = os.path.join(root_dir, cfg.general.resume) if cfg.model.type == "discrete": - model = MCD.load_from_checkpoint( + model = Graph_DiT.load_from_checkpoint( resume_path, **model_kwargs ) else: @@ -73,7 +73,7 @@ def get_resume_adaptive(cfg, model_kwargs): @hydra.main( - version_base="1.1", config_path="../configs", config_name="config_dev" + version_base="1.1", config_path="../configs", config_name="config" ) def main(cfg: DictConfig): @@ -106,7 +106,7 @@ def main(cfg: DictConfig): cfg, _ = get_resume_adaptive(cfg, model_kwargs) os.chdir(cfg.general.resume.split("checkpoints")[0]) - model = MCD(cfg=cfg, **model_kwargs) + model = Graph_DiT(cfg=cfg, **model_kwargs) trainer = Trainer( gradient_clip_val=cfg.train.clip_grad, accelerator="gpu" diff --git a/mcd/metrics/__init__.py b/graph_dit/metrics/__init__.py similarity index 100% rename from mcd/metrics/__init__.py rename to graph_dit/metrics/__init__.py diff --git a/mcd/metrics/abstract_metrics.py b/graph_dit/metrics/abstract_metrics.py similarity index 100% rename from mcd/metrics/abstract_metrics.py rename to graph_dit/metrics/abstract_metrics.py diff --git a/mcd/metrics/fpscores.pkl.gz b/graph_dit/metrics/fpscores.pkl.gz similarity index 100% rename from mcd/metrics/fpscores.pkl.gz rename to graph_dit/metrics/fpscores.pkl.gz diff --git a/mcd/metrics/molecular_metrics_sampling.py b/graph_dit/metrics/molecular_metrics_sampling.py similarity index 100% rename from mcd/metrics/molecular_metrics_sampling.py rename to graph_dit/metrics/molecular_metrics_sampling.py diff --git a/mcd/metrics/molecular_metrics_train.py b/graph_dit/metrics/molecular_metrics_train.py similarity index 100% rename from mcd/metrics/molecular_metrics_train.py rename to graph_dit/metrics/molecular_metrics_train.py diff --git a/mcd/metrics/property_metric.py b/graph_dit/metrics/property_metric.py similarity index 100% rename from mcd/metrics/property_metric.py rename to graph_dit/metrics/property_metric.py diff --git a/mcd/metrics/train_loss.py b/graph_dit/metrics/train_loss.py similarity index 100% rename from mcd/metrics/train_loss.py rename to graph_dit/metrics/train_loss.py diff --git a/mcd/models/__init__.py b/graph_dit/models/__init__.py similarity index 100% rename from mcd/models/__init__.py rename to graph_dit/models/__init__.py diff --git a/mcd/models/conditions.py b/graph_dit/models/conditions.py similarity index 100% rename from mcd/models/conditions.py rename to graph_dit/models/conditions.py diff --git a/mcd/models/layers.py b/graph_dit/models/layers.py similarity index 100% rename from mcd/models/layers.py rename to graph_dit/models/layers.py diff --git a/mcd/models/transformer.py b/graph_dit/models/transformer.py similarity index 96% rename from mcd/models/transformer.py rename to graph_dit/models/transformer.py index e9e95d1..a568191 100644 --- a/mcd/models/transformer.py +++ b/graph_dit/models/transformer.py @@ -44,7 +44,7 @@ class Denoiser(nn.Module): ] ) - self.decoder = Decoder( + self.out_layer = OutLayer( max_n_nodes=max_n_nodes, hidden_size=hidden_size, atom_type=Xdim, @@ -73,7 +73,7 @@ class Denoiser(nn.Module): for block in self.encoders : _constant_init(block.adaLN_modulation[0], 0) - _constant_init(self.decoder.adaLN_modulation[0], 0) + _constant_init(self.out_layer.adaLN_modulation[0], 0) def forward(self, x, e, node_mask, y, t, unconditioned): @@ -99,7 +99,7 @@ class Denoiser(nn.Module): x = block(x, c, node_mask) # X: B * N * dx, E: B * N * N * de - X, E, y = self.decoder(x, x_in, e_in, c, t, node_mask) + X, E, y = self.out_layer(x, x_in, e_in, c, t, node_mask) return utils.PlaceHolder(X=X, E=E, y=y).mask(node_mask) @@ -140,8 +140,8 @@ class SELayer(nn.Module): return x -class Decoder(nn.Module): - # Structure Decoder +class OutLayer(nn.Module): + # Structure Output Layer def __init__(self, max_n_nodes, hidden_size, atom_type, bond_type, mlp_ratio, num_heads=None): super().__init__() self.atom_type = atom_type diff --git a/mcd/utils.py b/graph_dit/utils.py similarity index 100% rename from mcd/utils.py rename to graph_dit/utils.py