4 Commits

Author SHA1 Message Date
xmuhanma
123cde9313 add swap cifar10 results property_metric path update 2024-09-22 15:55:10 +02:00
xmuhanma
9360839a35 add config path 2024-09-22 15:47:51 +02:00
mhz
f75657ac3b add the environment yaml 2024-09-20 00:06:09 +02:00
mhz
be178bc5ee use trainer but has bugs 2024-09-19 14:11:19 +02:00
10 changed files with 16649 additions and 16228 deletions

View File

@@ -2,20 +2,26 @@ general:
name: 'graph_dit' name: 'graph_dit'
wandb: 'disabled' wandb: 'disabled'
gpus: 1 gpus: 1
gpu_number: 2 gpu_number: 0
resume: null resume: null
test_only: null test_only: null
sample_every_val: 2500 sample_every_val: 2500
samples_to_generate: 512 samples_to_generate: 1000
samples_to_save: 3 samples_to_save: 3
chains_to_save: 1 chains_to_save: 1
log_every_steps: 50 log_every_steps: 50
number_chain_steps: 8 number_chain_steps: 8
final_model_samples_to_generate: 100 final_model_samples_to_generate: 1000
final_model_samples_to_save: 20 final_model_samples_to_save: 20
final_model_chains_to_save: 1 final_model_chains_to_save: 1
enable_progress_bar: False enable_progress_bar: False
save_model: True save_model: True
log_dir: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT'
number_checkpoint_limit: 3
type: 'Trainer'
nas_201: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
swap_result: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/swap_results.csv'
root: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/'
model: model:
type: 'discrete' type: 'discrete'
transition: 'marginal' transition: 'marginal'
@@ -32,7 +38,7 @@ model:
ensure_connected: True ensure_connected: True
train: train:
# n_epochs: 5000 # n_epochs: 5000
n_epochs: 500 n_epochs: 10
batch_size: 1200 batch_size: 1200
lr: 0.0002 lr: 0.0002
clip_grad: null clip_grad: null
@@ -41,8 +47,11 @@ train:
seed: 0 seed: 0
val_check_interval: null val_check_interval: null
check_val_every_n_epoch: 1 check_val_every_n_epoch: 1
gradient_accumulation_steps: 1
dataset: dataset:
datadir: 'data/' datadir: 'data/'
task_name: 'nasbench-201' task_name: 'nasbench-201'
guidance_target: 'nasbench-201' guidance_target: 'nasbench-201'
pin_memory: False pin_memory: False
ppo:
clip_param: 1

228
environment.yaml Normal file
View File

@@ -0,0 +1,228 @@
name: graphdit
channels:
- pytorch
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_gnu
- asttokens=2.4.1=pyhd8ed1ab_0
- blas=1.0=mkl
- brotli-python=1.0.9=py39h6a678d5_8
- bzip2=1.0.8=h5eee18b_6
- ca-certificates=2024.7.2=h06a4308_0
- comm=0.2.2=pyhd8ed1ab_0
- debugpy=1.6.7=py39h6a678d5_0
- decorator=5.1.1=pyhd8ed1ab_0
- exceptiongroup=1.2.0=pyhd8ed1ab_2
- executing=2.0.1=pyhd8ed1ab_0
- ffmpeg=4.3=hf484d3e_0
- freetype=2.12.1=h4a9f257_0
- gmp=6.2.1=h295c915_3
- gmpy2=2.1.2=py39heeb90bb_0
- gnutls=3.6.15=he1e5248_0
- idna=3.7=py39h06a4308_0
- importlib-metadata=7.1.0=pyha770c72_0
- importlib_metadata=7.1.0=hd8ed1ab_0
- intel-openmp=2023.1.0=hdb19cb5_46306
- ipykernel=6.29.4=pyh3099207_0
- ipython=8.18.1=pyh707e725_3
- jedi=0.19.1=pyhd8ed1ab_0
- jinja2=3.1.4=py39h06a4308_0
- jpeg=9e=h5eee18b_1
- jupyter_client=8.6.2=pyhd8ed1ab_0
- jupyter_core=5.7.2=py39hf3d152e_0
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.38=h1181459_1
- lerc=3.0=h295c915_0
- libdeflate=1.17=h5eee18b_1
- libffi=3.4.4=h6a678d5_1
- libgcc-ng=13.2.0=h77fa898_7
- libgomp=13.2.0=h77fa898_7
- libiconv=1.16=h5eee18b_3
- libidn2=2.3.4=h5eee18b_0
- libpng=1.6.39=h5eee18b_0
- libsodium=1.0.18=h36c2ea0_1
- libstdcxx-ng=11.2.0=h1234567_1
- libtasn1=4.19.0=h5eee18b_0
- libtiff=4.5.1=h6a678d5_0
- libunistring=0.9.10=h27cfd23_0
- libwebp-base=1.3.2=h5eee18b_0
- lz4-c=1.9.4=h6a678d5_1
- matplotlib-inline=0.1.7=pyhd8ed1ab_0
- mkl=2023.1.0=h213fc3f_46344
- mkl-service=2.4.0=py39h5eee18b_1
- mkl_fft=1.3.8=py39h5eee18b_0
- mkl_random=1.2.4=py39hdb19cb5_0
- mpc=1.1.0=h10f8cd9_1
- mpfr=4.0.2=hb69a4c5_1
- mpmath=1.3.0=py39h06a4308_0
- ncurses=6.4=h6a678d5_0
- nest-asyncio=1.6.0=pyhd8ed1ab_0
- nettle=3.7.3=hbbd107a_1
- numpy-base=1.26.4=py39hb5e798b_0
- openh264=2.1.1=h4ff587b_0
- openjpeg=2.4.0=h9ca470c_2
- openssl=3.3.1=h4ab18f5_0
- packaging=24.0=pyhd8ed1ab_0
- parso=0.8.4=pyhd8ed1ab_0
- pexpect=4.9.0=pyhd8ed1ab_0
- pickleshare=0.7.5=py_1003
- pip=24.0=py39h06a4308_0
- platformdirs=4.2.2=pyhd8ed1ab_0
- prompt-toolkit=3.0.46=pyha770c72_0
- psutil=5.9.8=py39hd1e30aa_0
- ptyprocess=0.7.0=pyhd3deb0d_0
- pure_eval=0.2.2=pyhd8ed1ab_0
- pygments=2.18.0=pyhd8ed1ab_0
- pysocks=1.7.1=py39h06a4308_0
- python=3.9.19=h955ad1f_1
- python_abi=3.9=2_cp39
- pytorch-mutex=1.0=cpu
- pyzmq=25.1.2=py39h6a678d5_0
- readline=8.2=h5eee18b_0
- setuptools=69.5.1=py39h06a4308_0
- six=1.16.0=pyh6c4a22f_0
- sqlite=3.45.3=h5eee18b_0
- stack_data=0.6.2=pyhd8ed1ab_0
- sympy=1.12=py39h06a4308_0
- tbb=2021.8.0=hdb19cb5_0
- tk=8.6.14=h39e8969_0
- tornado=6.4.1=py39hd3abc70_0
- traitlets=5.14.3=pyhd8ed1ab_0
- typing_extensions=4.12.2=pyha770c72_0
- wcwidth=0.2.13=pyhd8ed1ab_0
- wheel=0.43.0=py39h06a4308_0
- xz=5.4.6=h5eee18b_1
- zeromq=4.3.5=h6a678d5_0
- zlib=1.2.13=h5eee18b_1
- zstd=1.5.5=hc292b87_2
- pip:
- absl-py==2.1.0
- accelerate==0.34.2
- aiohttp==3.9.5
- aiosignal==1.3.1
- antlr4-python3-runtime==4.9.3
- astunparse==1.6.3
- async-timeout==4.0.3
- attrs==23.2.0
- beautifulsoup4==4.12.3
- bleach==6.1.0
- certifi==2024.2.2
- charset-normalizer==3.1.0
- cmake==3.29.3
- contourpy==1.2.1
- cycler==0.12.1
- defusedxml==0.7.1
- fastjsonschema==2.19.1
- fcd-torch==1.0.7
- filelock==3.14.0
- flatbuffers==24.3.25
- fonttools==4.52.4
- frozenlist==1.4.1
- fsspec==2024.5.0
- gast==0.5.4
- google-pasta==0.2.0
- grpcio==1.64.1
- h5py==3.11.0
- huggingface-hub==0.24.6
- hydra-core==1.3.2
- imageio==2.26.0
- importlib-resources==6.4.0
- joblib==1.2.0
- jsonschema==4.22.0
- jsonschema-specifications==2023.12.1
- jupyterlab-pygments==0.3.0
- keras==3.3.3
- kiwisolver==1.4.5
- libclang==18.1.1
- lightning-utilities==0.11.2
- lit==18.1.6
- markdown==3.6
- markdown-it-py==3.0.0
- markupsafe==2.1.5
- matplotlib==3.7.0
- mdurl==0.1.2
- mini-moses==1.0
- mistune==3.0.2
- ml-dtypes==0.3.2
- multidict==6.0.5
- namex==0.0.8
- nas-bench-201==2.1
- nasbench==1.0
- nbclient==0.10.0
- nbconvert==7.16.4
- nbformat==5.10.4
- networkx==3.0
- numpy==1.24.2
- nvidia-cublas-cu11==11.10.3.66
- nvidia-cublas-cu12==12.1.3.1
- nvidia-cuda-cupti-cu11==11.7.101
- nvidia-cuda-cupti-cu12==12.1.105
- nvidia-cuda-nvrtc-cu11==11.7.99
- nvidia-cuda-nvrtc-cu12==12.1.105
- nvidia-cuda-runtime-cu11==11.7.99
- nvidia-cuda-runtime-cu12==12.1.105
- nvidia-cudnn-cu11==8.5.0.96
- nvidia-cudnn-cu12==8.9.2.26
- nvidia-cufft-cu11==10.9.0.58
- nvidia-cufft-cu12==11.0.2.54
- nvidia-curand-cu11==10.2.10.91
- nvidia-curand-cu12==10.3.2.106
- nvidia-cusolver-cu11==11.4.0.1
- nvidia-cusolver-cu12==11.4.5.107
- nvidia-cusparse-cu11==11.7.4.91
- nvidia-cusparse-cu12==12.1.0.106
- nvidia-nccl-cu11==2.14.3
- nvidia-nccl-cu12==2.20.5
- nvidia-nvjitlink-cu12==12.5.40
- nvidia-nvtx-cu11==11.7.91
- nvidia-nvtx-cu12==12.1.105
- omegaconf==2.3.0
- opt-einsum==3.3.0
- optree==0.11.0
- pandas==1.5.3
- pandocfilters==1.5.1
- pillow==10.3.0
- protobuf==3.20.3
- pyparsing==3.1.2
- python-dateutil==2.9.0.post0
- pytorch-lightning==2.0.1
- pytz==2024.1
- pyyaml==6.0.1
- rdkit==2023.9.4
- referencing==0.35.1
- requests==2.32.2
- rich==13.7.1
- rpds-py==0.18.1
- safetensors==0.4.5
- scikit-learn==1.2.1
- scipy==1.13.1
- seaborn==0.13.2
- simplejson==3.19.2
- soupsieve==2.5
- tensorboard==2.16.2
- tensorboard-data-server==0.7.2
- tensorflow==2.16.1
- tensorflow-io-gcs-filesystem==0.37.0
- termcolor==2.4.0
- threadpoolctl==3.5.0
- tinycss2==1.3.0
- torch==2.0.0
- torch-geometric==2.3.0
- torchaudio==2.0.1+rocm5.4.2
- torchmetrics==0.11.4
- torchvision==0.15.1
- tqdm==4.64.1
- triton==2.0.0
- typing-extensions==4.12.0
- tzdata==2024.1
- urllib3==2.2.1
- webencodings==0.5.1
- werkzeug==3.0.3
- wrapt==1.16.0
- yacs==0.1.8
- yarl==1.9.4
- zipp==3.19.0
prefix: /home/stud/hanzhang/anaconda3/envs/graphdit

View File

@@ -54,7 +54,9 @@ class BasicGraphMetrics(object):
covered_nodes = set() covered_nodes = set()
direct_valid_count = 0 direct_valid_count = 0
print(f"generated number: {len(generated)}") print(f"generated number: {len(generated)}")
print(f"generated: {generated}")
for graph in generated: for graph in generated:
print(f"graph: {graph}")
node_types, edge_types = graph node_types, edge_types = graph
direct_valid_flag = True direct_valid_flag = True
direct_valid_count += 1 direct_valid_count += 1

View File

@@ -25,7 +25,6 @@ from sklearn.model_selection import train_test_split
import utils as utils import utils as utils
from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule
from diffusion.distributions import DistributionNodes from diffusion.distributions import DistributionNodes
from naswot.score_networks import get_nasbench201_idx_score
from naswot import nasspace from naswot import nasspace
from naswot import datasets as dt from naswot import datasets as dt
@@ -72,7 +71,9 @@ class DataModule(AbstractDataModule):
# base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] # base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
# except NameError: # except NameError:
# base_path = pathlib.Path(os.getcwd()).parent[2] # base_path = pathlib.Path(os.getcwd()).parent[2]
base_path = '/nfs/data3/hanzhang/nasbenchDiT' # base_path = '/nfs/data3/hanzhang/nasbenchDiT'
base_path = os.path.join(self.cfg.general.root, "..")
root_path = os.path.join(base_path, self.datadir) root_path = os.path.join(base_path, self.datadir)
self.root_path = root_path self.root_path = root_path
@@ -84,7 +85,7 @@ class DataModule(AbstractDataModule):
# Load the dataset to the memory # Load the dataset to the memory
# Dataset has target property, root path, and transform # Dataset has target property, root path, and transform
source = './NAS-Bench-201-v1_1-096897.pth' source = './NAS-Bench-201-v1_1-096897.pth'
dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None) dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None, cfg=self.cfg)
self.dataset = dataset self.dataset = dataset
# self.api = dataset.api # self.api = dataset.api
@@ -384,7 +385,7 @@ class DataModule_original(AbstractDataModule):
def test_dataloader(self): def test_dataloader(self):
return self.test_loader return self.test_loader
def new_graphs_to_json(graphs, filename): def new_graphs_to_json(graphs, filename, cfg):
source_name = "nasbench-201" source_name = "nasbench-201"
num_graph = len(graphs) num_graph = len(graphs)
@@ -491,8 +492,9 @@ def new_graphs_to_json(graphs, filename):
'num_active_nodes': len(active_nodes), 'num_active_nodes': len(active_nodes),
'transition_E': transition_E.tolist(), 'transition_E': transition_E.tolist(),
} }
import os
with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f: # with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f:
with open(os.path.join(cfg.general.root,'nasbench-201-meta.json'), 'w') as f:
json.dump(meta_dict, f) json.dump(meta_dict, f)
return meta_dict return meta_dict
@@ -656,9 +658,11 @@ def graphs_to_json(graphs, filename):
json.dump(meta_dict, f) json.dump(meta_dict, f)
return meta_dict return meta_dict
class Dataset(InMemoryDataset): class Dataset(InMemoryDataset):
def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None): def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None, cfg=None):
self.target_prop = target_prop self.target_prop = target_prop
source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' self.cfg = cfg
# source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
source = os.path.join(self.cfg.general.root, 'NAS-Bench-201-v1_1-096897.pth')
self.source = source self.source = source
# self.api = API(source) # Initialize NAS-Bench-201 API # self.api = API(source) # Initialize NAS-Bench-201 API
# print('API loaded') # print('API loaded')
@@ -679,7 +683,8 @@ class Dataset(InMemoryDataset):
return [f'{self.source}.pt'] return [f'{self.source}.pt']
def process(self): def process(self):
source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' # source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
source = self.cfg.general.nas_201
# self.api = API(source) # self.api = API(source)
data_list = [] data_list = []
@@ -748,7 +753,8 @@ class Dataset(InMemoryDataset):
return edges,nodes return edges,nodes
def graph_to_graph_data(graph, idx, train_loader, searchspace, args, device): # def graph_to_graph_data(graph, idx, train_loader, searchspace, args, device):
def graph_to_graph_data(graph, idx, args, device):
# def graph_to_graph_data(graph): # def graph_to_graph_data(graph):
ops = graph[1] ops = graph[1]
adj = graph[0] adj = graph[0]
@@ -797,7 +803,7 @@ class Dataset(InMemoryDataset):
args.batch_size = 128 args.batch_size = 128
args.GPU = '0' args.GPU = '0'
args.dataset = 'cifar10' args.dataset = 'cifar10'
args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' args.api_loc = self.cfg.general.nas_201
args.data_loc = '../cifardata/' args.data_loc = '../cifardata/'
args.seed = 777 args.seed = 777
args.init = '' args.init = ''
@@ -812,11 +818,12 @@ class Dataset(InMemoryDataset):
args.num_modules_per_stack = 3 args.num_modules_per_stack = 3
args.num_labels = 1 args.num_labels = 1
searchspace = nasspace.get_search_space(args) searchspace = nasspace.get_search_space(args)
train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) # train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
self.swap_scores = [] self.swap_scores = []
import csv import csv
# with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f: # with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f:
with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results_cifar100.csv', 'r') as f: with open(self.cfg.general.swap_result, 'r') as f:
# with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results_cifar100.csv', 'r') as f:
reader = csv.reader(f) reader = csv.reader(f)
header = next(reader) header = next(reader)
data = [row for row in reader] data = [row for row in reader]
@@ -824,12 +831,15 @@ class Dataset(InMemoryDataset):
device = torch.device('cuda:2') device = torch.device('cuda:2')
with tqdm(total = len_data) as pbar: with tqdm(total = len_data) as pbar:
active_nodes = set() active_nodes = set()
file_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json' import os
# file_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json'
file_path = os.path.join(self.cfg.general.root, 'nasbench-201-graph.json')
with open(file_path, 'r') as f: with open(file_path, 'r') as f:
graph_list = json.load(f) graph_list = json.load(f)
i = 0 i = 0
flex_graph_list = [] flex_graph_list = []
flex_graph_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json' # flex_graph_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json'
flex_graph_path = os.path.join(self.cfg.general.root,'flex-nasbench201-graph.json')
for graph in graph_list: for graph in graph_list:
print(f'iterate every graph in graph_list, here is {i}') print(f'iterate every graph in graph_list, here is {i}')
arch_info = graph['arch_str'] arch_info = graph['arch_str']
@@ -837,7 +847,8 @@ class Dataset(InMemoryDataset):
for op in ops: for op in ops:
if op not in active_nodes: if op not in active_nodes:
active_nodes.add(op) active_nodes.add(op)
data = graph_to_graph_data((adj_matrix, ops),idx=i, train_loader=train_loader, searchspace=searchspace, args=args, device=device) # data = graph_to_graph_data((adj_matrix, ops),idx=i, train_loader=train_loader, searchspace=searchspace, args=args, device=device)
data = graph_to_graph_data((adj_matrix, ops),idx=i, args=args, device=device)
i += 1 i += 1
if data is None: if data is None:
pbar.update(1) pbar.update(1)
@@ -1140,6 +1151,7 @@ class DataInfos(AbstractDatasetInfos):
self.task = task_name self.task = task_name
self.task_type = tasktype_dict.get(task_name, "regression") self.task_type = tasktype_dict.get(task_name, "regression")
self.ensure_connected = cfg.model.ensure_connected self.ensure_connected = cfg.model.ensure_connected
self.cfg = cfg
# self.api = dataset.api # self.api = dataset.api
datadir = cfg.dataset.datadir datadir = cfg.dataset.datadir
@@ -1182,14 +1194,15 @@ class DataInfos(AbstractDatasetInfos):
# len_ops.add(len(ops)) # len_ops.add(len(ops))
# graphs.append((adj_matrix, ops)) # graphs.append((adj_matrix, ops))
# graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json') # graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json')
graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json') # graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json')
graphs = read_adj_ops_from_json(os.path.join(self.cfg.general.root, 'nasbench-201-graph.json'))
# check first five graphs # check first five graphs
for i in range(5): for i in range(5):
print(f'graph {i} : {graphs[i]}') print(f'graph {i} : {graphs[i]}')
# print(f'ops_type: {ops_type}') # print(f'ops_type: {ops_type}')
meta_dict = new_graphs_to_json(graphs, 'nasbench-201') meta_dict = new_graphs_to_json(graphs, 'nasbench-201', self.cfg)
self.base_path = base_path self.base_path = base_path
self.active_nodes = meta_dict['active_nodes'] self.active_nodes = meta_dict['active_nodes']
self.max_n_nodes = meta_dict['max_n_nodes'] self.max_n_nodes = meta_dict['max_n_nodes']
@@ -1396,11 +1409,12 @@ def compute_meta(root, source_name, train_index, test_index):
'transition_E': tansition_E.tolist(), 'transition_E': tansition_E.tolist(),
} }
with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f: # with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f:
with open(os.path.join(self.cfg.general.root, 'nasbench201.meta.json'), "w") as f:
json.dump(meta_dict, f) json.dump(meta_dict, f)
return meta_dict return meta_dict
if __name__ == "__main__": if __name__ == "__main__":
dataset = Dataset(source='nasbench', root='/nfs/data3/hanzhang/nasbenchDiT/graph-dit', target_prop='Class', transform=None) dataset = Dataset(source='nasbench', root='/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/', target_prop='Class', transform=None)

View File

@@ -23,6 +23,9 @@ class Graph_DiT(pl.LightningModule):
self.test_only = cfg.general.test_only self.test_only = cfg.general.test_only
self.guidance_target = getattr(cfg.dataset, 'guidance_target', None) self.guidance_target = getattr(cfg.dataset, 'guidance_target', None)
from nas_201_api import NASBench201API as API
self.api = API(cfg.general.nas_201)
input_dims = dataset_infos.input_dims input_dims = dataset_infos.input_dims
output_dims = dataset_infos.output_dims output_dims = dataset_infos.output_dims
nodes_dist = dataset_infos.nodes_dist nodes_dist = dataset_infos.nodes_dist
@@ -41,7 +44,7 @@ class Graph_DiT(pl.LightningModule):
self.args.batch_size = 128 self.args.batch_size = 128
self.args.GPU = '0' self.args.GPU = '0'
self.args.dataset = 'cifar10-valid' self.args.dataset = 'cifar10-valid'
self.args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' self.args.api_loc = cfg.general.nas_201
self.args.data_loc = '../cifardata/' self.args.data_loc = '../cifardata/'
self.args.seed = 777 self.args.seed = 777
self.args.init = '' self.args.init = ''
@@ -79,6 +82,7 @@ class Graph_DiT(pl.LightningModule):
self.node_dist = nodes_dist self.node_dist = nodes_dist
self.active_index = active_index self.active_index = active_index
self.dataset_info = dataset_infos self.dataset_info = dataset_infos
self.cur_epoch = 0
self.train_loss = TrainLossDiscrete(self.cfg.model.lambda_train) self.train_loss = TrainLossDiscrete(self.cfg.model.lambda_train)
@@ -162,6 +166,62 @@ class Graph_DiT(pl.LightningModule):
return pred return pred
def training_step(self, data, i): def training_step(self, data, i):
if self.cfg.general.type != 'accelerator' and self.current_epoch > self.cfg.train.n_epochs / 5 * 4:
samples_left_to_generate = self.cfg.general.samples_to_generate
samples_left_to_save = self.cfg.general.samples_to_save
chains_left_to_save = self.cfg.general.chains_to_save
samples, all_ys, batch_id = [], [], 0
def graph_reward_fn(graphs, true_graphs=None, device=None, reward_model='swap'):
rewards = []
if reward_model == 'swap':
import csv
with open(self.cfg.general.swap_result, 'r') as f:
reader = csv.reader(f)
header = next(reader)
data = [row for row in reader]
swap_scores = [float(row[0]) for row in data]
for graph in graphs:
node_tensor = graph[0]
node = node_tensor.cpu().numpy().tolist()
def nodes_to_arch_str(nodes):
num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
nodes_str = [num_to_op[node] for node in nodes]
arch_str = '|' + nodes_str[1] + '~0|+' + \
'|' + nodes_str[2] + '~0|' + nodes_str[3] + '~1|+' +\
'|' + nodes_str[4] + '~0|' + nodes_str[5] + '~1|' + nodes_str[6] + '~2|'
return arch_str
arch_str = nodes_to_arch_str(node)
reward = swap_scores[self.api.query_index_by_arch(arch_str)]
rewards.append(reward)
return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device)
old_log_probs = None
bs = 1 * self.cfg.train.batch_size
to_generate = min(samples_left_to_generate, bs)
to_save = min(samples_left_to_save, bs)
chains_save = min(chains_left_to_save, bs)
# batch_y = test_y_collection[batch_id : batch_id + to_generate]
batch_y = torch.ones(to_generate, self.ydim_output, device=self.device)
cur_sample, log_probs = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
keep_chain=chains_save, number_chain_steps=self.number_chain_steps)
# samples = samples + cur_sample
samples.append(cur_sample)
reward = graph_reward_fn(cur_sample, device=self.device)
advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6) #
if old_log_probs is None:
old_log_probs = log_probs.clone()
ratio = torch.exp(log_probs - old_log_probs)
print(f"ratio: {ratio.shape}, advantages: {advantages.shape}")
unclipped_loss = -advantages * ratio
clipped_loss = -advantages * torch.clamp(ratio, 1.0 - self.cfg.ppo.clip_param, 1.0 + self.cfg.ppo.clip_param)
loss = torch.mean(torch.max(unclipped_loss, clipped_loss))
return {'loss': loss}
else:
data_x = F.one_hot(data.x, num_classes=12).float()[:, self.active_index] data_x = F.one_hot(data.x, num_classes=12).float()[:, self.active_index]
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
@@ -196,14 +256,15 @@ class Graph_DiT(pl.LightningModule):
def on_train_epoch_start(self) -> None: def on_train_epoch_start(self) -> None:
if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
print("Starting train epoch {}/{}...".format(self.current_epoch, self.trainer.max_epochs)) # if self.cur_epoch / self.cfg.train.n_epochs in [0.25, 0.5, 0.75, 1.0]:
print("Starting train epoch {}/{}...".format(self.cur_epoch, self.cfg.train.n_epochs))
self.start_epoch_time = time.time() self.start_epoch_time = time.time()
self.train_loss.reset() self.train_loss.reset()
self.train_metrics.reset() self.train_metrics.reset()
def on_train_epoch_end(self) -> None: def on_train_epoch_end(self) -> None:
if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: if self.current_epoch / self.cfg.train.n_epochs in [0.25, 0.5, 0.75, 1.0]:
log = True log = True
else: else:
log = False log = False
@@ -240,6 +301,7 @@ class Graph_DiT(pl.LightningModule):
self.val_X_logp.compute(), self.val_E_logp.compute()] self.val_X_logp.compute(), self.val_E_logp.compute()]
if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
# if self.cur_epoch / self.cfg.train.n_epochs in [0.25, 0.5, 0.75, 1.0]:
print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ", print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ",
f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best : %.2f\n' % (metrics[0], self.best_val_nll)) f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best : %.2f\n' % (metrics[0], self.best_val_nll))
with open("validation-metrics.csv", "a") as f: with open("validation-metrics.csv", "a") as f:
@@ -283,10 +345,15 @@ class Graph_DiT(pl.LightningModule):
num_examples = self.val_y_collection.size(0) num_examples = self.val_y_collection.size(0)
batch_y = self.val_y_collection[start_index:start_index + to_generate] batch_y = self.val_y_collection[start_index:start_index + to_generate]
all_ys.append(batch_y) all_ys.append(batch_y)
samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y, cur_sample, logprobs = self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y,
save_final=to_save, save_final=to_save,
keep_chain=chains_save, keep_chain=chains_save,
number_chain_steps=self.number_chain_steps)) number_chain_steps=self.number_chain_steps)
samples.extend(cur_sample)
# samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y,
# save_final=to_save,
# keep_chain=chains_save,
# number_chain_steps=self.number_chain_steps))
ident += to_generate ident += to_generate
start_index += to_generate start_index += to_generate
@@ -336,7 +403,7 @@ class Graph_DiT(pl.LightningModule):
print(f"Epoch {self.current_epoch}: Test NLL {metrics[0] :.2f} -- Test Atom type KL {metrics[1] :.2f} -- ", print(f"Epoch {self.current_epoch}: Test NLL {metrics[0] :.2f} -- Test Atom type KL {metrics[1] :.2f} -- ",
f"Test Edge type KL: {metrics[2] :.2f}") f"Test Edge type KL: {metrics[2] :.2f}")
## final epcoh ## final epoch
samples_left_to_generate = self.cfg.general.final_model_samples_to_generate samples_left_to_generate = self.cfg.general.final_model_samples_to_generate
samples_left_to_save = self.cfg.general.final_model_samples_to_save samples_left_to_save = self.cfg.general.final_model_samples_to_save
chains_left_to_save = self.cfg.general.final_model_chains_to_save chains_left_to_save = self.cfg.general.final_model_chains_to_save
@@ -359,9 +426,9 @@ class Graph_DiT(pl.LightningModule):
# batch_y = test_y_collection[batch_id : batch_id + to_generate] # batch_y = test_y_collection[batch_id : batch_id + to_generate]
batch_y = torch.ones(to_generate, self.ydim_output, device=self.device) batch_y = torch.ones(to_generate, self.ydim_output, device=self.device)
cur_sample = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, cur_sample, log_probs = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
keep_chain=chains_save, number_chain_steps=self.number_chain_steps) keep_chain=chains_save, number_chain_steps=self.number_chain_steps)
samples = samples + cur_sample samples.extend(cur_sample)
all_ys.append(batch_y) all_ys.append(batch_y)
batch_id += to_generate batch_id += to_generate
@@ -601,6 +668,12 @@ class Graph_DiT(pl.LightningModule):
assert (E == torch.transpose(E, 1, 2)).all() assert (E == torch.transpose(E, 1, 2)).all()
if self.cfg.general.type != 'accelerator':
if self.trainer.training or self.trainer.validating:
total_log_probs = torch.zeros([self.cfg.general.samples_to_generate, 10], device=self.device)
elif self.trainer.testing:
total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate, 10], device=self.device)
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1. # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
for s_int in reversed(range(0, self.T)): for s_int in reversed(range(0, self.T)):
s_array = s_int * torch.ones((batch_size, 1)).type_as(y) s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
@@ -609,21 +682,24 @@ class Graph_DiT(pl.LightningModule):
t_norm = t_array / self.T t_norm = t_array / self.T
# Sample z_s # Sample z_s
sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask) sampled_s, discrete_sampled_s, log_probs = self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask)
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
total_log_probs += log_probs
# Sample # Sample
sampled_s = sampled_s.mask(node_mask, collapse=True) sampled_s = sampled_s.mask(node_mask, collapse=True)
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
molecule_list = [] graph_list = []
for i in range(batch_size): for i in range(batch_size):
n = n_nodes[i] n = n_nodes[i]
atom_types = X[i, :n].cpu() node_types = X[i, :n].cpu()
edge_types = E[i, :n, :n].cpu() edge_types = E[i, :n, :n].cpu()
molecule_list.append([atom_types, edge_types]) graph_list.append((node_types , edge_types))
return molecule_list total_log_probs = torch.sum(total_log_probs, dim=-1)
return graph_list, total_log_probs
def sample_p_zs_given_zt(self, s, t, X_t, E_t, y_t, node_mask): def sample_p_zs_given_zt(self, s, t, X_t, E_t, y_t, node_mask):
"""Samples from zs ~ p(zs | zt). Only used during sampling. """Samples from zs ~ p(zs | zt). Only used during sampling.
@@ -675,6 +751,14 @@ class Graph_DiT(pl.LightningModule):
# with condition = P_t(A_{t-1} |A_t, y) # with condition = P_t(A_{t-1} |A_t, y)
prob_X, prob_E, pred = get_prob(noisy_data) prob_X, prob_E, pred = get_prob(noisy_data)
log_prob_X = torch.log(torch.gather(prob_X, -1, X_t.long()).squeeze(-1)) # bs, n
log_prob_E = torch.log(torch.gather(prob_E, -1, E_t.long()).squeeze(-1)) # bs, n, n
# Sum the log_prob across dimensions for total log_prob
log_prob_X = log_prob_X.sum(dim=-1)
log_prob_E = log_prob_E.sum(dim=(1, 2))
log_probs = torch.cat([log_prob_X, log_prob_E], dim=-1)
### Guidance ### Guidance
if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1: if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1:
uncon_prob_X, uncon_prob_E, pred = get_prob(noisy_data, unconditioned=True) uncon_prob_X, uncon_prob_E, pred = get_prob(noisy_data, unconditioned=True)
@@ -810,4 +894,4 @@ class Graph_DiT(pl.LightningModule):
out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=y_t) out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=y_t)
out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=y_t) out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=y_t)
return out_one_hot.mask(node_mask).type_as(y_t), out_discrete.mask(node_mask, collapse=True).type_as(y_t) return out_one_hot.mask(node_mask).type_as(y_t), out_discrete.mask(node_mask, collapse=True).type_as(y_t), log_probs

View File

@@ -177,6 +177,66 @@ def test(cfg: DictConfig):
os.chdir(cfg.general.resume.split("checkpoints")[0]) os.chdir(cfg.general.resume.split("checkpoints")[0])
# os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number # os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number
model = Graph_DiT(cfg=cfg, **model_kwargs) model = Graph_DiT(cfg=cfg, **model_kwargs)
if cfg.general.type == "accelerator":
graph_dit_model = model
from accelerate import Accelerator
from accelerate.utils import set_seed, ProjectConfiguration
accelerator_config = ProjectConfiguration(
project_dir=os.path.join(cfg.general.log_dir, cfg.general.name),
automatic_checkpoint_naming=True,
total_limit=cfg.general.number_checkpoint_limit,
)
accelerator = Accelerator(
mixed_precision='no',
project_config=accelerator_config,
# gradient_accumulation_steps=cfg.train.gradient_accumulation_steps * cfg.train.n_epochs,
gradient_accumulation_steps=cfg.train.gradient_accumulation_steps,
)
optimizer = graph_dit_model.configure_optimizers()
train_dataloader = datamodule.train_dataloader()
train_dataloader = accelerator.prepare(train_dataloader)
val_dataloader = datamodule.val_dataloader()
val_dataloader = accelerator.prepare(val_dataloader)
test_dataloader = datamodule.test_dataloader()
test_dataloader = accelerator.prepare(test_dataloader)
optimizer, graph_dit_model = accelerator.prepare(optimizer, graph_dit_model)
# train_epoch
from pytorch_lightning import seed_everything
seed_everything(cfg.train.seed)
for epoch in range(cfg.train.n_epochs):
print(f"Epoch {epoch}")
graph_dit_model.train()
graph_dit_model.cur_epoch = epoch
graph_dit_model.on_train_epoch_start()
for batch in train_dataloader:
optimizer.zero_grad()
loss = graph_dit_model.training_step(batch, epoch)['loss']
accelerator.backward(loss)
optimizer.step()
graph_dit_model.on_train_epoch_end()
for batch in val_dataloader:
if epoch % cfg.train.check_val_every_n_epoch == 0:
graph_dit_model.eval()
graph_dit_model.on_validation_epoch_start()
graph_dit_model.validation_step(batch, epoch)
graph_dit_model.on_validation_epoch_end()
# test_epoch
graph_dit_model.test()
graph_dit_model.on_test_epoch_start()
for batch in test_dataloader:
graph_dit_model.test_step(batch, epoch)
graph_dit_model.on_test_epoch_end()
elif cfg.general.type == "Trainer":
trainer = Trainer( trainer = Trainer(
gradient_clip_val=cfg.train.clip_grad, gradient_clip_val=cfg.train.clip_grad,
# accelerator="cpu", # accelerator="cpu",

View File

@@ -83,7 +83,8 @@ class TaskModel():
return adj_ops_pairs return adj_ops_pairs
def feature_from_adj_and_ops(adj, ops): def feature_from_adj_and_ops(adj, ops):
return np.concatenate([adj.flatten(), ops]) return np.concatenate([adj.flatten(), ops])
filename = '/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json' # filename = '/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json'
filename = '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/nasbench-201-graph.json'
graphs = read_adj_ops_from_json(filename) graphs = read_adj_ops_from_json(filename)
adjs = [] adjs = []
opss = [] opss = []

15626
graph_dit/swap_results.csv Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long