Compare commits
4 Commits
63ca6c716e
...
trainer
Author | SHA1 | Date | |
---|---|---|---|
|
123cde9313 | ||
|
9360839a35 | ||
f75657ac3b | |||
be178bc5ee |
@@ -2,20 +2,26 @@ general:
|
|||||||
name: 'graph_dit'
|
name: 'graph_dit'
|
||||||
wandb: 'disabled'
|
wandb: 'disabled'
|
||||||
gpus: 1
|
gpus: 1
|
||||||
gpu_number: 2
|
gpu_number: 0
|
||||||
resume: null
|
resume: null
|
||||||
test_only: null
|
test_only: null
|
||||||
sample_every_val: 2500
|
sample_every_val: 2500
|
||||||
samples_to_generate: 512
|
samples_to_generate: 1000
|
||||||
samples_to_save: 3
|
samples_to_save: 3
|
||||||
chains_to_save: 1
|
chains_to_save: 1
|
||||||
log_every_steps: 50
|
log_every_steps: 50
|
||||||
number_chain_steps: 8
|
number_chain_steps: 8
|
||||||
final_model_samples_to_generate: 100
|
final_model_samples_to_generate: 1000
|
||||||
final_model_samples_to_save: 20
|
final_model_samples_to_save: 20
|
||||||
final_model_chains_to_save: 1
|
final_model_chains_to_save: 1
|
||||||
enable_progress_bar: False
|
enable_progress_bar: False
|
||||||
save_model: True
|
save_model: True
|
||||||
|
log_dir: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT'
|
||||||
|
number_checkpoint_limit: 3
|
||||||
|
type: 'Trainer'
|
||||||
|
nas_201: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
||||||
|
swap_result: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/swap_results.csv'
|
||||||
|
root: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/'
|
||||||
model:
|
model:
|
||||||
type: 'discrete'
|
type: 'discrete'
|
||||||
transition: 'marginal'
|
transition: 'marginal'
|
||||||
@@ -32,7 +38,7 @@ model:
|
|||||||
ensure_connected: True
|
ensure_connected: True
|
||||||
train:
|
train:
|
||||||
# n_epochs: 5000
|
# n_epochs: 5000
|
||||||
n_epochs: 500
|
n_epochs: 10
|
||||||
batch_size: 1200
|
batch_size: 1200
|
||||||
lr: 0.0002
|
lr: 0.0002
|
||||||
clip_grad: null
|
clip_grad: null
|
||||||
@@ -41,8 +47,11 @@ train:
|
|||||||
seed: 0
|
seed: 0
|
||||||
val_check_interval: null
|
val_check_interval: null
|
||||||
check_val_every_n_epoch: 1
|
check_val_every_n_epoch: 1
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
dataset:
|
dataset:
|
||||||
datadir: 'data/'
|
datadir: 'data/'
|
||||||
task_name: 'nasbench-201'
|
task_name: 'nasbench-201'
|
||||||
guidance_target: 'nasbench-201'
|
guidance_target: 'nasbench-201'
|
||||||
pin_memory: False
|
pin_memory: False
|
||||||
|
ppo:
|
||||||
|
clip_param: 1
|
||||||
|
228
environment.yaml
Normal file
228
environment.yaml
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
name: graphdit
|
||||||
|
channels:
|
||||||
|
- pytorch
|
||||||
|
- conda-forge
|
||||||
|
- defaults
|
||||||
|
dependencies:
|
||||||
|
- _libgcc_mutex=0.1=conda_forge
|
||||||
|
- _openmp_mutex=4.5=2_gnu
|
||||||
|
- asttokens=2.4.1=pyhd8ed1ab_0
|
||||||
|
- blas=1.0=mkl
|
||||||
|
- brotli-python=1.0.9=py39h6a678d5_8
|
||||||
|
- bzip2=1.0.8=h5eee18b_6
|
||||||
|
- ca-certificates=2024.7.2=h06a4308_0
|
||||||
|
- comm=0.2.2=pyhd8ed1ab_0
|
||||||
|
- debugpy=1.6.7=py39h6a678d5_0
|
||||||
|
- decorator=5.1.1=pyhd8ed1ab_0
|
||||||
|
- exceptiongroup=1.2.0=pyhd8ed1ab_2
|
||||||
|
- executing=2.0.1=pyhd8ed1ab_0
|
||||||
|
- ffmpeg=4.3=hf484d3e_0
|
||||||
|
- freetype=2.12.1=h4a9f257_0
|
||||||
|
- gmp=6.2.1=h295c915_3
|
||||||
|
- gmpy2=2.1.2=py39heeb90bb_0
|
||||||
|
- gnutls=3.6.15=he1e5248_0
|
||||||
|
- idna=3.7=py39h06a4308_0
|
||||||
|
- importlib-metadata=7.1.0=pyha770c72_0
|
||||||
|
- importlib_metadata=7.1.0=hd8ed1ab_0
|
||||||
|
- intel-openmp=2023.1.0=hdb19cb5_46306
|
||||||
|
- ipykernel=6.29.4=pyh3099207_0
|
||||||
|
- ipython=8.18.1=pyh707e725_3
|
||||||
|
- jedi=0.19.1=pyhd8ed1ab_0
|
||||||
|
- jinja2=3.1.4=py39h06a4308_0
|
||||||
|
- jpeg=9e=h5eee18b_1
|
||||||
|
- jupyter_client=8.6.2=pyhd8ed1ab_0
|
||||||
|
- jupyter_core=5.7.2=py39hf3d152e_0
|
||||||
|
- lame=3.100=h7b6447c_0
|
||||||
|
- lcms2=2.12=h3be6417_0
|
||||||
|
- ld_impl_linux-64=2.38=h1181459_1
|
||||||
|
- lerc=3.0=h295c915_0
|
||||||
|
- libdeflate=1.17=h5eee18b_1
|
||||||
|
- libffi=3.4.4=h6a678d5_1
|
||||||
|
- libgcc-ng=13.2.0=h77fa898_7
|
||||||
|
- libgomp=13.2.0=h77fa898_7
|
||||||
|
- libiconv=1.16=h5eee18b_3
|
||||||
|
- libidn2=2.3.4=h5eee18b_0
|
||||||
|
- libpng=1.6.39=h5eee18b_0
|
||||||
|
- libsodium=1.0.18=h36c2ea0_1
|
||||||
|
- libstdcxx-ng=11.2.0=h1234567_1
|
||||||
|
- libtasn1=4.19.0=h5eee18b_0
|
||||||
|
- libtiff=4.5.1=h6a678d5_0
|
||||||
|
- libunistring=0.9.10=h27cfd23_0
|
||||||
|
- libwebp-base=1.3.2=h5eee18b_0
|
||||||
|
- lz4-c=1.9.4=h6a678d5_1
|
||||||
|
- matplotlib-inline=0.1.7=pyhd8ed1ab_0
|
||||||
|
- mkl=2023.1.0=h213fc3f_46344
|
||||||
|
- mkl-service=2.4.0=py39h5eee18b_1
|
||||||
|
- mkl_fft=1.3.8=py39h5eee18b_0
|
||||||
|
- mkl_random=1.2.4=py39hdb19cb5_0
|
||||||
|
- mpc=1.1.0=h10f8cd9_1
|
||||||
|
- mpfr=4.0.2=hb69a4c5_1
|
||||||
|
- mpmath=1.3.0=py39h06a4308_0
|
||||||
|
- ncurses=6.4=h6a678d5_0
|
||||||
|
- nest-asyncio=1.6.0=pyhd8ed1ab_0
|
||||||
|
- nettle=3.7.3=hbbd107a_1
|
||||||
|
- numpy-base=1.26.4=py39hb5e798b_0
|
||||||
|
- openh264=2.1.1=h4ff587b_0
|
||||||
|
- openjpeg=2.4.0=h9ca470c_2
|
||||||
|
- openssl=3.3.1=h4ab18f5_0
|
||||||
|
- packaging=24.0=pyhd8ed1ab_0
|
||||||
|
- parso=0.8.4=pyhd8ed1ab_0
|
||||||
|
- pexpect=4.9.0=pyhd8ed1ab_0
|
||||||
|
- pickleshare=0.7.5=py_1003
|
||||||
|
- pip=24.0=py39h06a4308_0
|
||||||
|
- platformdirs=4.2.2=pyhd8ed1ab_0
|
||||||
|
- prompt-toolkit=3.0.46=pyha770c72_0
|
||||||
|
- psutil=5.9.8=py39hd1e30aa_0
|
||||||
|
- ptyprocess=0.7.0=pyhd3deb0d_0
|
||||||
|
- pure_eval=0.2.2=pyhd8ed1ab_0
|
||||||
|
- pygments=2.18.0=pyhd8ed1ab_0
|
||||||
|
- pysocks=1.7.1=py39h06a4308_0
|
||||||
|
- python=3.9.19=h955ad1f_1
|
||||||
|
- python_abi=3.9=2_cp39
|
||||||
|
- pytorch-mutex=1.0=cpu
|
||||||
|
- pyzmq=25.1.2=py39h6a678d5_0
|
||||||
|
- readline=8.2=h5eee18b_0
|
||||||
|
- setuptools=69.5.1=py39h06a4308_0
|
||||||
|
- six=1.16.0=pyh6c4a22f_0
|
||||||
|
- sqlite=3.45.3=h5eee18b_0
|
||||||
|
- stack_data=0.6.2=pyhd8ed1ab_0
|
||||||
|
- sympy=1.12=py39h06a4308_0
|
||||||
|
- tbb=2021.8.0=hdb19cb5_0
|
||||||
|
- tk=8.6.14=h39e8969_0
|
||||||
|
- tornado=6.4.1=py39hd3abc70_0
|
||||||
|
- traitlets=5.14.3=pyhd8ed1ab_0
|
||||||
|
- typing_extensions=4.12.2=pyha770c72_0
|
||||||
|
- wcwidth=0.2.13=pyhd8ed1ab_0
|
||||||
|
- wheel=0.43.0=py39h06a4308_0
|
||||||
|
- xz=5.4.6=h5eee18b_1
|
||||||
|
- zeromq=4.3.5=h6a678d5_0
|
||||||
|
- zlib=1.2.13=h5eee18b_1
|
||||||
|
- zstd=1.5.5=hc292b87_2
|
||||||
|
- pip:
|
||||||
|
- absl-py==2.1.0
|
||||||
|
- accelerate==0.34.2
|
||||||
|
- aiohttp==3.9.5
|
||||||
|
- aiosignal==1.3.1
|
||||||
|
- antlr4-python3-runtime==4.9.3
|
||||||
|
- astunparse==1.6.3
|
||||||
|
- async-timeout==4.0.3
|
||||||
|
- attrs==23.2.0
|
||||||
|
- beautifulsoup4==4.12.3
|
||||||
|
- bleach==6.1.0
|
||||||
|
- certifi==2024.2.2
|
||||||
|
- charset-normalizer==3.1.0
|
||||||
|
- cmake==3.29.3
|
||||||
|
- contourpy==1.2.1
|
||||||
|
- cycler==0.12.1
|
||||||
|
- defusedxml==0.7.1
|
||||||
|
- fastjsonschema==2.19.1
|
||||||
|
- fcd-torch==1.0.7
|
||||||
|
- filelock==3.14.0
|
||||||
|
- flatbuffers==24.3.25
|
||||||
|
- fonttools==4.52.4
|
||||||
|
- frozenlist==1.4.1
|
||||||
|
- fsspec==2024.5.0
|
||||||
|
- gast==0.5.4
|
||||||
|
- google-pasta==0.2.0
|
||||||
|
- grpcio==1.64.1
|
||||||
|
- h5py==3.11.0
|
||||||
|
- huggingface-hub==0.24.6
|
||||||
|
- hydra-core==1.3.2
|
||||||
|
- imageio==2.26.0
|
||||||
|
- importlib-resources==6.4.0
|
||||||
|
- joblib==1.2.0
|
||||||
|
- jsonschema==4.22.0
|
||||||
|
- jsonschema-specifications==2023.12.1
|
||||||
|
- jupyterlab-pygments==0.3.0
|
||||||
|
- keras==3.3.3
|
||||||
|
- kiwisolver==1.4.5
|
||||||
|
- libclang==18.1.1
|
||||||
|
- lightning-utilities==0.11.2
|
||||||
|
- lit==18.1.6
|
||||||
|
- markdown==3.6
|
||||||
|
- markdown-it-py==3.0.0
|
||||||
|
- markupsafe==2.1.5
|
||||||
|
- matplotlib==3.7.0
|
||||||
|
- mdurl==0.1.2
|
||||||
|
- mini-moses==1.0
|
||||||
|
- mistune==3.0.2
|
||||||
|
- ml-dtypes==0.3.2
|
||||||
|
- multidict==6.0.5
|
||||||
|
- namex==0.0.8
|
||||||
|
- nas-bench-201==2.1
|
||||||
|
- nasbench==1.0
|
||||||
|
- nbclient==0.10.0
|
||||||
|
- nbconvert==7.16.4
|
||||||
|
- nbformat==5.10.4
|
||||||
|
- networkx==3.0
|
||||||
|
- numpy==1.24.2
|
||||||
|
- nvidia-cublas-cu11==11.10.3.66
|
||||||
|
- nvidia-cublas-cu12==12.1.3.1
|
||||||
|
- nvidia-cuda-cupti-cu11==11.7.101
|
||||||
|
- nvidia-cuda-cupti-cu12==12.1.105
|
||||||
|
- nvidia-cuda-nvrtc-cu11==11.7.99
|
||||||
|
- nvidia-cuda-nvrtc-cu12==12.1.105
|
||||||
|
- nvidia-cuda-runtime-cu11==11.7.99
|
||||||
|
- nvidia-cuda-runtime-cu12==12.1.105
|
||||||
|
- nvidia-cudnn-cu11==8.5.0.96
|
||||||
|
- nvidia-cudnn-cu12==8.9.2.26
|
||||||
|
- nvidia-cufft-cu11==10.9.0.58
|
||||||
|
- nvidia-cufft-cu12==11.0.2.54
|
||||||
|
- nvidia-curand-cu11==10.2.10.91
|
||||||
|
- nvidia-curand-cu12==10.3.2.106
|
||||||
|
- nvidia-cusolver-cu11==11.4.0.1
|
||||||
|
- nvidia-cusolver-cu12==11.4.5.107
|
||||||
|
- nvidia-cusparse-cu11==11.7.4.91
|
||||||
|
- nvidia-cusparse-cu12==12.1.0.106
|
||||||
|
- nvidia-nccl-cu11==2.14.3
|
||||||
|
- nvidia-nccl-cu12==2.20.5
|
||||||
|
- nvidia-nvjitlink-cu12==12.5.40
|
||||||
|
- nvidia-nvtx-cu11==11.7.91
|
||||||
|
- nvidia-nvtx-cu12==12.1.105
|
||||||
|
- omegaconf==2.3.0
|
||||||
|
- opt-einsum==3.3.0
|
||||||
|
- optree==0.11.0
|
||||||
|
- pandas==1.5.3
|
||||||
|
- pandocfilters==1.5.1
|
||||||
|
- pillow==10.3.0
|
||||||
|
- protobuf==3.20.3
|
||||||
|
- pyparsing==3.1.2
|
||||||
|
- python-dateutil==2.9.0.post0
|
||||||
|
- pytorch-lightning==2.0.1
|
||||||
|
- pytz==2024.1
|
||||||
|
- pyyaml==6.0.1
|
||||||
|
- rdkit==2023.9.4
|
||||||
|
- referencing==0.35.1
|
||||||
|
- requests==2.32.2
|
||||||
|
- rich==13.7.1
|
||||||
|
- rpds-py==0.18.1
|
||||||
|
- safetensors==0.4.5
|
||||||
|
- scikit-learn==1.2.1
|
||||||
|
- scipy==1.13.1
|
||||||
|
- seaborn==0.13.2
|
||||||
|
- simplejson==3.19.2
|
||||||
|
- soupsieve==2.5
|
||||||
|
- tensorboard==2.16.2
|
||||||
|
- tensorboard-data-server==0.7.2
|
||||||
|
- tensorflow==2.16.1
|
||||||
|
- tensorflow-io-gcs-filesystem==0.37.0
|
||||||
|
- termcolor==2.4.0
|
||||||
|
- threadpoolctl==3.5.0
|
||||||
|
- tinycss2==1.3.0
|
||||||
|
- torch==2.0.0
|
||||||
|
- torch-geometric==2.3.0
|
||||||
|
- torchaudio==2.0.1+rocm5.4.2
|
||||||
|
- torchmetrics==0.11.4
|
||||||
|
- torchvision==0.15.1
|
||||||
|
- tqdm==4.64.1
|
||||||
|
- triton==2.0.0
|
||||||
|
- typing-extensions==4.12.0
|
||||||
|
- tzdata==2024.1
|
||||||
|
- urllib3==2.2.1
|
||||||
|
- webencodings==0.5.1
|
||||||
|
- werkzeug==3.0.3
|
||||||
|
- wrapt==1.16.0
|
||||||
|
- yacs==0.1.8
|
||||||
|
- yarl==1.9.4
|
||||||
|
- zipp==3.19.0
|
||||||
|
prefix: /home/stud/hanzhang/anaconda3/envs/graphdit
|
@@ -54,7 +54,9 @@ class BasicGraphMetrics(object):
|
|||||||
covered_nodes = set()
|
covered_nodes = set()
|
||||||
direct_valid_count = 0
|
direct_valid_count = 0
|
||||||
print(f"generated number: {len(generated)}")
|
print(f"generated number: {len(generated)}")
|
||||||
|
print(f"generated: {generated}")
|
||||||
for graph in generated:
|
for graph in generated:
|
||||||
|
print(f"graph: {graph}")
|
||||||
node_types, edge_types = graph
|
node_types, edge_types = graph
|
||||||
direct_valid_flag = True
|
direct_valid_flag = True
|
||||||
direct_valid_count += 1
|
direct_valid_count += 1
|
||||||
|
@@ -25,7 +25,6 @@ from sklearn.model_selection import train_test_split
|
|||||||
import utils as utils
|
import utils as utils
|
||||||
from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule
|
from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule
|
||||||
from diffusion.distributions import DistributionNodes
|
from diffusion.distributions import DistributionNodes
|
||||||
from naswot.score_networks import get_nasbench201_idx_score
|
|
||||||
from naswot import nasspace
|
from naswot import nasspace
|
||||||
from naswot import datasets as dt
|
from naswot import datasets as dt
|
||||||
|
|
||||||
@@ -72,7 +71,9 @@ class DataModule(AbstractDataModule):
|
|||||||
# base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
|
# base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
|
||||||
# except NameError:
|
# except NameError:
|
||||||
# base_path = pathlib.Path(os.getcwd()).parent[2]
|
# base_path = pathlib.Path(os.getcwd()).parent[2]
|
||||||
base_path = '/nfs/data3/hanzhang/nasbenchDiT'
|
# base_path = '/nfs/data3/hanzhang/nasbenchDiT'
|
||||||
|
base_path = os.path.join(self.cfg.general.root, "..")
|
||||||
|
|
||||||
root_path = os.path.join(base_path, self.datadir)
|
root_path = os.path.join(base_path, self.datadir)
|
||||||
self.root_path = root_path
|
self.root_path = root_path
|
||||||
|
|
||||||
@@ -84,7 +85,7 @@ class DataModule(AbstractDataModule):
|
|||||||
# Load the dataset to the memory
|
# Load the dataset to the memory
|
||||||
# Dataset has target property, root path, and transform
|
# Dataset has target property, root path, and transform
|
||||||
source = './NAS-Bench-201-v1_1-096897.pth'
|
source = './NAS-Bench-201-v1_1-096897.pth'
|
||||||
dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None)
|
dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None, cfg=self.cfg)
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
# self.api = dataset.api
|
# self.api = dataset.api
|
||||||
|
|
||||||
@@ -384,7 +385,7 @@ class DataModule_original(AbstractDataModule):
|
|||||||
def test_dataloader(self):
|
def test_dataloader(self):
|
||||||
return self.test_loader
|
return self.test_loader
|
||||||
|
|
||||||
def new_graphs_to_json(graphs, filename):
|
def new_graphs_to_json(graphs, filename, cfg):
|
||||||
source_name = "nasbench-201"
|
source_name = "nasbench-201"
|
||||||
num_graph = len(graphs)
|
num_graph = len(graphs)
|
||||||
|
|
||||||
@@ -491,8 +492,9 @@ def new_graphs_to_json(graphs, filename):
|
|||||||
'num_active_nodes': len(active_nodes),
|
'num_active_nodes': len(active_nodes),
|
||||||
'transition_E': transition_E.tolist(),
|
'transition_E': transition_E.tolist(),
|
||||||
}
|
}
|
||||||
|
import os
|
||||||
with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f:
|
# with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f:
|
||||||
|
with open(os.path.join(cfg.general.root,'nasbench-201-meta.json'), 'w') as f:
|
||||||
json.dump(meta_dict, f)
|
json.dump(meta_dict, f)
|
||||||
|
|
||||||
return meta_dict
|
return meta_dict
|
||||||
@@ -656,9 +658,11 @@ def graphs_to_json(graphs, filename):
|
|||||||
json.dump(meta_dict, f)
|
json.dump(meta_dict, f)
|
||||||
return meta_dict
|
return meta_dict
|
||||||
class Dataset(InMemoryDataset):
|
class Dataset(InMemoryDataset):
|
||||||
def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):
|
def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None, cfg=None):
|
||||||
self.target_prop = target_prop
|
self.target_prop = target_prop
|
||||||
source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
self.cfg = cfg
|
||||||
|
# source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
||||||
|
source = os.path.join(self.cfg.general.root, 'NAS-Bench-201-v1_1-096897.pth')
|
||||||
self.source = source
|
self.source = source
|
||||||
# self.api = API(source) # Initialize NAS-Bench-201 API
|
# self.api = API(source) # Initialize NAS-Bench-201 API
|
||||||
# print('API loaded')
|
# print('API loaded')
|
||||||
@@ -679,7 +683,8 @@ class Dataset(InMemoryDataset):
|
|||||||
return [f'{self.source}.pt']
|
return [f'{self.source}.pt']
|
||||||
|
|
||||||
def process(self):
|
def process(self):
|
||||||
source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
# source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
||||||
|
source = self.cfg.general.nas_201
|
||||||
# self.api = API(source)
|
# self.api = API(source)
|
||||||
|
|
||||||
data_list = []
|
data_list = []
|
||||||
@@ -748,7 +753,8 @@ class Dataset(InMemoryDataset):
|
|||||||
return edges,nodes
|
return edges,nodes
|
||||||
|
|
||||||
|
|
||||||
def graph_to_graph_data(graph, idx, train_loader, searchspace, args, device):
|
# def graph_to_graph_data(graph, idx, train_loader, searchspace, args, device):
|
||||||
|
def graph_to_graph_data(graph, idx, args, device):
|
||||||
# def graph_to_graph_data(graph):
|
# def graph_to_graph_data(graph):
|
||||||
ops = graph[1]
|
ops = graph[1]
|
||||||
adj = graph[0]
|
adj = graph[0]
|
||||||
@@ -797,7 +803,7 @@ class Dataset(InMemoryDataset):
|
|||||||
args.batch_size = 128
|
args.batch_size = 128
|
||||||
args.GPU = '0'
|
args.GPU = '0'
|
||||||
args.dataset = 'cifar10'
|
args.dataset = 'cifar10'
|
||||||
args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
args.api_loc = self.cfg.general.nas_201
|
||||||
args.data_loc = '../cifardata/'
|
args.data_loc = '../cifardata/'
|
||||||
args.seed = 777
|
args.seed = 777
|
||||||
args.init = ''
|
args.init = ''
|
||||||
@@ -812,11 +818,12 @@ class Dataset(InMemoryDataset):
|
|||||||
args.num_modules_per_stack = 3
|
args.num_modules_per_stack = 3
|
||||||
args.num_labels = 1
|
args.num_labels = 1
|
||||||
searchspace = nasspace.get_search_space(args)
|
searchspace = nasspace.get_search_space(args)
|
||||||
train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
|
# train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
|
||||||
self.swap_scores = []
|
self.swap_scores = []
|
||||||
import csv
|
import csv
|
||||||
# with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f:
|
# with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f:
|
||||||
with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results_cifar100.csv', 'r') as f:
|
with open(self.cfg.general.swap_result, 'r') as f:
|
||||||
|
# with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results_cifar100.csv', 'r') as f:
|
||||||
reader = csv.reader(f)
|
reader = csv.reader(f)
|
||||||
header = next(reader)
|
header = next(reader)
|
||||||
data = [row for row in reader]
|
data = [row for row in reader]
|
||||||
@@ -824,12 +831,15 @@ class Dataset(InMemoryDataset):
|
|||||||
device = torch.device('cuda:2')
|
device = torch.device('cuda:2')
|
||||||
with tqdm(total = len_data) as pbar:
|
with tqdm(total = len_data) as pbar:
|
||||||
active_nodes = set()
|
active_nodes = set()
|
||||||
file_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json'
|
import os
|
||||||
|
# file_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json'
|
||||||
|
file_path = os.path.join(self.cfg.general.root, 'nasbench-201-graph.json')
|
||||||
with open(file_path, 'r') as f:
|
with open(file_path, 'r') as f:
|
||||||
graph_list = json.load(f)
|
graph_list = json.load(f)
|
||||||
i = 0
|
i = 0
|
||||||
flex_graph_list = []
|
flex_graph_list = []
|
||||||
flex_graph_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json'
|
# flex_graph_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json'
|
||||||
|
flex_graph_path = os.path.join(self.cfg.general.root,'flex-nasbench201-graph.json')
|
||||||
for graph in graph_list:
|
for graph in graph_list:
|
||||||
print(f'iterate every graph in graph_list, here is {i}')
|
print(f'iterate every graph in graph_list, here is {i}')
|
||||||
arch_info = graph['arch_str']
|
arch_info = graph['arch_str']
|
||||||
@@ -837,7 +847,8 @@ class Dataset(InMemoryDataset):
|
|||||||
for op in ops:
|
for op in ops:
|
||||||
if op not in active_nodes:
|
if op not in active_nodes:
|
||||||
active_nodes.add(op)
|
active_nodes.add(op)
|
||||||
data = graph_to_graph_data((adj_matrix, ops),idx=i, train_loader=train_loader, searchspace=searchspace, args=args, device=device)
|
# data = graph_to_graph_data((adj_matrix, ops),idx=i, train_loader=train_loader, searchspace=searchspace, args=args, device=device)
|
||||||
|
data = graph_to_graph_data((adj_matrix, ops),idx=i, args=args, device=device)
|
||||||
i += 1
|
i += 1
|
||||||
if data is None:
|
if data is None:
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
@@ -1140,6 +1151,7 @@ class DataInfos(AbstractDatasetInfos):
|
|||||||
self.task = task_name
|
self.task = task_name
|
||||||
self.task_type = tasktype_dict.get(task_name, "regression")
|
self.task_type = tasktype_dict.get(task_name, "regression")
|
||||||
self.ensure_connected = cfg.model.ensure_connected
|
self.ensure_connected = cfg.model.ensure_connected
|
||||||
|
self.cfg = cfg
|
||||||
# self.api = dataset.api
|
# self.api = dataset.api
|
||||||
|
|
||||||
datadir = cfg.dataset.datadir
|
datadir = cfg.dataset.datadir
|
||||||
@@ -1182,14 +1194,15 @@ class DataInfos(AbstractDatasetInfos):
|
|||||||
# len_ops.add(len(ops))
|
# len_ops.add(len(ops))
|
||||||
# graphs.append((adj_matrix, ops))
|
# graphs.append((adj_matrix, ops))
|
||||||
# graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json')
|
# graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json')
|
||||||
graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json')
|
# graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json')
|
||||||
|
graphs = read_adj_ops_from_json(os.path.join(self.cfg.general.root, 'nasbench-201-graph.json'))
|
||||||
|
|
||||||
# check first five graphs
|
# check first five graphs
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
print(f'graph {i} : {graphs[i]}')
|
print(f'graph {i} : {graphs[i]}')
|
||||||
# print(f'ops_type: {ops_type}')
|
# print(f'ops_type: {ops_type}')
|
||||||
|
|
||||||
meta_dict = new_graphs_to_json(graphs, 'nasbench-201')
|
meta_dict = new_graphs_to_json(graphs, 'nasbench-201', self.cfg)
|
||||||
self.base_path = base_path
|
self.base_path = base_path
|
||||||
self.active_nodes = meta_dict['active_nodes']
|
self.active_nodes = meta_dict['active_nodes']
|
||||||
self.max_n_nodes = meta_dict['max_n_nodes']
|
self.max_n_nodes = meta_dict['max_n_nodes']
|
||||||
@@ -1396,11 +1409,12 @@ def compute_meta(root, source_name, train_index, test_index):
|
|||||||
'transition_E': tansition_E.tolist(),
|
'transition_E': tansition_E.tolist(),
|
||||||
}
|
}
|
||||||
|
|
||||||
with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f:
|
# with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f:
|
||||||
|
with open(os.path.join(self.cfg.general.root, 'nasbench201.meta.json'), "w") as f:
|
||||||
json.dump(meta_dict, f)
|
json.dump(meta_dict, f)
|
||||||
|
|
||||||
return meta_dict
|
return meta_dict
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
dataset = Dataset(source='nasbench', root='/nfs/data3/hanzhang/nasbenchDiT/graph-dit', target_prop='Class', transform=None)
|
dataset = Dataset(source='nasbench', root='/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/', target_prop='Class', transform=None)
|
||||||
|
@@ -23,6 +23,9 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
self.test_only = cfg.general.test_only
|
self.test_only = cfg.general.test_only
|
||||||
self.guidance_target = getattr(cfg.dataset, 'guidance_target', None)
|
self.guidance_target = getattr(cfg.dataset, 'guidance_target', None)
|
||||||
|
|
||||||
|
from nas_201_api import NASBench201API as API
|
||||||
|
self.api = API(cfg.general.nas_201)
|
||||||
|
|
||||||
input_dims = dataset_infos.input_dims
|
input_dims = dataset_infos.input_dims
|
||||||
output_dims = dataset_infos.output_dims
|
output_dims = dataset_infos.output_dims
|
||||||
nodes_dist = dataset_infos.nodes_dist
|
nodes_dist = dataset_infos.nodes_dist
|
||||||
@@ -41,7 +44,7 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
self.args.batch_size = 128
|
self.args.batch_size = 128
|
||||||
self.args.GPU = '0'
|
self.args.GPU = '0'
|
||||||
self.args.dataset = 'cifar10-valid'
|
self.args.dataset = 'cifar10-valid'
|
||||||
self.args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
self.args.api_loc = cfg.general.nas_201
|
||||||
self.args.data_loc = '../cifardata/'
|
self.args.data_loc = '../cifardata/'
|
||||||
self.args.seed = 777
|
self.args.seed = 777
|
||||||
self.args.init = ''
|
self.args.init = ''
|
||||||
@@ -79,6 +82,7 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
self.node_dist = nodes_dist
|
self.node_dist = nodes_dist
|
||||||
self.active_index = active_index
|
self.active_index = active_index
|
||||||
self.dataset_info = dataset_infos
|
self.dataset_info = dataset_infos
|
||||||
|
self.cur_epoch = 0
|
||||||
|
|
||||||
self.train_loss = TrainLossDiscrete(self.cfg.model.lambda_train)
|
self.train_loss = TrainLossDiscrete(self.cfg.model.lambda_train)
|
||||||
|
|
||||||
@@ -162,6 +166,62 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
return pred
|
return pred
|
||||||
|
|
||||||
def training_step(self, data, i):
|
def training_step(self, data, i):
|
||||||
|
if self.cfg.general.type != 'accelerator' and self.current_epoch > self.cfg.train.n_epochs / 5 * 4:
|
||||||
|
samples_left_to_generate = self.cfg.general.samples_to_generate
|
||||||
|
samples_left_to_save = self.cfg.general.samples_to_save
|
||||||
|
chains_left_to_save = self.cfg.general.chains_to_save
|
||||||
|
|
||||||
|
samples, all_ys, batch_id = [], [], 0
|
||||||
|
|
||||||
|
def graph_reward_fn(graphs, true_graphs=None, device=None, reward_model='swap'):
|
||||||
|
rewards = []
|
||||||
|
if reward_model == 'swap':
|
||||||
|
import csv
|
||||||
|
with open(self.cfg.general.swap_result, 'r') as f:
|
||||||
|
reader = csv.reader(f)
|
||||||
|
header = next(reader)
|
||||||
|
data = [row for row in reader]
|
||||||
|
swap_scores = [float(row[0]) for row in data]
|
||||||
|
for graph in graphs:
|
||||||
|
node_tensor = graph[0]
|
||||||
|
node = node_tensor.cpu().numpy().tolist()
|
||||||
|
|
||||||
|
def nodes_to_arch_str(nodes):
|
||||||
|
num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
|
||||||
|
nodes_str = [num_to_op[node] for node in nodes]
|
||||||
|
arch_str = '|' + nodes_str[1] + '~0|+' + \
|
||||||
|
'|' + nodes_str[2] + '~0|' + nodes_str[3] + '~1|+' +\
|
||||||
|
'|' + nodes_str[4] + '~0|' + nodes_str[5] + '~1|' + nodes_str[6] + '~2|'
|
||||||
|
return arch_str
|
||||||
|
|
||||||
|
arch_str = nodes_to_arch_str(node)
|
||||||
|
reward = swap_scores[self.api.query_index_by_arch(arch_str)]
|
||||||
|
rewards.append(reward)
|
||||||
|
return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device)
|
||||||
|
old_log_probs = None
|
||||||
|
|
||||||
|
bs = 1 * self.cfg.train.batch_size
|
||||||
|
to_generate = min(samples_left_to_generate, bs)
|
||||||
|
to_save = min(samples_left_to_save, bs)
|
||||||
|
chains_save = min(chains_left_to_save, bs)
|
||||||
|
# batch_y = test_y_collection[batch_id : batch_id + to_generate]
|
||||||
|
batch_y = torch.ones(to_generate, self.ydim_output, device=self.device)
|
||||||
|
|
||||||
|
cur_sample, log_probs = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
|
||||||
|
keep_chain=chains_save, number_chain_steps=self.number_chain_steps)
|
||||||
|
# samples = samples + cur_sample
|
||||||
|
samples.append(cur_sample)
|
||||||
|
reward = graph_reward_fn(cur_sample, device=self.device)
|
||||||
|
advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6) #
|
||||||
|
if old_log_probs is None:
|
||||||
|
old_log_probs = log_probs.clone()
|
||||||
|
ratio = torch.exp(log_probs - old_log_probs)
|
||||||
|
print(f"ratio: {ratio.shape}, advantages: {advantages.shape}")
|
||||||
|
unclipped_loss = -advantages * ratio
|
||||||
|
clipped_loss = -advantages * torch.clamp(ratio, 1.0 - self.cfg.ppo.clip_param, 1.0 + self.cfg.ppo.clip_param)
|
||||||
|
loss = torch.mean(torch.max(unclipped_loss, clipped_loss))
|
||||||
|
return {'loss': loss}
|
||||||
|
else:
|
||||||
data_x = F.one_hot(data.x, num_classes=12).float()[:, self.active_index]
|
data_x = F.one_hot(data.x, num_classes=12).float()[:, self.active_index]
|
||||||
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
|
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
|
||||||
|
|
||||||
@@ -196,14 +256,15 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
|
|
||||||
def on_train_epoch_start(self) -> None:
|
def on_train_epoch_start(self) -> None:
|
||||||
if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
|
if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
|
||||||
print("Starting train epoch {}/{}...".format(self.current_epoch, self.trainer.max_epochs))
|
# if self.cur_epoch / self.cfg.train.n_epochs in [0.25, 0.5, 0.75, 1.0]:
|
||||||
|
print("Starting train epoch {}/{}...".format(self.cur_epoch, self.cfg.train.n_epochs))
|
||||||
self.start_epoch_time = time.time()
|
self.start_epoch_time = time.time()
|
||||||
self.train_loss.reset()
|
self.train_loss.reset()
|
||||||
self.train_metrics.reset()
|
self.train_metrics.reset()
|
||||||
|
|
||||||
def on_train_epoch_end(self) -> None:
|
def on_train_epoch_end(self) -> None:
|
||||||
|
|
||||||
if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
|
if self.current_epoch / self.cfg.train.n_epochs in [0.25, 0.5, 0.75, 1.0]:
|
||||||
log = True
|
log = True
|
||||||
else:
|
else:
|
||||||
log = False
|
log = False
|
||||||
@@ -240,6 +301,7 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
self.val_X_logp.compute(), self.val_E_logp.compute()]
|
self.val_X_logp.compute(), self.val_E_logp.compute()]
|
||||||
|
|
||||||
if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
|
if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
|
||||||
|
# if self.cur_epoch / self.cfg.train.n_epochs in [0.25, 0.5, 0.75, 1.0]:
|
||||||
print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ",
|
print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ",
|
||||||
f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best : %.2f\n' % (metrics[0], self.best_val_nll))
|
f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best : %.2f\n' % (metrics[0], self.best_val_nll))
|
||||||
with open("validation-metrics.csv", "a") as f:
|
with open("validation-metrics.csv", "a") as f:
|
||||||
@@ -283,10 +345,15 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
num_examples = self.val_y_collection.size(0)
|
num_examples = self.val_y_collection.size(0)
|
||||||
batch_y = self.val_y_collection[start_index:start_index + to_generate]
|
batch_y = self.val_y_collection[start_index:start_index + to_generate]
|
||||||
all_ys.append(batch_y)
|
all_ys.append(batch_y)
|
||||||
samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y,
|
cur_sample, logprobs = self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y,
|
||||||
save_final=to_save,
|
save_final=to_save,
|
||||||
keep_chain=chains_save,
|
keep_chain=chains_save,
|
||||||
number_chain_steps=self.number_chain_steps))
|
number_chain_steps=self.number_chain_steps)
|
||||||
|
samples.extend(cur_sample)
|
||||||
|
# samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y,
|
||||||
|
# save_final=to_save,
|
||||||
|
# keep_chain=chains_save,
|
||||||
|
# number_chain_steps=self.number_chain_steps))
|
||||||
ident += to_generate
|
ident += to_generate
|
||||||
start_index += to_generate
|
start_index += to_generate
|
||||||
|
|
||||||
@@ -336,7 +403,7 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
print(f"Epoch {self.current_epoch}: Test NLL {metrics[0] :.2f} -- Test Atom type KL {metrics[1] :.2f} -- ",
|
print(f"Epoch {self.current_epoch}: Test NLL {metrics[0] :.2f} -- Test Atom type KL {metrics[1] :.2f} -- ",
|
||||||
f"Test Edge type KL: {metrics[2] :.2f}")
|
f"Test Edge type KL: {metrics[2] :.2f}")
|
||||||
|
|
||||||
## final epcoh
|
## final epoch
|
||||||
samples_left_to_generate = self.cfg.general.final_model_samples_to_generate
|
samples_left_to_generate = self.cfg.general.final_model_samples_to_generate
|
||||||
samples_left_to_save = self.cfg.general.final_model_samples_to_save
|
samples_left_to_save = self.cfg.general.final_model_samples_to_save
|
||||||
chains_left_to_save = self.cfg.general.final_model_chains_to_save
|
chains_left_to_save = self.cfg.general.final_model_chains_to_save
|
||||||
@@ -359,9 +426,9 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
# batch_y = test_y_collection[batch_id : batch_id + to_generate]
|
# batch_y = test_y_collection[batch_id : batch_id + to_generate]
|
||||||
batch_y = torch.ones(to_generate, self.ydim_output, device=self.device)
|
batch_y = torch.ones(to_generate, self.ydim_output, device=self.device)
|
||||||
|
|
||||||
cur_sample = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
|
cur_sample, log_probs = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
|
||||||
keep_chain=chains_save, number_chain_steps=self.number_chain_steps)
|
keep_chain=chains_save, number_chain_steps=self.number_chain_steps)
|
||||||
samples = samples + cur_sample
|
samples.extend(cur_sample)
|
||||||
|
|
||||||
all_ys.append(batch_y)
|
all_ys.append(batch_y)
|
||||||
batch_id += to_generate
|
batch_id += to_generate
|
||||||
@@ -601,6 +668,12 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
|
|
||||||
assert (E == torch.transpose(E, 1, 2)).all()
|
assert (E == torch.transpose(E, 1, 2)).all()
|
||||||
|
|
||||||
|
if self.cfg.general.type != 'accelerator':
|
||||||
|
if self.trainer.training or self.trainer.validating:
|
||||||
|
total_log_probs = torch.zeros([self.cfg.general.samples_to_generate, 10], device=self.device)
|
||||||
|
elif self.trainer.testing:
|
||||||
|
total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate, 10], device=self.device)
|
||||||
|
|
||||||
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
|
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
|
||||||
for s_int in reversed(range(0, self.T)):
|
for s_int in reversed(range(0, self.T)):
|
||||||
s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
|
s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
|
||||||
@@ -609,21 +682,24 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
t_norm = t_array / self.T
|
t_norm = t_array / self.T
|
||||||
|
|
||||||
# Sample z_s
|
# Sample z_s
|
||||||
sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask)
|
sampled_s, discrete_sampled_s, log_probs = self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask)
|
||||||
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
|
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
|
||||||
|
total_log_probs += log_probs
|
||||||
|
|
||||||
# Sample
|
# Sample
|
||||||
sampled_s = sampled_s.mask(node_mask, collapse=True)
|
sampled_s = sampled_s.mask(node_mask, collapse=True)
|
||||||
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
|
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
|
||||||
|
|
||||||
molecule_list = []
|
graph_list = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
n = n_nodes[i]
|
n = n_nodes[i]
|
||||||
atom_types = X[i, :n].cpu()
|
node_types = X[i, :n].cpu()
|
||||||
edge_types = E[i, :n, :n].cpu()
|
edge_types = E[i, :n, :n].cpu()
|
||||||
molecule_list.append([atom_types, edge_types])
|
graph_list.append((node_types , edge_types))
|
||||||
|
|
||||||
return molecule_list
|
total_log_probs = torch.sum(total_log_probs, dim=-1)
|
||||||
|
|
||||||
|
return graph_list, total_log_probs
|
||||||
|
|
||||||
def sample_p_zs_given_zt(self, s, t, X_t, E_t, y_t, node_mask):
|
def sample_p_zs_given_zt(self, s, t, X_t, E_t, y_t, node_mask):
|
||||||
"""Samples from zs ~ p(zs | zt). Only used during sampling.
|
"""Samples from zs ~ p(zs | zt). Only used during sampling.
|
||||||
@@ -675,6 +751,14 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
# with condition = P_t(A_{t-1} |A_t, y)
|
# with condition = P_t(A_{t-1} |A_t, y)
|
||||||
prob_X, prob_E, pred = get_prob(noisy_data)
|
prob_X, prob_E, pred = get_prob(noisy_data)
|
||||||
|
|
||||||
|
log_prob_X = torch.log(torch.gather(prob_X, -1, X_t.long()).squeeze(-1)) # bs, n
|
||||||
|
log_prob_E = torch.log(torch.gather(prob_E, -1, E_t.long()).squeeze(-1)) # bs, n, n
|
||||||
|
|
||||||
|
# Sum the log_prob across dimensions for total log_prob
|
||||||
|
log_prob_X = log_prob_X.sum(dim=-1)
|
||||||
|
log_prob_E = log_prob_E.sum(dim=(1, 2))
|
||||||
|
|
||||||
|
log_probs = torch.cat([log_prob_X, log_prob_E], dim=-1)
|
||||||
### Guidance
|
### Guidance
|
||||||
if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1:
|
if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1:
|
||||||
uncon_prob_X, uncon_prob_E, pred = get_prob(noisy_data, unconditioned=True)
|
uncon_prob_X, uncon_prob_E, pred = get_prob(noisy_data, unconditioned=True)
|
||||||
@@ -810,4 +894,4 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=y_t)
|
out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=y_t)
|
||||||
out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=y_t)
|
out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=y_t)
|
||||||
|
|
||||||
return out_one_hot.mask(node_mask).type_as(y_t), out_discrete.mask(node_mask, collapse=True).type_as(y_t)
|
return out_one_hot.mask(node_mask).type_as(y_t), out_discrete.mask(node_mask, collapse=True).type_as(y_t), log_probs
|
||||||
|
@@ -177,6 +177,66 @@ def test(cfg: DictConfig):
|
|||||||
os.chdir(cfg.general.resume.split("checkpoints")[0])
|
os.chdir(cfg.general.resume.split("checkpoints")[0])
|
||||||
# os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number
|
# os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number
|
||||||
model = Graph_DiT(cfg=cfg, **model_kwargs)
|
model = Graph_DiT(cfg=cfg, **model_kwargs)
|
||||||
|
|
||||||
|
if cfg.general.type == "accelerator":
|
||||||
|
graph_dit_model = model
|
||||||
|
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from accelerate.utils import set_seed, ProjectConfiguration
|
||||||
|
|
||||||
|
accelerator_config = ProjectConfiguration(
|
||||||
|
project_dir=os.path.join(cfg.general.log_dir, cfg.general.name),
|
||||||
|
automatic_checkpoint_naming=True,
|
||||||
|
total_limit=cfg.general.number_checkpoint_limit,
|
||||||
|
)
|
||||||
|
accelerator = Accelerator(
|
||||||
|
mixed_precision='no',
|
||||||
|
project_config=accelerator_config,
|
||||||
|
# gradient_accumulation_steps=cfg.train.gradient_accumulation_steps * cfg.train.n_epochs,
|
||||||
|
gradient_accumulation_steps=cfg.train.gradient_accumulation_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer = graph_dit_model.configure_optimizers()
|
||||||
|
|
||||||
|
train_dataloader = datamodule.train_dataloader()
|
||||||
|
train_dataloader = accelerator.prepare(train_dataloader)
|
||||||
|
val_dataloader = datamodule.val_dataloader()
|
||||||
|
val_dataloader = accelerator.prepare(val_dataloader)
|
||||||
|
test_dataloader = datamodule.test_dataloader()
|
||||||
|
test_dataloader = accelerator.prepare(test_dataloader)
|
||||||
|
|
||||||
|
optimizer, graph_dit_model = accelerator.prepare(optimizer, graph_dit_model)
|
||||||
|
|
||||||
|
# train_epoch
|
||||||
|
from pytorch_lightning import seed_everything
|
||||||
|
seed_everything(cfg.train.seed)
|
||||||
|
for epoch in range(cfg.train.n_epochs):
|
||||||
|
print(f"Epoch {epoch}")
|
||||||
|
graph_dit_model.train()
|
||||||
|
graph_dit_model.cur_epoch = epoch
|
||||||
|
graph_dit_model.on_train_epoch_start()
|
||||||
|
for batch in train_dataloader:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss = graph_dit_model.training_step(batch, epoch)['loss']
|
||||||
|
accelerator.backward(loss)
|
||||||
|
optimizer.step()
|
||||||
|
graph_dit_model.on_train_epoch_end()
|
||||||
|
for batch in val_dataloader:
|
||||||
|
if epoch % cfg.train.check_val_every_n_epoch == 0:
|
||||||
|
graph_dit_model.eval()
|
||||||
|
graph_dit_model.on_validation_epoch_start()
|
||||||
|
graph_dit_model.validation_step(batch, epoch)
|
||||||
|
graph_dit_model.on_validation_epoch_end()
|
||||||
|
|
||||||
|
# test_epoch
|
||||||
|
|
||||||
|
graph_dit_model.test()
|
||||||
|
graph_dit_model.on_test_epoch_start()
|
||||||
|
for batch in test_dataloader:
|
||||||
|
graph_dit_model.test_step(batch, epoch)
|
||||||
|
graph_dit_model.on_test_epoch_end()
|
||||||
|
|
||||||
|
elif cfg.general.type == "Trainer":
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
gradient_clip_val=cfg.train.clip_grad,
|
gradient_clip_val=cfg.train.clip_grad,
|
||||||
# accelerator="cpu",
|
# accelerator="cpu",
|
||||||
|
@@ -83,7 +83,8 @@ class TaskModel():
|
|||||||
return adj_ops_pairs
|
return adj_ops_pairs
|
||||||
def feature_from_adj_and_ops(adj, ops):
|
def feature_from_adj_and_ops(adj, ops):
|
||||||
return np.concatenate([adj.flatten(), ops])
|
return np.concatenate([adj.flatten(), ops])
|
||||||
filename = '/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json'
|
# filename = '/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json'
|
||||||
|
filename = '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/nasbench-201-graph.json'
|
||||||
graphs = read_adj_ops_from_json(filename)
|
graphs = read_adj_ops_from_json(filename)
|
||||||
adjs = []
|
adjs = []
|
||||||
opss = []
|
opss = []
|
||||||
|
15626
graph_dit/swap_results.csv
Normal file
15626
graph_dit/swap_results.csv
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user