Compare commits
5 Commits
nasbench
...
dd91da7921
Author | SHA1 | Date | |
---|---|---|---|
dd91da7921 | |||
586e354971 | |||
e376f38dcb | |||
7149b49a39 | |||
83f9345028 |
6
.gitignore
vendored
6
.gitignore
vendored
@@ -1,3 +1,9 @@
|
||||
cifardata/
|
||||
NAS-Bench-201-*
|
||||
*.csv.gz
|
||||
*.meta.json
|
||||
*.pt
|
||||
*.zip
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
BIN
2401.13858v2 (2).pdf
Normal file
BIN
2401.13858v2 (2).pdf
Normal file
Binary file not shown.
20
environment.yml
Normal file
20
environment.yml
Normal file
@@ -0,0 +1,20 @@
|
||||
name: graphdit39
|
||||
dependencies:
|
||||
- python=3.9
|
||||
- fcd_torch==1.0.7
|
||||
- hydra-core==1.3.2
|
||||
- imageio==2.26.0
|
||||
- joblib==1.2.0
|
||||
- matplotlib==3.7.0
|
||||
- mini_moses==1.0
|
||||
- networkx==3.0
|
||||
- numpy==1.24.2
|
||||
- omegaconf==2.3.0
|
||||
- pandas==1.5.3
|
||||
- pytorch_lightning==2.0.1
|
||||
- rdkit==2023.9.4
|
||||
- scikit_learn==1.2.1
|
||||
- torch==2.0.0
|
||||
- torch_geometric==2.3.0
|
||||
- torchmetrics==0.11.4
|
||||
- tqdm==4.64.1
|
@@ -70,7 +70,7 @@ class DataModule(AbstractDataModule):
|
||||
# base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
|
||||
# except NameError:
|
||||
# base_path = pathlib.Path(os.getcwd()).parent[2]
|
||||
base_path = '/home/stud/hanzhang/nasbenchDiT'
|
||||
base_path = '/nfs/data3/hanzhang/nasbenchDiT'
|
||||
root_path = os.path.join(base_path, self.datadir)
|
||||
self.root_path = root_path
|
||||
|
||||
@@ -408,6 +408,7 @@ def new_graphs_to_json(graphs, filename):
|
||||
adj = graph[0]
|
||||
|
||||
n_node = len(ops)
|
||||
print(n_node)
|
||||
n_edge = len(ops)
|
||||
n_node_list.append(n_node)
|
||||
n_edge_list.append(n_edge)
|
||||
@@ -489,7 +490,7 @@ def new_graphs_to_json(graphs, filename):
|
||||
'transition_E': transition_E.tolist(),
|
||||
}
|
||||
|
||||
with open(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f:
|
||||
with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f:
|
||||
json.dump(meta_dict, f)
|
||||
|
||||
return meta_dict
|
||||
@@ -655,7 +656,7 @@ def graphs_to_json(graphs, filename):
|
||||
class Dataset(InMemoryDataset):
|
||||
def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):
|
||||
self.target_prop = target_prop
|
||||
source = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
||||
source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
||||
self.source = source
|
||||
# self.api = API(source) # Initialize NAS-Bench-201 API
|
||||
# print('API loaded')
|
||||
@@ -676,8 +677,8 @@ class Dataset(InMemoryDataset):
|
||||
return [f'{self.source}.pt']
|
||||
|
||||
def process(self):
|
||||
source = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
||||
self.api = API(source)
|
||||
source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
||||
# self.api = API(source)
|
||||
|
||||
data_list = []
|
||||
# len_data = len(self.api)
|
||||
@@ -710,14 +711,24 @@ class Dataset(InMemoryDataset):
|
||||
return True
|
||||
|
||||
def generate_flex_adj_mat(ori_nodes, ori_edges, max_nodes=12, min_nodes=8,random_ratio=0.5):
|
||||
# print(ori_nodes)
|
||||
# print(ori_edges)
|
||||
|
||||
ori_edges = np.array(ori_edges)
|
||||
# ori_nodes = np.array(ori_nodes)
|
||||
nasbench_201_node_num = 8
|
||||
# random.seed(random_seed)
|
||||
nodes_num = random.randint(min_nodes, max_nodes)
|
||||
# print(f'arch_str: {arch_str}, \nmax_nodes: {max_nodes}, min_nodes: {min_nodes}, nodes_num: {nodes_num},random_seed: {random_seed},random_ratio: {random_ratio}')
|
||||
add_num = nodes_num - nasbench_201_node_num
|
||||
# ori_nodes, ori_edges = parse_architecture_string(arch_str)
|
||||
add_nodes = [op for op in random.choices(num_to_op[1:-1], k=add_num)]
|
||||
add_nodes = []
|
||||
print(f'add_num: {add_num}')
|
||||
for i in range(add_num):
|
||||
add_nodes.append(random.choice(num_to_op[1:-1]))
|
||||
# print(add_nodes)
|
||||
print(f'ori_nodes[:-1]: {ori_nodes[:-1]}, add_nodes: {add_nodes}')
|
||||
print(f'len(ori_nodes[:-1]): {len(ori_nodes[:-1])}, len(add_nodes): {len(add_nodes)}')
|
||||
nodes = ori_nodes[:-1] + add_nodes + ['output']
|
||||
edges = np.zeros((nodes_num , nodes_num))
|
||||
edges[:6, :6] = ori_edges[:6, :6]
|
||||
@@ -727,12 +738,18 @@ class Dataset(InMemoryDataset):
|
||||
rand = random.random()
|
||||
if rand < random_ratio:
|
||||
edges[i, j] = 1
|
||||
return nodes, edges
|
||||
if nodes_num < max_nodes:
|
||||
edges = np.pad(edges, ((0, max_nodes - nodes_num), (0, max_nodes - nodes_num)), 'constant',constant_values=0)
|
||||
while len(nodes) < max_nodes:
|
||||
nodes.append('none')
|
||||
print(f'edges size: {edges.shape}, nodes size: {len(nodes)}')
|
||||
return edges,nodes
|
||||
|
||||
def get_nasbench_201_val(idx):
|
||||
pass
|
||||
|
||||
def graph_to_graph_data(graph, idx):
|
||||
# def graph_to_graph_data(graph, idx):
|
||||
def graph_to_graph_data(graph):
|
||||
ops = graph[1]
|
||||
adj = graph[0]
|
||||
nodes = []
|
||||
@@ -753,58 +770,95 @@ class Dataset(InMemoryDataset):
|
||||
edge_index = torch.tensor(edges_list, dtype=torch.long).t()
|
||||
edge_type = torch.tensor(edge_type, dtype=torch.long)
|
||||
edge_attr = edge_type
|
||||
# y = torch.tensor([0, 0], dtype=torch.float).view(1, -1)
|
||||
y = get_nasbench_201_val(idx)
|
||||
y = torch.tensor([0, 0], dtype=torch.float).view(1, -1)
|
||||
# y = get_nasbench_201_val(idx)
|
||||
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i)
|
||||
return data
|
||||
graph_list = []
|
||||
|
||||
with tqdm(total = len_data) as pbar:
|
||||
active_nodes = set()
|
||||
for i in range(len_data):
|
||||
arch_info = self.api.query_meta_info_by_index(i)
|
||||
results = self.api.query_by_index(i, 'cifar100')
|
||||
file_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json'
|
||||
with open(file_path, 'r') as f:
|
||||
graph_list = json.load(f)
|
||||
i = 0
|
||||
flex_graph_list = []
|
||||
flex_graph_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json'
|
||||
for graph in graph_list:
|
||||
# arch_info = self.api.query_meta_info_by_index(i)
|
||||
# results = self.api.query_by_index(i, 'cifar100')
|
||||
arch_info = graph['arch_str']
|
||||
# results =
|
||||
# nodes, edges = parse_architecture_string(arch_info.arch_str)
|
||||
ops, adj_matrix = parse_architecture_string(arch_info.arch_str)
|
||||
# ops, adj_matrix = parse_architecture_string(arch_info.arch_str, padding=4)
|
||||
ops, adj_matrix, ori_nodes, ori_adj = parse_architecture_string(arch_info, padding=4)
|
||||
# adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges)
|
||||
for op in ops:
|
||||
if op not in active_nodes:
|
||||
active_nodes.add(op)
|
||||
|
||||
graph_list.append({
|
||||
"adj_matrix": adj_matrix,
|
||||
"ops": ops,
|
||||
"idx": i,
|
||||
"train": [{
|
||||
"iepoch": result.get_train()['iepoch'],
|
||||
"loss": result.get_train()['loss'],
|
||||
"accuracy": result.get_train()['accuracy'],
|
||||
"cur_time": result.get_train()['cur_time'],
|
||||
"all_time": result.get_train()['all_time'],
|
||||
"seed": seed,
|
||||
}for seed, result in results.items()],
|
||||
"valid": [{
|
||||
"iepoch": result.get_eval('x-valid')['iepoch'],
|
||||
"loss": result.get_eval('x-valid')['loss'],
|
||||
"accuracy": result.get_eval('x-valid')['accuracy'],
|
||||
"cur_time": result.get_eval('x-valid')['cur_time'],
|
||||
"all_time": result.get_eval('x-valid')['all_time'],
|
||||
"seed": seed,
|
||||
}for seed, result in results.items()],
|
||||
"test": [{
|
||||
"iepoch": result.get_eval('x-test')['iepoch'],
|
||||
"loss": result.get_eval('x-test')['loss'],
|
||||
"accuracy": result.get_eval('x-test')['accuracy'],
|
||||
"cur_time": result.get_eval('x-test')['cur_time'],
|
||||
"all_time": result.get_eval('x-test')['all_time'],
|
||||
"seed": seed,
|
||||
}for seed, result in results.items()]
|
||||
})
|
||||
data = graph_to_graph_data((adj_matrix, ops))
|
||||
# with open(flex_graph_path, 'a') as f:
|
||||
# flex_graph = {
|
||||
# 'adj_matrix': adj_matrix,
|
||||
# 'ops': ops,
|
||||
# }
|
||||
# json.dump(flex_graph, f)
|
||||
flex_graph_list.append({
|
||||
'adj_matrix':adj_matrix,
|
||||
'ops': ops,
|
||||
})
|
||||
if i < 3:
|
||||
print(f"i={i}, data={data}")
|
||||
with open(f'{i}.json', 'w') as f:
|
||||
f.write(str(data.x))
|
||||
f.write(str(data.edge_index))
|
||||
f.write(str(data.edge_attr))
|
||||
data_list.append(data)
|
||||
|
||||
# new_adj, new_ops = generate_flex_adj_mat(ori_nodes=ops, ori_edges=adj_matrix, max_nodes=12, min_nodes=8, random_ratio=0.5)
|
||||
# data_list.append(graph_to_graph_data((new_adj, new_ops)))
|
||||
new_adj, new_ops = generate_flex_adj_mat(ori_nodes=ori_nodes, ori_edges=ori_adj, max_nodes=12, min_nodes=9, random_ratio=0.5)
|
||||
flex_graph_list.append({
|
||||
'adj_matrix':new_adj.tolist(),
|
||||
'ops': new_ops,
|
||||
})
|
||||
# with open(flex_graph_path, 'w') as f:
|
||||
# flex_graph = {
|
||||
# 'adj_matrix': new_adj.tolist(),
|
||||
# 'ops': new_ops,
|
||||
# }
|
||||
# json.dump(flex_graph, f)
|
||||
data_list.append(graph_to_graph_data((new_adj, new_ops)))
|
||||
|
||||
# graph_list.append({
|
||||
# "adj_matrix": adj_matrix,
|
||||
# "ops": ops,
|
||||
# "arch_str": arch_info.arch_str,
|
||||
# "idx": i,
|
||||
# "train": [{
|
||||
# "iepoch": result.get_train()['iepoch'],
|
||||
# "loss": result.get_train()['loss'],
|
||||
# "accuracy": result.get_train()['accuracy'],
|
||||
# "cur_time": result.get_train()['cur_time'],
|
||||
# "all_time": result.get_train()['all_time'],
|
||||
# "seed": seed,
|
||||
# }for seed, result in results.items()],
|
||||
# "valid": [{
|
||||
# "iepoch": result.get_eval('x-valid')['iepoch'],
|
||||
# "loss": result.get_eval('x-valid')['loss'],
|
||||
# "accuracy": result.get_eval('x-valid')['accuracy'],
|
||||
# "cur_time": result.get_eval('x-valid')['cur_time'],
|
||||
# "all_time": result.get_eval('x-valid')['all_time'],
|
||||
# "seed": seed,
|
||||
# }for seed, result in results.items()],
|
||||
# "test": [{
|
||||
# "iepoch": result.get_eval('x-test')['iepoch'],
|
||||
# "loss": result.get_eval('x-test')['loss'],
|
||||
# "accuracy": result.get_eval('x-test')['accuracy'],
|
||||
# "cur_time": result.get_eval('x-test')['cur_time'],
|
||||
# "all_time": result.get_eval('x-test')['all_time'],
|
||||
# "seed": seed,
|
||||
# }for seed, result in results.items()]
|
||||
# })
|
||||
pbar.update(1)
|
||||
|
||||
for graph in graph_list:
|
||||
@@ -818,6 +872,8 @@ class Dataset(InMemoryDataset):
|
||||
graph['ops'] = ops
|
||||
with open(f'nasbench-201-graph.json', 'w') as f:
|
||||
json.dump(graph_list, f)
|
||||
with open(flex_graph_path, 'w') as f:
|
||||
json.dump(flex_graph_list, f)
|
||||
|
||||
torch.save(self.collate(data_list), self.processed_paths[0])
|
||||
|
||||
@@ -981,18 +1037,29 @@ class Dataset_origin(InMemoryDataset):
|
||||
|
||||
torch.save(self.collate(data_list), self.processed_paths[0])
|
||||
|
||||
def parse_architecture_string(arch_str):
|
||||
def parse_architecture_string(arch_str, padding=0):
|
||||
# print(arch_str)
|
||||
steps = arch_str.split('+')
|
||||
nodes = ['input'] # Start with input node
|
||||
adj_mat = np.array([[0, 1, 1, 0, 1, 0, 0, 0],
|
||||
ori_adj_mat = [[0, 1, 1, 0, 1, 0, 0, 0],
|
||||
[0, 0, 0, 1, 0, 1 ,0 ,0],
|
||||
[0, 0, 0, 0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0]])
|
||||
# [0, 0, 0, 0, 0, 0, 0, 0]])
|
||||
[0, 0, 0, 0, 0, 0, 0, 0]]
|
||||
# adj_mat = np.array([[0, 1, 1, 0, 1, 0, 0, 0],
|
||||
adj_mat = [[0, 1, 1, 0, 1, 0, 0, 0],
|
||||
[0, 0, 0, 1, 0, 1 ,0 ,0],
|
||||
[0, 0, 0, 0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 0, 0, 1],
|
||||
# [0, 0, 0, 0, 0, 0, 0, 0]])
|
||||
[0, 0, 0, 0, 0, 0, 0, 0]]
|
||||
steps = arch_str.split('+')
|
||||
steps_coding = ['0', '0', '1', '0', '1', '2']
|
||||
cont = 0
|
||||
@@ -1004,7 +1071,21 @@ def parse_architecture_string(arch_str):
|
||||
cont += 1
|
||||
nodes.append(n)
|
||||
nodes.append('output') # Add output node
|
||||
return nodes, adj_mat
|
||||
ori_nodes = nodes.copy()
|
||||
if padding > 0:
|
||||
for i in range(padding):
|
||||
nodes.append('none')
|
||||
for adj_row in adj_mat:
|
||||
for i in range(padding):
|
||||
adj_row.append(0)
|
||||
# adj_mat = np.append(adj_mat, np.zeros((padding, len(nodes))))
|
||||
for i in range(padding):
|
||||
adj_mat.append([0] * len(nodes))
|
||||
# print(nodes)
|
||||
# print(adj_mat)
|
||||
# print(len(adj_mat))
|
||||
# print(f'len(ori_nodes): {len(ori_nodes)}, len(nodes): {len(nodes)}')
|
||||
return nodes, adj_mat, ori_nodes, ori_adj_mat
|
||||
|
||||
def create_adj_matrix_and_ops(nodes, edges):
|
||||
num_nodes = len(nodes)
|
||||
@@ -1046,6 +1127,7 @@ class DataInfos(AbstractDatasetInfos):
|
||||
|
||||
adj_ops_pairs = []
|
||||
for item in data:
|
||||
print(item)
|
||||
adj_matrix = np.array(item['adj_matrix'])
|
||||
ops = item['ops']
|
||||
ops = [op_type[op] for op in ops]
|
||||
@@ -1066,12 +1148,12 @@ class DataInfos(AbstractDatasetInfos):
|
||||
# ops_type[op] = len(ops_type)
|
||||
# len_ops.add(len(ops))
|
||||
# graphs.append((adj_matrix, ops))
|
||||
graphs = read_adj_ops_from_json(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json')
|
||||
graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json')
|
||||
|
||||
# check first five graphs
|
||||
for i in range(5):
|
||||
print(f'graph {i} : {graphs[i]}')
|
||||
print(f'ops_type: {ops_type}')
|
||||
# print(f'ops_type: {ops_type}')
|
||||
|
||||
meta_dict = new_graphs_to_json(graphs, 'nasbench-201')
|
||||
self.base_path = base_path
|
||||
@@ -1280,11 +1362,11 @@ def compute_meta(root, source_name, train_index, test_index):
|
||||
'transition_E': tansition_E.tolist(),
|
||||
}
|
||||
|
||||
with open(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f:
|
||||
with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f:
|
||||
json.dump(meta_dict, f)
|
||||
|
||||
return meta_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset = Dataset(source='nasbench', root='/home/stud/hanzhang/nasbenchDiT/graph-dit', target_prop='Class', transform=None)
|
||||
dataset = Dataset(source='nasbench', root='/nfs/data3/hanzhang/nasbenchDiT/graph-dit', target_prop='Class', transform=None)
|
||||
|
@@ -3,9 +3,9 @@ import torch.nn.functional as F
|
||||
import pytorch_lightning as pl
|
||||
import time
|
||||
import os
|
||||
from naswot.score_networks import get_nasbench201_nodes_score
|
||||
from naswot import nasspace
|
||||
from naswot import datasets
|
||||
# from naswot.score_networks import get_nasbench201_nodes_score
|
||||
# from naswot import nasspace
|
||||
# from naswot import datasets
|
||||
from models.transformer import Denoiser
|
||||
from diffusion.noise_schedule import PredefinedNoiseScheduleDiscrete, MarginalTransition
|
||||
|
||||
@@ -41,7 +41,7 @@ class Graph_DiT(pl.LightningModule):
|
||||
self.args.batch_size = 128
|
||||
self.args.GPU = '0'
|
||||
self.args.dataset = 'cifar10-valid'
|
||||
self.args.api_loc = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
||||
self.args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
||||
self.args.data_loc = '../cifardata/'
|
||||
self.args.seed = 777
|
||||
self.args.init = ''
|
||||
@@ -59,10 +59,10 @@ class Graph_DiT(pl.LightningModule):
|
||||
if 'valid' in self.args.dataset:
|
||||
self.args.dataset = self.args.dataset.replace('-valid', '')
|
||||
print('graph_dit starts to get searchspace of nasbench201')
|
||||
self.searchspace = nasspace.get_search_space(self.args)
|
||||
# self.searchspace = nasspace.get_search_space(self.args)
|
||||
print('searchspace of nasbench201 is obtained')
|
||||
print('graphdit starts to get train_loader')
|
||||
self.train_loader = datasets.get_data(self.args.dataset, self.args.data_loc, self.args.trainval, self.args.batch_size, self.args.augtype, self.args.repeat, self.args)
|
||||
# self.train_loader = datasets.get_data(self.args.dataset, self.args.data_loc, self.args.trainval, self.args.batch_size, self.args.augtype, self.args.repeat, self.args)
|
||||
print('train_loader is obtained')
|
||||
|
||||
self.cfg = cfg
|
||||
@@ -162,7 +162,7 @@ class Graph_DiT(pl.LightningModule):
|
||||
return pred
|
||||
|
||||
def training_step(self, data, i):
|
||||
data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index]
|
||||
data_x = F.one_hot(data.x, num_classes=12).float()[:, self.active_index]
|
||||
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
|
||||
|
||||
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
|
||||
@@ -222,7 +222,7 @@ class Graph_DiT(pl.LightningModule):
|
||||
|
||||
@torch.no_grad()
|
||||
def validation_step(self, data, i):
|
||||
data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index]
|
||||
data_x = F.one_hot(data.x, num_classes=12).float()[:, self.active_index]
|
||||
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
|
||||
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
|
||||
dense_data = dense_data.mask(node_mask, collapse=False)
|
||||
@@ -315,7 +315,7 @@ class Graph_DiT(pl.LightningModule):
|
||||
|
||||
@torch.no_grad()
|
||||
def test_step(self, data, i):
|
||||
data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index]
|
||||
data_x = F.one_hot(data.x, num_classes=12).float()[:, self.active_index]
|
||||
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
|
||||
|
||||
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
|
||||
@@ -686,120 +686,120 @@ class Graph_DiT(pl.LightningModule):
|
||||
assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-4).all()
|
||||
assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-4).all()
|
||||
|
||||
# sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
|
||||
sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
|
||||
|
||||
# sample multiple times and get the best score arch...
|
||||
|
||||
num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
|
||||
op_type = {
|
||||
'input': 0,
|
||||
'nor_conv_1x1': 1,
|
||||
'nor_conv_3x3': 2,
|
||||
'avg_pool_3x3': 3,
|
||||
'skip_connect': 4,
|
||||
'none': 5,
|
||||
'output': 6,
|
||||
}
|
||||
def check_valid_graph(nodes, edges):
|
||||
nodes = [num_to_op[i] for i in nodes]
|
||||
if len(nodes) != edges.shape[0] or len(nodes) != edges.shape[1]:
|
||||
return False
|
||||
if nodes[0] != 'input' or nodes[-1] != 'output':
|
||||
return False
|
||||
for i in range(0, len(nodes)):
|
||||
if edges[i][i] == 1:
|
||||
return False
|
||||
for i in range(1, len(nodes) - 1):
|
||||
if nodes[i] not in op_type or nodes[i] == 'input' or nodes[i] == 'output':
|
||||
return False
|
||||
for i in range(0, len(nodes)):
|
||||
for j in range(i, len(nodes)):
|
||||
if edges[i, j] == 1 and nodes[j] == 'input':
|
||||
return False
|
||||
for i in range(0, len(nodes)):
|
||||
for j in range(i, len(nodes)):
|
||||
if edges[i, j] == 1 and nodes[i] == 'output':
|
||||
return False
|
||||
flag = 0
|
||||
for i in range(0,len(nodes)):
|
||||
if edges[i,-1] == 1:
|
||||
flag = 1
|
||||
break
|
||||
if flag == 0: return False
|
||||
return True
|
||||
# num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
|
||||
# op_type = {
|
||||
# 'input': 0,
|
||||
# 'nor_conv_1x1': 1,
|
||||
# 'nor_conv_3x3': 2,
|
||||
# 'avg_pool_3x3': 3,
|
||||
# 'skip_connect': 4,
|
||||
# 'none': 5,
|
||||
# 'output': 6,
|
||||
# }
|
||||
# def check_valid_graph(nodes, edges):
|
||||
# nodes = [num_to_op[i] for i in nodes]
|
||||
# if len(nodes) != edges.shape[0] or len(nodes) != edges.shape[1]:
|
||||
# return False
|
||||
# if nodes[0] != 'input' or nodes[-1] != 'output':
|
||||
# return False
|
||||
# for i in range(0, len(nodes)):
|
||||
# if edges[i][i] == 1:
|
||||
# return False
|
||||
# for i in range(1, len(nodes) - 1):
|
||||
# if nodes[i] not in op_type or nodes[i] == 'input' or nodes[i] == 'output':
|
||||
# return False
|
||||
# for i in range(0, len(nodes)):
|
||||
# for j in range(i, len(nodes)):
|
||||
# if edges[i, j] == 1 and nodes[j] == 'input':
|
||||
# return False
|
||||
# for i in range(0, len(nodes)):
|
||||
# for j in range(i, len(nodes)):
|
||||
# if edges[i, j] == 1 and nodes[i] == 'output':
|
||||
# return False
|
||||
# flag = 0
|
||||
# for i in range(0,len(nodes)):
|
||||
# if edges[i,-1] == 1:
|
||||
# flag = 1
|
||||
# break
|
||||
# if flag == 0: return False
|
||||
# return True
|
||||
|
||||
class Args:
|
||||
pass
|
||||
# class Args:
|
||||
# pass
|
||||
|
||||
def get_score(sampled_s):
|
||||
x_list = sampled_s.X.unbind(dim=0)
|
||||
e_list = sampled_s.E.unbind(dim=0)
|
||||
valid_rlt = [check_valid_graph(x_list[i].cpu().numpy(), e_list[i].cpu().numpy()) for i in range(len(x_list))]
|
||||
from graph_dit.naswot.naswot.score_networks import get_nasbench201_nodes_score
|
||||
score = []
|
||||
# def get_score(sampled_s):
|
||||
# x_list = sampled_s.X.unbind(dim=0)
|
||||
# e_list = sampled_s.E.unbind(dim=0)
|
||||
# valid_rlt = [check_valid_graph(x_list[i].cpu().numpy(), e_list[i].cpu().numpy()) for i in range(len(x_list))]
|
||||
# from graph_dit.naswot.naswot.score_networks import get_nasbench201_nodes_score
|
||||
# score = []
|
||||
|
||||
for i in range(len(x_list)):
|
||||
if valid_rlt[i]:
|
||||
nodes = [num_to_op[j] for j in x_list[i].cpu().numpy()]
|
||||
# edges = e_list[i].cpu().numpy()
|
||||
score.append(get_nasbench201_nodes_score(nodes,train_loader=self.train_loader,searchspace=self.searchspace,device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu") , args=self.args))
|
||||
else:
|
||||
score.append(-1)
|
||||
return torch.tensor(score, dtype=torch.float32, requires_grad=True).to(x_list[0].device)
|
||||
# for i in range(len(x_list)):
|
||||
# if valid_rlt[i]:
|
||||
# nodes = [num_to_op[j] for j in x_list[i].cpu().numpy()]
|
||||
# # edges = e_list[i].cpu().numpy()
|
||||
# score.append(get_nasbench201_nodes_score(nodes,train_loader=self.train_loader,searchspace=self.searchspace,device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu") , args=self.args))
|
||||
# else:
|
||||
# score.append(-1)
|
||||
# return torch.tensor(score, dtype=torch.float32, requires_grad=True).to(x_list[0].device)
|
||||
|
||||
sample_num = 10
|
||||
best_arch = None
|
||||
best_score_int = -1e8
|
||||
score = torch.ones(100, dtype=torch.float32, requires_grad=True) * -1e8
|
||||
# sample_num = 10
|
||||
# best_arch = None
|
||||
# best_score_int = -1e8
|
||||
# score = torch.ones(100, dtype=torch.float32, requires_grad=True) * -1e8
|
||||
|
||||
for i in range(sample_num):
|
||||
sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
|
||||
score = get_score(sampled_s)
|
||||
print(f'score: {score}')
|
||||
print(f'score.shape: {score.shape}')
|
||||
print(f'torch.sum(score): {torch.sum(score)}')
|
||||
sum_score = torch.sum(score)
|
||||
print(f'sum_score: {sum_score}')
|
||||
if sum_score > best_score_int:
|
||||
best_score_int = sum_score
|
||||
best_score = score
|
||||
best_arch = sampled_s
|
||||
# for i in range(sample_num):
|
||||
# sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
|
||||
# score = get_score(sampled_s)
|
||||
# print(f'score: {score}')
|
||||
# print(f'score.shape: {score.shape}')
|
||||
# print(f'torch.sum(score): {torch.sum(score)}')
|
||||
# sum_score = torch.sum(score)
|
||||
# print(f'sum_score: {sum_score}')
|
||||
# if sum_score > best_score_int:
|
||||
# best_score_int = sum_score
|
||||
# best_score = score
|
||||
# best_arch = sampled_s
|
||||
|
||||
# print(f'prob_X: {prob_X.shape}, prob_E: {prob_E.shape}')
|
||||
|
||||
# best_arch = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
|
||||
|
||||
# X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float()
|
||||
# E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float()
|
||||
print(f'best_arch.X: {best_arch.X.shape}, best_arch.E: {best_arch.E.shape}') # 100 8 8, bs n n, 100 8 8 2, bs n n 2
|
||||
X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float()
|
||||
E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float()
|
||||
# print(f'best_arch.X: {best_arch.X.shape}, best_arch.E: {best_arch.E.shape}') # 100 8 8, bs n n, 100 8 8 2, bs n n 2
|
||||
|
||||
print(f'best_arch.X: {best_arch.X}, best_arch.E: {best_arch.E}')
|
||||
X_s = F.one_hot(best_arch.X, num_classes=self.Xdim_output).float()
|
||||
E_s = F.one_hot(best_arch.E, num_classes=self.Edim_output).float()
|
||||
print(f'X_s: {X_s}, E_s: {E_s}')
|
||||
# print(f'best_arch.X: {best_arch.X}, best_arch.E: {best_arch.E}')
|
||||
# X_s = F.one_hot(best_arch.X, num_classes=self.Xdim_output).float()
|
||||
# E_s = F.one_hot(best_arch.E, num_classes=self.Edim_output).float()
|
||||
# print(f'X_s: {X_s}, E_s: {E_s}')
|
||||
|
||||
# NASWOT score
|
||||
target_score = torch.ones(100, requires_grad=True) * 2000.0
|
||||
target_score = target_score.to(X_s.device)
|
||||
# # NASWOT score
|
||||
# target_score = torch.ones(100, requires_grad=True) * 2000.0
|
||||
# target_score = target_score.to(X_s.device)
|
||||
|
||||
# compute loss mse(cur_score - target_score)
|
||||
mse_loss = torch.nn.MSELoss()
|
||||
print(f'best_score: {best_score.shape}, target_score: {target_score.shape}')
|
||||
print(f'best_score.requires_grad: {best_score.requires_grad}, target_score.requires_grad: {target_score.requires_grad}')
|
||||
loss = mse_loss(best_score, target_score)
|
||||
loss.backward(retain_graph=True)
|
||||
# # compute loss mse(cur_score - target_score)
|
||||
# mse_loss = torch.nn.MSELoss()
|
||||
# print(f'best_score: {best_score.shape}, target_score: {target_score.shape}')
|
||||
# print(f'best_score.requires_grad: {best_score.requires_grad}, target_score.requires_grad: {target_score.requires_grad}')
|
||||
# loss = mse_loss(best_score, target_score)
|
||||
# loss.backward(retain_graph=True)
|
||||
|
||||
# loss backward = gradient
|
||||
|
||||
# get prob.X, prob_E gradient
|
||||
x_grad = pred.X.grad
|
||||
e_grad = pred.E.grad
|
||||
# x_grad = pred.X.grad
|
||||
# e_grad = pred.E.grad
|
||||
|
||||
beta_ratio = 0.5
|
||||
# x_current = pred.X - beta_ratio * x_grad
|
||||
# e_current = pred.E - beta_ratio * e_grad
|
||||
E_s = pred.X - beta_ratio * x_grad
|
||||
X_s = pred.E - beta_ratio * e_grad
|
||||
# beta_ratio = 0.5
|
||||
# # x_current = pred.X - beta_ratio * x_grad
|
||||
# # e_current = pred.E - beta_ratio * e_grad
|
||||
# E_s = pred.X - beta_ratio * x_grad
|
||||
# X_s = pred.E - beta_ratio * e_grad
|
||||
|
||||
# update prob.X prob_E with using gradient
|
||||
|
||||
|
1
graph_dit/flex-nasbench201-graph.json
Normal file
1
graph_dit/flex-nasbench201-graph.json
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1
graph_dit/nasbench-201-meta.json
Normal file
1
graph_dit/nasbench-201-meta.json
Normal file
@@ -0,0 +1 @@
|
||||
{"source": "nasbench-201", "num_graph": 31250, "n_nodes_per_graph": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "max_n_nodes": 12, "max_n_edges": 12, "node_type_list": [0.08333333333333333, 0.12076, 0.121096, 0.12054933333333333, 0.120808, 0.35012, 0.08333333333333333, 0.0], "edge_type_list": [0.7757650537496, 0.22423494625039994], "valencies": [0.08333333333333333, 0.12076, 0.121096, 0.12054933333333333, 0.120808, 0.35012, 0.08333333333333333, 0.0], "active_nodes": ["*", "input", "nor_conv_1x1", "nor_conv_3x3", "avg_pool_3x3", "skip_connect", "none"], "num_active_nodes": 7, "transition_E": [[[1.0, 0.0], [0.4991939935961135, 0.5008060064038865], [0.5003633480874677, 0.4996366519125322], [0.49935849223554396, 0.500641507764456], [0.5018652186389422, 0.49813478136105777], [0.8275181842415934, 0.17248181575840665], [0.752416, 0.247584], [1.0, 0.0]], [[0.4991939935961135, 0.5008060064038865], [0.6929218744736044, 0.3070781255263956], [0.6891703482219348, 0.3108296517780652], [0.6909309288988885, 0.3090690711011114], [0.6876641745807234, 0.3123358254192766], [0.8913831706500085, 0.10861682934999148], [0.3948327260682345, 0.6051672739317655], [1.0, 0.0]], [[0.5003633480874677, 0.4996366519125322], [0.6891703482219348, 0.3108296517780652], [0.6877141129844832, 0.3122858870155169], [0.6899900524354673, 0.3100099475645327], [0.6869198878799577, 0.3130801121200423], [0.8910209102091021, 0.1089790897908979], [0.39503644491422785, 0.6049635550857722], [1.0, 0.0]], [[0.49935849223554396, 0.500641507764456], [0.6909309288988885, 0.3090690711011114], [0.6899900524354673, 0.3100099475645327], [0.6918940854215279, 0.30810591457847214], [0.6933245431647987, 0.30667545683520125], [0.8933821584543675, 0.10661784154563249], [0.3977348139627483, 0.6022651860372517], [1.0, 0.0]], [[0.5018652186389422, 0.49813478136105777], [0.6876641745807234, 0.3123358254192766], [0.6869198878799577, 0.3130801121200423], [0.6933245431647987, 0.30667545683520125], [0.6879391891891892, 0.31206081081081083], [0.8921497860953153, 0.10785021390468477], [0.39730260689137586, 0.6026973931086241], [1.0, 0.0]], [[0.8275181842415934, 0.17248181575840665], [0.8913831706500085, 0.10861682934999148], [0.8910209102091021, 0.1089790897908979], [0.8933821584543675, 0.10661784154563249], [0.8921497860953153, 0.10785021390468477], [0.9634043948311156, 0.03659560516888434], [0.79138581057923, 0.20861418942077004], [1.0, 0.0]], [[0.752416, 0.247584], [0.3948327260682345, 0.6051672739317655], [0.39503644491422785, 0.6049635550857722], [0.3977348139627483, 0.6022651860372517], [0.39730260689137586, 0.6026973931086241], [0.79138581057923, 0.20861418942077004], [1.0, 0.0], [1.0, 0.0]], [[1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0]]]}
|
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user