4 Commits

Author SHA1 Message Date
xmuhanma
123cde9313 add swap cifar10 results property_metric path update 2024-09-22 15:55:10 +02:00
xmuhanma
9360839a35 add config path 2024-09-22 15:47:51 +02:00
mhz
f75657ac3b add the environment yaml 2024-09-20 00:06:09 +02:00
mhz
be178bc5ee use trainer but has bugs 2024-09-19 14:11:19 +02:00
10 changed files with 16649 additions and 16228 deletions

View File

@@ -2,20 +2,26 @@ general:
name: 'graph_dit'
wandb: 'disabled'
gpus: 1
gpu_number: 2
gpu_number: 0
resume: null
test_only: null
sample_every_val: 2500
samples_to_generate: 512
samples_to_generate: 1000
samples_to_save: 3
chains_to_save: 1
log_every_steps: 50
number_chain_steps: 8
final_model_samples_to_generate: 100
final_model_samples_to_generate: 1000
final_model_samples_to_save: 20
final_model_chains_to_save: 1
enable_progress_bar: False
save_model: True
log_dir: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT'
number_checkpoint_limit: 3
type: 'Trainer'
nas_201: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
swap_result: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/swap_results.csv'
root: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/'
model:
type: 'discrete'
transition: 'marginal'
@@ -32,7 +38,7 @@ model:
ensure_connected: True
train:
# n_epochs: 5000
n_epochs: 500
n_epochs: 10
batch_size: 1200
lr: 0.0002
clip_grad: null
@@ -41,8 +47,11 @@ train:
seed: 0
val_check_interval: null
check_val_every_n_epoch: 1
gradient_accumulation_steps: 1
dataset:
datadir: 'data/'
task_name: 'nasbench-201'
guidance_target: 'nasbench-201'
pin_memory: False
ppo:
clip_param: 1

228
environment.yaml Normal file
View File

@@ -0,0 +1,228 @@
name: graphdit
channels:
- pytorch
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_gnu
- asttokens=2.4.1=pyhd8ed1ab_0
- blas=1.0=mkl
- brotli-python=1.0.9=py39h6a678d5_8
- bzip2=1.0.8=h5eee18b_6
- ca-certificates=2024.7.2=h06a4308_0
- comm=0.2.2=pyhd8ed1ab_0
- debugpy=1.6.7=py39h6a678d5_0
- decorator=5.1.1=pyhd8ed1ab_0
- exceptiongroup=1.2.0=pyhd8ed1ab_2
- executing=2.0.1=pyhd8ed1ab_0
- ffmpeg=4.3=hf484d3e_0
- freetype=2.12.1=h4a9f257_0
- gmp=6.2.1=h295c915_3
- gmpy2=2.1.2=py39heeb90bb_0
- gnutls=3.6.15=he1e5248_0
- idna=3.7=py39h06a4308_0
- importlib-metadata=7.1.0=pyha770c72_0
- importlib_metadata=7.1.0=hd8ed1ab_0
- intel-openmp=2023.1.0=hdb19cb5_46306
- ipykernel=6.29.4=pyh3099207_0
- ipython=8.18.1=pyh707e725_3
- jedi=0.19.1=pyhd8ed1ab_0
- jinja2=3.1.4=py39h06a4308_0
- jpeg=9e=h5eee18b_1
- jupyter_client=8.6.2=pyhd8ed1ab_0
- jupyter_core=5.7.2=py39hf3d152e_0
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.38=h1181459_1
- lerc=3.0=h295c915_0
- libdeflate=1.17=h5eee18b_1
- libffi=3.4.4=h6a678d5_1
- libgcc-ng=13.2.0=h77fa898_7
- libgomp=13.2.0=h77fa898_7
- libiconv=1.16=h5eee18b_3
- libidn2=2.3.4=h5eee18b_0
- libpng=1.6.39=h5eee18b_0
- libsodium=1.0.18=h36c2ea0_1
- libstdcxx-ng=11.2.0=h1234567_1
- libtasn1=4.19.0=h5eee18b_0
- libtiff=4.5.1=h6a678d5_0
- libunistring=0.9.10=h27cfd23_0
- libwebp-base=1.3.2=h5eee18b_0
- lz4-c=1.9.4=h6a678d5_1
- matplotlib-inline=0.1.7=pyhd8ed1ab_0
- mkl=2023.1.0=h213fc3f_46344
- mkl-service=2.4.0=py39h5eee18b_1
- mkl_fft=1.3.8=py39h5eee18b_0
- mkl_random=1.2.4=py39hdb19cb5_0
- mpc=1.1.0=h10f8cd9_1
- mpfr=4.0.2=hb69a4c5_1
- mpmath=1.3.0=py39h06a4308_0
- ncurses=6.4=h6a678d5_0
- nest-asyncio=1.6.0=pyhd8ed1ab_0
- nettle=3.7.3=hbbd107a_1
- numpy-base=1.26.4=py39hb5e798b_0
- openh264=2.1.1=h4ff587b_0
- openjpeg=2.4.0=h9ca470c_2
- openssl=3.3.1=h4ab18f5_0
- packaging=24.0=pyhd8ed1ab_0
- parso=0.8.4=pyhd8ed1ab_0
- pexpect=4.9.0=pyhd8ed1ab_0
- pickleshare=0.7.5=py_1003
- pip=24.0=py39h06a4308_0
- platformdirs=4.2.2=pyhd8ed1ab_0
- prompt-toolkit=3.0.46=pyha770c72_0
- psutil=5.9.8=py39hd1e30aa_0
- ptyprocess=0.7.0=pyhd3deb0d_0
- pure_eval=0.2.2=pyhd8ed1ab_0
- pygments=2.18.0=pyhd8ed1ab_0
- pysocks=1.7.1=py39h06a4308_0
- python=3.9.19=h955ad1f_1
- python_abi=3.9=2_cp39
- pytorch-mutex=1.0=cpu
- pyzmq=25.1.2=py39h6a678d5_0
- readline=8.2=h5eee18b_0
- setuptools=69.5.1=py39h06a4308_0
- six=1.16.0=pyh6c4a22f_0
- sqlite=3.45.3=h5eee18b_0
- stack_data=0.6.2=pyhd8ed1ab_0
- sympy=1.12=py39h06a4308_0
- tbb=2021.8.0=hdb19cb5_0
- tk=8.6.14=h39e8969_0
- tornado=6.4.1=py39hd3abc70_0
- traitlets=5.14.3=pyhd8ed1ab_0
- typing_extensions=4.12.2=pyha770c72_0
- wcwidth=0.2.13=pyhd8ed1ab_0
- wheel=0.43.0=py39h06a4308_0
- xz=5.4.6=h5eee18b_1
- zeromq=4.3.5=h6a678d5_0
- zlib=1.2.13=h5eee18b_1
- zstd=1.5.5=hc292b87_2
- pip:
- absl-py==2.1.0
- accelerate==0.34.2
- aiohttp==3.9.5
- aiosignal==1.3.1
- antlr4-python3-runtime==4.9.3
- astunparse==1.6.3
- async-timeout==4.0.3
- attrs==23.2.0
- beautifulsoup4==4.12.3
- bleach==6.1.0
- certifi==2024.2.2
- charset-normalizer==3.1.0
- cmake==3.29.3
- contourpy==1.2.1
- cycler==0.12.1
- defusedxml==0.7.1
- fastjsonschema==2.19.1
- fcd-torch==1.0.7
- filelock==3.14.0
- flatbuffers==24.3.25
- fonttools==4.52.4
- frozenlist==1.4.1
- fsspec==2024.5.0
- gast==0.5.4
- google-pasta==0.2.0
- grpcio==1.64.1
- h5py==3.11.0
- huggingface-hub==0.24.6
- hydra-core==1.3.2
- imageio==2.26.0
- importlib-resources==6.4.0
- joblib==1.2.0
- jsonschema==4.22.0
- jsonschema-specifications==2023.12.1
- jupyterlab-pygments==0.3.0
- keras==3.3.3
- kiwisolver==1.4.5
- libclang==18.1.1
- lightning-utilities==0.11.2
- lit==18.1.6
- markdown==3.6
- markdown-it-py==3.0.0
- markupsafe==2.1.5
- matplotlib==3.7.0
- mdurl==0.1.2
- mini-moses==1.0
- mistune==3.0.2
- ml-dtypes==0.3.2
- multidict==6.0.5
- namex==0.0.8
- nas-bench-201==2.1
- nasbench==1.0
- nbclient==0.10.0
- nbconvert==7.16.4
- nbformat==5.10.4
- networkx==3.0
- numpy==1.24.2
- nvidia-cublas-cu11==11.10.3.66
- nvidia-cublas-cu12==12.1.3.1
- nvidia-cuda-cupti-cu11==11.7.101
- nvidia-cuda-cupti-cu12==12.1.105
- nvidia-cuda-nvrtc-cu11==11.7.99
- nvidia-cuda-nvrtc-cu12==12.1.105
- nvidia-cuda-runtime-cu11==11.7.99
- nvidia-cuda-runtime-cu12==12.1.105
- nvidia-cudnn-cu11==8.5.0.96
- nvidia-cudnn-cu12==8.9.2.26
- nvidia-cufft-cu11==10.9.0.58
- nvidia-cufft-cu12==11.0.2.54
- nvidia-curand-cu11==10.2.10.91
- nvidia-curand-cu12==10.3.2.106
- nvidia-cusolver-cu11==11.4.0.1
- nvidia-cusolver-cu12==11.4.5.107
- nvidia-cusparse-cu11==11.7.4.91
- nvidia-cusparse-cu12==12.1.0.106
- nvidia-nccl-cu11==2.14.3
- nvidia-nccl-cu12==2.20.5
- nvidia-nvjitlink-cu12==12.5.40
- nvidia-nvtx-cu11==11.7.91
- nvidia-nvtx-cu12==12.1.105
- omegaconf==2.3.0
- opt-einsum==3.3.0
- optree==0.11.0
- pandas==1.5.3
- pandocfilters==1.5.1
- pillow==10.3.0
- protobuf==3.20.3
- pyparsing==3.1.2
- python-dateutil==2.9.0.post0
- pytorch-lightning==2.0.1
- pytz==2024.1
- pyyaml==6.0.1
- rdkit==2023.9.4
- referencing==0.35.1
- requests==2.32.2
- rich==13.7.1
- rpds-py==0.18.1
- safetensors==0.4.5
- scikit-learn==1.2.1
- scipy==1.13.1
- seaborn==0.13.2
- simplejson==3.19.2
- soupsieve==2.5
- tensorboard==2.16.2
- tensorboard-data-server==0.7.2
- tensorflow==2.16.1
- tensorflow-io-gcs-filesystem==0.37.0
- termcolor==2.4.0
- threadpoolctl==3.5.0
- tinycss2==1.3.0
- torch==2.0.0
- torch-geometric==2.3.0
- torchaudio==2.0.1+rocm5.4.2
- torchmetrics==0.11.4
- torchvision==0.15.1
- tqdm==4.64.1
- triton==2.0.0
- typing-extensions==4.12.0
- tzdata==2024.1
- urllib3==2.2.1
- webencodings==0.5.1
- werkzeug==3.0.3
- wrapt==1.16.0
- yacs==0.1.8
- yarl==1.9.4
- zipp==3.19.0
prefix: /home/stud/hanzhang/anaconda3/envs/graphdit

View File

@@ -54,7 +54,9 @@ class BasicGraphMetrics(object):
covered_nodes = set()
direct_valid_count = 0
print(f"generated number: {len(generated)}")
print(f"generated: {generated}")
for graph in generated:
print(f"graph: {graph}")
node_types, edge_types = graph
direct_valid_flag = True
direct_valid_count += 1

View File

@@ -25,7 +25,6 @@ from sklearn.model_selection import train_test_split
import utils as utils
from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule
from diffusion.distributions import DistributionNodes
from naswot.score_networks import get_nasbench201_idx_score
from naswot import nasspace
from naswot import datasets as dt
@@ -72,7 +71,9 @@ class DataModule(AbstractDataModule):
# base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
# except NameError:
# base_path = pathlib.Path(os.getcwd()).parent[2]
base_path = '/nfs/data3/hanzhang/nasbenchDiT'
# base_path = '/nfs/data3/hanzhang/nasbenchDiT'
base_path = os.path.join(self.cfg.general.root, "..")
root_path = os.path.join(base_path, self.datadir)
self.root_path = root_path
@@ -84,7 +85,7 @@ class DataModule(AbstractDataModule):
# Load the dataset to the memory
# Dataset has target property, root path, and transform
source = './NAS-Bench-201-v1_1-096897.pth'
dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None)
dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None, cfg=self.cfg)
self.dataset = dataset
# self.api = dataset.api
@@ -384,7 +385,7 @@ class DataModule_original(AbstractDataModule):
def test_dataloader(self):
return self.test_loader
def new_graphs_to_json(graphs, filename):
def new_graphs_to_json(graphs, filename, cfg):
source_name = "nasbench-201"
num_graph = len(graphs)
@@ -491,8 +492,9 @@ def new_graphs_to_json(graphs, filename):
'num_active_nodes': len(active_nodes),
'transition_E': transition_E.tolist(),
}
with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f:
import os
# with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f:
with open(os.path.join(cfg.general.root,'nasbench-201-meta.json'), 'w') as f:
json.dump(meta_dict, f)
return meta_dict
@@ -656,9 +658,11 @@ def graphs_to_json(graphs, filename):
json.dump(meta_dict, f)
return meta_dict
class Dataset(InMemoryDataset):
def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):
def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None, cfg=None):
self.target_prop = target_prop
source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
self.cfg = cfg
# source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
source = os.path.join(self.cfg.general.root, 'NAS-Bench-201-v1_1-096897.pth')
self.source = source
# self.api = API(source) # Initialize NAS-Bench-201 API
# print('API loaded')
@@ -679,7 +683,8 @@ class Dataset(InMemoryDataset):
return [f'{self.source}.pt']
def process(self):
source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
# source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
source = self.cfg.general.nas_201
# self.api = API(source)
data_list = []
@@ -748,7 +753,8 @@ class Dataset(InMemoryDataset):
return edges,nodes
def graph_to_graph_data(graph, idx, train_loader, searchspace, args, device):
# def graph_to_graph_data(graph, idx, train_loader, searchspace, args, device):
def graph_to_graph_data(graph, idx, args, device):
# def graph_to_graph_data(graph):
ops = graph[1]
adj = graph[0]
@@ -797,7 +803,7 @@ class Dataset(InMemoryDataset):
args.batch_size = 128
args.GPU = '0'
args.dataset = 'cifar10'
args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
args.api_loc = self.cfg.general.nas_201
args.data_loc = '../cifardata/'
args.seed = 777
args.init = ''
@@ -812,11 +818,12 @@ class Dataset(InMemoryDataset):
args.num_modules_per_stack = 3
args.num_labels = 1
searchspace = nasspace.get_search_space(args)
train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
# train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
self.swap_scores = []
import csv
# with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f:
with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results_cifar100.csv', 'r') as f:
with open(self.cfg.general.swap_result, 'r') as f:
# with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results_cifar100.csv', 'r') as f:
reader = csv.reader(f)
header = next(reader)
data = [row for row in reader]
@@ -824,12 +831,15 @@ class Dataset(InMemoryDataset):
device = torch.device('cuda:2')
with tqdm(total = len_data) as pbar:
active_nodes = set()
file_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json'
import os
# file_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json'
file_path = os.path.join(self.cfg.general.root, 'nasbench-201-graph.json')
with open(file_path, 'r') as f:
graph_list = json.load(f)
i = 0
flex_graph_list = []
flex_graph_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json'
# flex_graph_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json'
flex_graph_path = os.path.join(self.cfg.general.root,'flex-nasbench201-graph.json')
for graph in graph_list:
print(f'iterate every graph in graph_list, here is {i}')
arch_info = graph['arch_str']
@@ -837,7 +847,8 @@ class Dataset(InMemoryDataset):
for op in ops:
if op not in active_nodes:
active_nodes.add(op)
data = graph_to_graph_data((adj_matrix, ops),idx=i, train_loader=train_loader, searchspace=searchspace, args=args, device=device)
# data = graph_to_graph_data((adj_matrix, ops),idx=i, train_loader=train_loader, searchspace=searchspace, args=args, device=device)
data = graph_to_graph_data((adj_matrix, ops),idx=i, args=args, device=device)
i += 1
if data is None:
pbar.update(1)
@@ -1140,6 +1151,7 @@ class DataInfos(AbstractDatasetInfos):
self.task = task_name
self.task_type = tasktype_dict.get(task_name, "regression")
self.ensure_connected = cfg.model.ensure_connected
self.cfg = cfg
# self.api = dataset.api
datadir = cfg.dataset.datadir
@@ -1182,14 +1194,15 @@ class DataInfos(AbstractDatasetInfos):
# len_ops.add(len(ops))
# graphs.append((adj_matrix, ops))
# graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json')
graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json')
# graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json')
graphs = read_adj_ops_from_json(os.path.join(self.cfg.general.root, 'nasbench-201-graph.json'))
# check first five graphs
for i in range(5):
print(f'graph {i} : {graphs[i]}')
# print(f'ops_type: {ops_type}')
meta_dict = new_graphs_to_json(graphs, 'nasbench-201')
meta_dict = new_graphs_to_json(graphs, 'nasbench-201', self.cfg)
self.base_path = base_path
self.active_nodes = meta_dict['active_nodes']
self.max_n_nodes = meta_dict['max_n_nodes']
@@ -1396,11 +1409,12 @@ def compute_meta(root, source_name, train_index, test_index):
'transition_E': tansition_E.tolist(),
}
with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f:
# with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f:
with open(os.path.join(self.cfg.general.root, 'nasbench201.meta.json'), "w") as f:
json.dump(meta_dict, f)
return meta_dict
if __name__ == "__main__":
dataset = Dataset(source='nasbench', root='/nfs/data3/hanzhang/nasbenchDiT/graph-dit', target_prop='Class', transform=None)
dataset = Dataset(source='nasbench', root='/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/', target_prop='Class', transform=None)

View File

@@ -23,6 +23,9 @@ class Graph_DiT(pl.LightningModule):
self.test_only = cfg.general.test_only
self.guidance_target = getattr(cfg.dataset, 'guidance_target', None)
from nas_201_api import NASBench201API as API
self.api = API(cfg.general.nas_201)
input_dims = dataset_infos.input_dims
output_dims = dataset_infos.output_dims
nodes_dist = dataset_infos.nodes_dist
@@ -41,7 +44,7 @@ class Graph_DiT(pl.LightningModule):
self.args.batch_size = 128
self.args.GPU = '0'
self.args.dataset = 'cifar10-valid'
self.args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
self.args.api_loc = cfg.general.nas_201
self.args.data_loc = '../cifardata/'
self.args.seed = 777
self.args.init = ''
@@ -79,6 +82,7 @@ class Graph_DiT(pl.LightningModule):
self.node_dist = nodes_dist
self.active_index = active_index
self.dataset_info = dataset_infos
self.cur_epoch = 0
self.train_loss = TrainLossDiscrete(self.cfg.model.lambda_train)
@@ -162,25 +166,81 @@ class Graph_DiT(pl.LightningModule):
return pred
def training_step(self, data, i):
data_x = F.one_hot(data.x, num_classes=12).float()[:, self.active_index]
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
if self.cfg.general.type != 'accelerator' and self.current_epoch > self.cfg.train.n_epochs / 5 * 4:
samples_left_to_generate = self.cfg.general.samples_to_generate
samples_left_to_save = self.cfg.general.samples_to_save
chains_left_to_save = self.cfg.general.chains_to_save
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
dense_data = dense_data.mask(node_mask)
X, E = dense_data.X, dense_data.E
noisy_data = self.apply_noise(X, E, data.y, node_mask)
pred = self.forward(noisy_data)
loss = self.train_loss(masked_pred_X=pred.X, masked_pred_E=pred.E, pred_y=pred.y,
true_X=X, true_E=E, true_y=data.y, node_mask=node_mask,
samples, all_ys, batch_id = [], [], 0
def graph_reward_fn(graphs, true_graphs=None, device=None, reward_model='swap'):
rewards = []
if reward_model == 'swap':
import csv
with open(self.cfg.general.swap_result, 'r') as f:
reader = csv.reader(f)
header = next(reader)
data = [row for row in reader]
swap_scores = [float(row[0]) for row in data]
for graph in graphs:
node_tensor = graph[0]
node = node_tensor.cpu().numpy().tolist()
def nodes_to_arch_str(nodes):
num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
nodes_str = [num_to_op[node] for node in nodes]
arch_str = '|' + nodes_str[1] + '~0|+' + \
'|' + nodes_str[2] + '~0|' + nodes_str[3] + '~1|+' +\
'|' + nodes_str[4] + '~0|' + nodes_str[5] + '~1|' + nodes_str[6] + '~2|'
return arch_str
arch_str = nodes_to_arch_str(node)
reward = swap_scores[self.api.query_index_by_arch(arch_str)]
rewards.append(reward)
return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device)
old_log_probs = None
bs = 1 * self.cfg.train.batch_size
to_generate = min(samples_left_to_generate, bs)
to_save = min(samples_left_to_save, bs)
chains_save = min(chains_left_to_save, bs)
# batch_y = test_y_collection[batch_id : batch_id + to_generate]
batch_y = torch.ones(to_generate, self.ydim_output, device=self.device)
cur_sample, log_probs = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
keep_chain=chains_save, number_chain_steps=self.number_chain_steps)
# samples = samples + cur_sample
samples.append(cur_sample)
reward = graph_reward_fn(cur_sample, device=self.device)
advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6) #
if old_log_probs is None:
old_log_probs = log_probs.clone()
ratio = torch.exp(log_probs - old_log_probs)
print(f"ratio: {ratio.shape}, advantages: {advantages.shape}")
unclipped_loss = -advantages * ratio
clipped_loss = -advantages * torch.clamp(ratio, 1.0 - self.cfg.ppo.clip_param, 1.0 + self.cfg.ppo.clip_param)
loss = torch.mean(torch.max(unclipped_loss, clipped_loss))
return {'loss': loss}
else:
data_x = F.one_hot(data.x, num_classes=12).float()[:, self.active_index]
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
dense_data = dense_data.mask(node_mask)
X, E = dense_data.X, dense_data.E
noisy_data = self.apply_noise(X, E, data.y, node_mask)
pred = self.forward(noisy_data)
loss = self.train_loss(masked_pred_X=pred.X, masked_pred_E=pred.E, pred_y=pred.y,
true_X=X, true_E=E, true_y=data.y, node_mask=node_mask,
log=i % self.log_every_steps == 0)
# print(f'training loss: {loss}, epoch: {self.current_epoch}, batch: {i}\n, pred type: {type(pred)}, pred.X shape: {type(pred.X)}, {pred.X.shape}, pred.E shape: {type(pred.E)}, {pred.E.shape}')
self.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E,
log=i % self.log_every_steps == 0)
# print(f'training loss: {loss}, epoch: {self.current_epoch}, batch: {i}\n, pred type: {type(pred)}, pred.X shape: {type(pred.X)}, {pred.X.shape}, pred.E shape: {type(pred.E)}, {pred.E.shape}')
self.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E,
log=i % self.log_every_steps == 0)
self.log(f'loss', loss, batch_size=X.size(0), sync_dist=True)
print(f"training loss: {loss}")
with open("training-loss.csv", "a") as f:
f.write(f"{loss}, {i}\n")
return {'loss': loss}
self.log(f'loss', loss, batch_size=X.size(0), sync_dist=True)
print(f"training loss: {loss}")
with open("training-loss.csv", "a") as f:
f.write(f"{loss}, {i}\n")
return {'loss': loss}
def configure_optimizers(self):
@@ -196,14 +256,15 @@ class Graph_DiT(pl.LightningModule):
def on_train_epoch_start(self) -> None:
if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
print("Starting train epoch {}/{}...".format(self.current_epoch, self.trainer.max_epochs))
# if self.cur_epoch / self.cfg.train.n_epochs in [0.25, 0.5, 0.75, 1.0]:
print("Starting train epoch {}/{}...".format(self.cur_epoch, self.cfg.train.n_epochs))
self.start_epoch_time = time.time()
self.train_loss.reset()
self.train_metrics.reset()
def on_train_epoch_end(self) -> None:
if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
if self.current_epoch / self.cfg.train.n_epochs in [0.25, 0.5, 0.75, 1.0]:
log = True
else:
log = False
@@ -240,6 +301,7 @@ class Graph_DiT(pl.LightningModule):
self.val_X_logp.compute(), self.val_E_logp.compute()]
if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
# if self.cur_epoch / self.cfg.train.n_epochs in [0.25, 0.5, 0.75, 1.0]:
print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ",
f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best : %.2f\n' % (metrics[0], self.best_val_nll))
with open("validation-metrics.csv", "a") as f:
@@ -283,10 +345,15 @@ class Graph_DiT(pl.LightningModule):
num_examples = self.val_y_collection.size(0)
batch_y = self.val_y_collection[start_index:start_index + to_generate]
all_ys.append(batch_y)
samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y,
cur_sample, logprobs = self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y,
save_final=to_save,
keep_chain=chains_save,
number_chain_steps=self.number_chain_steps))
number_chain_steps=self.number_chain_steps)
samples.extend(cur_sample)
# samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y,
# save_final=to_save,
# keep_chain=chains_save,
# number_chain_steps=self.number_chain_steps))
ident += to_generate
start_index += to_generate
@@ -336,7 +403,7 @@ class Graph_DiT(pl.LightningModule):
print(f"Epoch {self.current_epoch}: Test NLL {metrics[0] :.2f} -- Test Atom type KL {metrics[1] :.2f} -- ",
f"Test Edge type KL: {metrics[2] :.2f}")
## final epcoh
## final epoch
samples_left_to_generate = self.cfg.general.final_model_samples_to_generate
samples_left_to_save = self.cfg.general.final_model_samples_to_save
chains_left_to_save = self.cfg.general.final_model_chains_to_save
@@ -359,9 +426,9 @@ class Graph_DiT(pl.LightningModule):
# batch_y = test_y_collection[batch_id : batch_id + to_generate]
batch_y = torch.ones(to_generate, self.ydim_output, device=self.device)
cur_sample = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
cur_sample, log_probs = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
keep_chain=chains_save, number_chain_steps=self.number_chain_steps)
samples = samples + cur_sample
samples.extend(cur_sample)
all_ys.append(batch_y)
batch_id += to_generate
@@ -601,6 +668,12 @@ class Graph_DiT(pl.LightningModule):
assert (E == torch.transpose(E, 1, 2)).all()
if self.cfg.general.type != 'accelerator':
if self.trainer.training or self.trainer.validating:
total_log_probs = torch.zeros([self.cfg.general.samples_to_generate, 10], device=self.device)
elif self.trainer.testing:
total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate, 10], device=self.device)
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
for s_int in reversed(range(0, self.T)):
s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
@@ -609,21 +682,24 @@ class Graph_DiT(pl.LightningModule):
t_norm = t_array / self.T
# Sample z_s
sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask)
sampled_s, discrete_sampled_s, log_probs = self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask)
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
total_log_probs += log_probs
# Sample
sampled_s = sampled_s.mask(node_mask, collapse=True)
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
molecule_list = []
graph_list = []
for i in range(batch_size):
n = n_nodes[i]
atom_types = X[i, :n].cpu()
node_types = X[i, :n].cpu()
edge_types = E[i, :n, :n].cpu()
molecule_list.append([atom_types, edge_types])
graph_list.append((node_types , edge_types))
return molecule_list
total_log_probs = torch.sum(total_log_probs, dim=-1)
return graph_list, total_log_probs
def sample_p_zs_given_zt(self, s, t, X_t, E_t, y_t, node_mask):
"""Samples from zs ~ p(zs | zt). Only used during sampling.
@@ -675,6 +751,14 @@ class Graph_DiT(pl.LightningModule):
# with condition = P_t(A_{t-1} |A_t, y)
prob_X, prob_E, pred = get_prob(noisy_data)
log_prob_X = torch.log(torch.gather(prob_X, -1, X_t.long()).squeeze(-1)) # bs, n
log_prob_E = torch.log(torch.gather(prob_E, -1, E_t.long()).squeeze(-1)) # bs, n, n
# Sum the log_prob across dimensions for total log_prob
log_prob_X = log_prob_X.sum(dim=-1)
log_prob_E = log_prob_E.sum(dim=(1, 2))
log_probs = torch.cat([log_prob_X, log_prob_E], dim=-1)
### Guidance
if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1:
uncon_prob_X, uncon_prob_E, pred = get_prob(noisy_data, unconditioned=True)
@@ -810,4 +894,4 @@ class Graph_DiT(pl.LightningModule):
out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=y_t)
out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=y_t)
return out_one_hot.mask(node_mask).type_as(y_t), out_discrete.mask(node_mask, collapse=True).type_as(y_t)
return out_one_hot.mask(node_mask).type_as(y_t), out_discrete.mask(node_mask, collapse=True).type_as(y_t), log_probs

View File

@@ -177,32 +177,92 @@ def test(cfg: DictConfig):
os.chdir(cfg.general.resume.split("checkpoints")[0])
# os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number
model = Graph_DiT(cfg=cfg, **model_kwargs)
trainer = Trainer(
gradient_clip_val=cfg.train.clip_grad,
# accelerator="cpu",
accelerator="gpu"
if torch.cuda.is_available() and cfg.general.gpus > 0
else "cpu",
devices=[cfg.general.gpu_number]
if torch.cuda.is_available() and cfg.general.gpus > 0
else None,
max_epochs=cfg.train.n_epochs,
enable_checkpointing=False,
check_val_every_n_epoch=cfg.train.check_val_every_n_epoch,
val_check_interval=cfg.train.val_check_interval,
strategy="ddp" if cfg.general.gpus > 1 else "auto",
enable_progress_bar=cfg.general.enable_progress_bar,
callbacks=[],
reload_dataloaders_every_n_epochs=0,
logger=[],
)
if not cfg.general.test_only:
print("start testing fit method")
trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume)
if cfg.general.save_model:
trainer.save_checkpoint(f"checkpoints/{cfg.general.name}/last.ckpt")
trainer.test(model, datamodule=datamodule)
if cfg.general.type == "accelerator":
graph_dit_model = model
from accelerate import Accelerator
from accelerate.utils import set_seed, ProjectConfiguration
accelerator_config = ProjectConfiguration(
project_dir=os.path.join(cfg.general.log_dir, cfg.general.name),
automatic_checkpoint_naming=True,
total_limit=cfg.general.number_checkpoint_limit,
)
accelerator = Accelerator(
mixed_precision='no',
project_config=accelerator_config,
# gradient_accumulation_steps=cfg.train.gradient_accumulation_steps * cfg.train.n_epochs,
gradient_accumulation_steps=cfg.train.gradient_accumulation_steps,
)
optimizer = graph_dit_model.configure_optimizers()
train_dataloader = datamodule.train_dataloader()
train_dataloader = accelerator.prepare(train_dataloader)
val_dataloader = datamodule.val_dataloader()
val_dataloader = accelerator.prepare(val_dataloader)
test_dataloader = datamodule.test_dataloader()
test_dataloader = accelerator.prepare(test_dataloader)
optimizer, graph_dit_model = accelerator.prepare(optimizer, graph_dit_model)
# train_epoch
from pytorch_lightning import seed_everything
seed_everything(cfg.train.seed)
for epoch in range(cfg.train.n_epochs):
print(f"Epoch {epoch}")
graph_dit_model.train()
graph_dit_model.cur_epoch = epoch
graph_dit_model.on_train_epoch_start()
for batch in train_dataloader:
optimizer.zero_grad()
loss = graph_dit_model.training_step(batch, epoch)['loss']
accelerator.backward(loss)
optimizer.step()
graph_dit_model.on_train_epoch_end()
for batch in val_dataloader:
if epoch % cfg.train.check_val_every_n_epoch == 0:
graph_dit_model.eval()
graph_dit_model.on_validation_epoch_start()
graph_dit_model.validation_step(batch, epoch)
graph_dit_model.on_validation_epoch_end()
# test_epoch
graph_dit_model.test()
graph_dit_model.on_test_epoch_start()
for batch in test_dataloader:
graph_dit_model.test_step(batch, epoch)
graph_dit_model.on_test_epoch_end()
elif cfg.general.type == "Trainer":
trainer = Trainer(
gradient_clip_val=cfg.train.clip_grad,
# accelerator="cpu",
accelerator="gpu"
if torch.cuda.is_available() and cfg.general.gpus > 0
else "cpu",
devices=[cfg.general.gpu_number]
if torch.cuda.is_available() and cfg.general.gpus > 0
else None,
max_epochs=cfg.train.n_epochs,
enable_checkpointing=False,
check_val_every_n_epoch=cfg.train.check_val_every_n_epoch,
val_check_interval=cfg.train.val_check_interval,
strategy="ddp" if cfg.general.gpus > 1 else "auto",
enable_progress_bar=cfg.general.enable_progress_bar,
callbacks=[],
reload_dataloaders_every_n_epochs=0,
logger=[],
)
if not cfg.general.test_only:
print("start testing fit method")
trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume)
if cfg.general.save_model:
trainer.save_checkpoint(f"checkpoints/{cfg.general.name}/last.ckpt")
trainer.test(model, datamodule=datamodule)
if __name__ == "__main__":
test()

View File

@@ -83,7 +83,8 @@ class TaskModel():
return adj_ops_pairs
def feature_from_adj_and_ops(adj, ops):
return np.concatenate([adj.flatten(), ops])
filename = '/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json'
# filename = '/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json'
filename = '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/nasbench-201-graph.json'
graphs = read_adj_ops_from_json(filename)
adjs = []
opss = []

15626
graph_dit/swap_results.csv Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long