update_name
This commit is contained in:
		
							
								
								
									
										161
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										161
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,161 @@ | |||||||
|  | # Byte-compiled / optimized / DLL files | ||||||
|  | __pycache__/ | ||||||
|  | *.py[cod] | ||||||
|  | *$py.class | ||||||
|  |  | ||||||
|  | # C extensions | ||||||
|  | *.so | ||||||
|  |  | ||||||
|  | # Distribution / packaging | ||||||
|  | .Python | ||||||
|  | build/ | ||||||
|  | develop-eggs/ | ||||||
|  | dist/ | ||||||
|  | downloads/ | ||||||
|  | eggs/ | ||||||
|  | .eggs/ | ||||||
|  | lib/ | ||||||
|  | lib64/ | ||||||
|  | parts/ | ||||||
|  | sdist/ | ||||||
|  | var/ | ||||||
|  | wheels/ | ||||||
|  | pip-wheel-metadata/ | ||||||
|  | share/python-wheels/ | ||||||
|  | *.egg-info/ | ||||||
|  | .installed.cfg | ||||||
|  | *.egg | ||||||
|  | MANIFEST | ||||||
|  |  | ||||||
|  | # PyInstaller | ||||||
|  | #  Usually these files are written by a python script from a template | ||||||
|  | #  before PyInstaller builds the exe, so as to inject date/other infos into it. | ||||||
|  | *.manifest | ||||||
|  | *.spec | ||||||
|  |  | ||||||
|  | # Installer logs | ||||||
|  | pip-log.txt | ||||||
|  | pip-delete-this-directory.txt | ||||||
|  |  | ||||||
|  | # Unit test / coverage reports | ||||||
|  | htmlcov/ | ||||||
|  | .tox/ | ||||||
|  | .nox/ | ||||||
|  | .coverage | ||||||
|  | .coverage.* | ||||||
|  | .cache | ||||||
|  | nosetests.xml | ||||||
|  | coverage.xml | ||||||
|  | *.cover | ||||||
|  | *.py,cover | ||||||
|  | .hypothesis/ | ||||||
|  | .pytest_cache/ | ||||||
|  |  | ||||||
|  | # Translations | ||||||
|  | *.mo | ||||||
|  | *.pot | ||||||
|  |  | ||||||
|  | # Django stuff: | ||||||
|  | *.log | ||||||
|  | local_settings.py | ||||||
|  | db.sqlite3 | ||||||
|  | db.sqlite3-journal | ||||||
|  |  | ||||||
|  | # Flask stuff: | ||||||
|  | instance/ | ||||||
|  | .webassets-cache | ||||||
|  |  | ||||||
|  | # Scrapy stuff: | ||||||
|  | .scrapy | ||||||
|  |  | ||||||
|  | # Sphinx documentation | ||||||
|  | docs/_build/ | ||||||
|  |  | ||||||
|  | # PyBuilder | ||||||
|  | target/ | ||||||
|  |  | ||||||
|  | # Jupyter Notebook | ||||||
|  | .ipynb_checkpoints | ||||||
|  |  | ||||||
|  | # IPython | ||||||
|  | profile_default/ | ||||||
|  | ipython_config.py | ||||||
|  |  | ||||||
|  | # pyenv | ||||||
|  | .python-version | ||||||
|  |  | ||||||
|  | # pipenv | ||||||
|  | #   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||||||
|  | #   However, in case of collaboration, if having platform-specific dependencies or dependencies | ||||||
|  | #   having no cross-platform support, pipenv may install dependencies that don't work, or not | ||||||
|  | #   install all needed dependencies. | ||||||
|  | #Pipfile.lock | ||||||
|  |  | ||||||
|  | # PEP 582; used by e.g. github.com/David-OConnor/pyflow | ||||||
|  | __pypackages__/ | ||||||
|  |  | ||||||
|  | # Celery stuff | ||||||
|  | celerybeat-schedule | ||||||
|  | celerybeat.pid | ||||||
|  |  | ||||||
|  | # SageMath parsed files | ||||||
|  | *.sage.py | ||||||
|  |  | ||||||
|  | # Environments | ||||||
|  | .env | ||||||
|  | .venv | ||||||
|  | env/ | ||||||
|  | venv/ | ||||||
|  | ENV/ | ||||||
|  | env.bak/ | ||||||
|  | venv.bak/ | ||||||
|  |  | ||||||
|  | # Spyder project settings | ||||||
|  | .spyderproject | ||||||
|  | .spyproject | ||||||
|  |  | ||||||
|  | # Rope project settings | ||||||
|  | .ropeproject | ||||||
|  |  | ||||||
|  | # mkdocs documentation | ||||||
|  | /site | ||||||
|  |  | ||||||
|  | # mypy | ||||||
|  | .mypy_cache/ | ||||||
|  | .dmypy.json | ||||||
|  | dmypy.json | ||||||
|  |  | ||||||
|  | # Pyre type checker | ||||||
|  | .pyre/ | ||||||
|  |  | ||||||
|  |  | ||||||
|  | .DS_Store | ||||||
|  | .idea/ | ||||||
|  | __pycache__/ | ||||||
|  | dgd/configs/__pycache__/ | ||||||
|  | egnn/__pycache__/ | ||||||
|  | equivariant_diffusion/__pycache__/ | ||||||
|  | outputs/ | ||||||
|  | archives/qm9/__pycache__/ | ||||||
|  | archives/qm9/data_utils/__pycache__/ | ||||||
|  | archives/qm9/data_utils/prepare/__pycache__/ | ||||||
|  | archives/qm9/property_prediction/__pycache__/ | ||||||
|  | archives/* | ||||||
|  | .env | ||||||
|  | results/*.ckpt | ||||||
|  | results/qm9_molecules_h | ||||||
|  | results/qm9_molecules_noh | ||||||
|  | dgd/analysis/orca/orca | ||||||
|  | results/* | ||||||
|  | ggg_data/ | ||||||
|  | ggg_utils/ | ||||||
|  | saved_models | ||||||
|  | src/analysis/orca/orca | ||||||
|  | src/analysis/orca/tmp_XMYAR426.txt | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # New | ||||||
|  | archive.zip | ||||||
|  | logs/ | ||||||
|  | generated/ | ||||||
|  | data/processed/ | ||||||
| @@ -1,9 +1,9 @@ | |||||||
| Inverse Molecular Design with Multi-Conditional Diffusion Guidance | Graph Diffusion Transformer for Multi-Conditional Molecular Generation | ||||||
| ================================================================ | ================================================================ | ||||||
|  |  | ||||||
| Paper: https://arxiv.org/abs/2401.13858 | Paper: https://arxiv.org/abs/2401.13858 | ||||||
|  |  | ||||||
| This is the code for MCD: a Multi-Conditional Diffusion Model for inverse small molecule and polymer designs and generations. The denoising model architecture in `mcd/models` looks like: | This is the code for Graph DiT. The denoising model architecture in `graph_dit/models` looks like: | ||||||
|  |  | ||||||
| <div style="display: flex;" markdown="1"> | <div style="display: flex;" markdown="1"> | ||||||
|       <img src="asset/reverse.png" style="width: 45%;" alt="Description of the first image"> |       <img src="asset/reverse.png" style="width: 45%;" alt="Description of the first image"> | ||||||
| @@ -16,7 +16,7 @@ All dependencies are specified in the `requirements.txt` file. | |||||||
|  |  | ||||||
| This code was developed and tested with Python 3.9.16, PyTorch 2.0.0, and PyG 2.3.0, Pytorch-lightning 2.0.1. | This code was developed and tested with Python 3.9.16, PyTorch 2.0.0, and PyG 2.3.0, Pytorch-lightning 2.0.1. | ||||||
|  |  | ||||||
| For molecular generation evaluation, we should first install rdkit:  | For molecular generation evaluation, we should first install rdkit. | ||||||
|  |  | ||||||
| Then `fcd_torch`: `pip install fcd_torch` (https://github.com/insilicomedicine/fcd_torch). | Then `fcd_torch`: `pip install fcd_torch` (https://github.com/insilicomedicine/fcd_torch). | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,5 +1,5 @@ | |||||||
| general: | general: | ||||||
|     name: 'MCD' |     name: 'graph_dit' | ||||||
|     wandb: 'disabled'  |     wandb: 'disabled'  | ||||||
|     gpus: 1 |     gpus: 1 | ||||||
|     resume: null |     resume: null | ||||||
| @@ -14,11 +14,11 @@ general: | |||||||
|     final_model_samples_to_save: 20 |     final_model_samples_to_save: 20 | ||||||
|     final_model_chains_to_save: 1 |     final_model_chains_to_save: 1 | ||||||
|     enable_progress_bar: False |     enable_progress_bar: False | ||||||
|     save_model: False |     save_model: True | ||||||
| model: | model: | ||||||
|     type: 'discrete' |     type: 'discrete' | ||||||
|     transition: 'marginal'                   |     transition: 'marginal'                   | ||||||
|     model: 'MCD' |     model: 'graph_dit' | ||||||
|     diffusion_steps: 500 |     diffusion_steps: 500 | ||||||
|     diffusion_noise_schedule: 'cosine' |     diffusion_noise_schedule: 'cosine' | ||||||
|     guide_scale: 2 |     guide_scale: 2 | ||||||
|   | |||||||
| @@ -12,7 +12,7 @@ from metrics.train_loss import TrainLossDiscrete | |||||||
| from metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL | from metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL | ||||||
| import utils | import utils | ||||||
| 
 | 
 | ||||||
| class MCD(pl.LightningModule): | class Graph_DiT(pl.LightningModule): | ||||||
|     def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools): |     def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics']) |         self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics']) | ||||||
| @@ -174,7 +174,6 @@ class MCD(pl.LightningModule): | |||||||
|     def validation_step(self, data, i): |     def validation_step(self, data, i): | ||||||
|         data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index] |         data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index] | ||||||
|         data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float() |         data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float() | ||||||
| 
 |  | ||||||
|         dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) |         dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) | ||||||
|         dense_data = dense_data.mask(node_mask) |         dense_data = dense_data.mask(node_mask) | ||||||
|         noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) |         noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) | ||||||
| @@ -281,7 +280,6 @@ class MCD(pl.LightningModule): | |||||||
|         chains_left_to_save = self.cfg.general.final_model_chains_to_save |         chains_left_to_save = self.cfg.general.final_model_chains_to_save | ||||||
| 
 | 
 | ||||||
|         samples, all_ys, batch_id = [], [], 0 |         samples, all_ys, batch_id = [], [], 0 | ||||||
| 
 |  | ||||||
|         test_y_collection = torch.cat(self.test_y_collection, dim=0) |         test_y_collection = torch.cat(self.test_y_collection, dim=0) | ||||||
|         num_examples = test_y_collection.size(0) |         num_examples = test_y_collection.size(0) | ||||||
|         if self.cfg.general.final_model_samples_to_generate > num_examples: |         if self.cfg.general.final_model_samples_to_generate > num_examples: | ||||||
| @@ -9,7 +9,7 @@ from pytorch_lightning import Trainer | |||||||
| 
 | 
 | ||||||
| import utils | import utils | ||||||
| from datasets import dataset | from datasets import dataset | ||||||
| from diffusion_model import MCD | from diffusion_model import Graph_DiT | ||||||
| from metrics.molecular_metrics_train import TrainMolecularMetricsDiscrete | from metrics.molecular_metrics_train import TrainMolecularMetricsDiscrete | ||||||
| from metrics.molecular_metrics_sampling import SamplingMolecularMetrics | from metrics.molecular_metrics_sampling import SamplingMolecularMetrics | ||||||
| 
 | 
 | ||||||
| @@ -36,7 +36,7 @@ def get_resume(cfg, model_kwargs): | |||||||
|     name = cfg.general.name + "_resume" |     name = cfg.general.name + "_resume" | ||||||
|     resume = cfg.general.test_only |     resume = cfg.general.test_only | ||||||
|     batch_size = cfg.train.batch_size |     batch_size = cfg.train.batch_size | ||||||
|     model = MCD.load_from_checkpoint(resume, **model_kwargs) |     model = Graph_DiT.load_from_checkpoint(resume, **model_kwargs) | ||||||
|     cfg = model.cfg |     cfg = model.cfg | ||||||
|     cfg.general.test_only = resume |     cfg.general.test_only = resume | ||||||
|     cfg.general.name = name |     cfg.general.name = name | ||||||
| @@ -54,7 +54,7 @@ def get_resume_adaptive(cfg, model_kwargs): | |||||||
|     resume_path = os.path.join(root_dir, cfg.general.resume) |     resume_path = os.path.join(root_dir, cfg.general.resume) | ||||||
| 
 | 
 | ||||||
|     if cfg.model.type == "discrete": |     if cfg.model.type == "discrete": | ||||||
|         model = MCD.load_from_checkpoint( |         model = Graph_DiT.load_from_checkpoint( | ||||||
|             resume_path, **model_kwargs |             resume_path, **model_kwargs | ||||||
|         ) |         ) | ||||||
|     else: |     else: | ||||||
| @@ -73,7 +73,7 @@ def get_resume_adaptive(cfg, model_kwargs): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @hydra.main( | @hydra.main( | ||||||
|     version_base="1.1", config_path="../configs", config_name="config_dev" |     version_base="1.1", config_path="../configs", config_name="config" | ||||||
| ) | ) | ||||||
| def main(cfg: DictConfig): | def main(cfg: DictConfig): | ||||||
| 
 | 
 | ||||||
| @@ -106,7 +106,7 @@ def main(cfg: DictConfig): | |||||||
|         cfg, _ = get_resume_adaptive(cfg, model_kwargs) |         cfg, _ = get_resume_adaptive(cfg, model_kwargs) | ||||||
|         os.chdir(cfg.general.resume.split("checkpoints")[0]) |         os.chdir(cfg.general.resume.split("checkpoints")[0]) | ||||||
| 
 | 
 | ||||||
|     model = MCD(cfg=cfg, **model_kwargs) |     model = Graph_DiT(cfg=cfg, **model_kwargs) | ||||||
|     trainer = Trainer( |     trainer = Trainer( | ||||||
|         gradient_clip_val=cfg.train.clip_grad, |         gradient_clip_val=cfg.train.clip_grad, | ||||||
|         accelerator="gpu" |         accelerator="gpu" | ||||||
| @@ -44,7 +44,7 @@ class Denoiser(nn.Module): | |||||||
|             ] |             ] | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         self.decoder = Decoder( |         self.out_layer = OutLayer( | ||||||
|             max_n_nodes=max_n_nodes, |             max_n_nodes=max_n_nodes, | ||||||
|             hidden_size=hidden_size, |             hidden_size=hidden_size, | ||||||
|             atom_type=Xdim, |             atom_type=Xdim, | ||||||
| @@ -73,7 +73,7 @@ class Denoiser(nn.Module): | |||||||
| 
 | 
 | ||||||
|         for block in self.encoders : |         for block in self.encoders : | ||||||
|             _constant_init(block.adaLN_modulation[0], 0) |             _constant_init(block.adaLN_modulation[0], 0) | ||||||
|         _constant_init(self.decoder.adaLN_modulation[0], 0) |         _constant_init(self.out_layer.adaLN_modulation[0], 0) | ||||||
| 
 | 
 | ||||||
|     def forward(self, x, e, node_mask, y, t, unconditioned): |     def forward(self, x, e, node_mask, y, t, unconditioned): | ||||||
|          |          | ||||||
| @@ -99,7 +99,7 @@ class Denoiser(nn.Module): | |||||||
|             x = block(x, c, node_mask) |             x = block(x, c, node_mask) | ||||||
| 
 | 
 | ||||||
|         # X: B * N * dx, E: B * N * N * de |         # X: B * N * dx, E: B * N * N * de | ||||||
|         X, E, y = self.decoder(x, x_in, e_in, c, t, node_mask) |         X, E, y = self.out_layer(x, x_in, e_in, c, t, node_mask) | ||||||
|         return utils.PlaceHolder(X=X, E=E, y=y).mask(node_mask) |         return utils.PlaceHolder(X=X, E=E, y=y).mask(node_mask) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @@ -140,8 +140,8 @@ class SELayer(nn.Module): | |||||||
|         return x |         return x | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Decoder(nn.Module): | class OutLayer(nn.Module): | ||||||
|     # Structure Decoder |     # Structure Output Layer | ||||||
|     def __init__(self, max_n_nodes, hidden_size, atom_type, bond_type, mlp_ratio, num_heads=None): |     def __init__(self, max_n_nodes, hidden_size, atom_type, bond_type, mlp_ratio, num_heads=None): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.atom_type = atom_type |         self.atom_type = atom_type | ||||||
		Reference in New Issue
	
	Block a user