update the flex data code

This commit is contained in:
mhz 2024-08-13 09:42:51 +02:00
parent 83f9345028
commit 7149b49a39
2 changed files with 152 additions and 115 deletions

View File

@ -70,7 +70,7 @@ class DataModule(AbstractDataModule):
# base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
# except NameError:
# base_path = pathlib.Path(os.getcwd()).parent[2]
base_path = '/home/stud/hanzhang/nasbenchDiT'
base_path = '/nfs/data3/hanzhang/nasbenchDiT'
root_path = os.path.join(base_path, self.datadir)
self.root_path = root_path
@ -408,6 +408,7 @@ def new_graphs_to_json(graphs, filename):
adj = graph[0]
n_node = len(ops)
print(n_node)
n_edge = len(ops)
n_node_list.append(n_node)
n_edge_list.append(n_edge)
@ -489,7 +490,7 @@ def new_graphs_to_json(graphs, filename):
'transition_E': transition_E.tolist(),
}
with open(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f:
with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f:
json.dump(meta_dict, f)
return meta_dict
@ -655,7 +656,7 @@ def graphs_to_json(graphs, filename):
class Dataset(InMemoryDataset):
def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):
self.target_prop = target_prop
source = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
self.source = source
# self.api = API(source) # Initialize NAS-Bench-201 API
# print('API loaded')
@ -676,7 +677,7 @@ class Dataset(InMemoryDataset):
return [f'{self.source}.pt']
def process(self):
source = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
# self.api = API(source)
data_list = []
@ -712,6 +713,7 @@ class Dataset(InMemoryDataset):
def generate_flex_adj_mat(ori_nodes, ori_edges, max_nodes=12, min_nodes=8,random_ratio=0.5):
# print(ori_nodes)
# print(ori_edges)
ori_edges = np.array(ori_edges)
# ori_nodes = np.array(ori_nodes)
nasbench_201_node_num = 8
@ -720,8 +722,13 @@ class Dataset(InMemoryDataset):
# print(f'arch_str: {arch_str}, \nmax_nodes: {max_nodes}, min_nodes: {min_nodes}, nodes_num: {nodes_num},random_seed: {random_seed},random_ratio: {random_ratio}')
add_num = nodes_num - nasbench_201_node_num
# ori_nodes, ori_edges = parse_architecture_string(arch_str)
add_nodes = [op for op in random.choices(num_to_op[1:-1], k=add_num)]
add_nodes = []
print(f'add_num: {add_num}')
for i in range(add_num):
add_nodes.append(random.choice(num_to_op[1:-1]))
# print(add_nodes)
print(f'ori_nodes[:-1]: {ori_nodes[:-1]}, add_nodes: {add_nodes}')
print(f'len(ori_nodes[:-1]): {len(ori_nodes[:-1])}, len(add_nodes): {len(add_nodes)}')
nodes = ori_nodes[:-1] + add_nodes + ['output']
edges = np.zeros((nodes_num , nodes_num))
edges[:6, :6] = ori_edges[:6, :6]
@ -731,6 +738,11 @@ class Dataset(InMemoryDataset):
rand = random.random()
if rand < random_ratio:
edges[i, j] = 1
if nodes_num < max_nodes:
edges = np.pad(edges, ((0, max_nodes - nodes_num), (0, max_nodes - nodes_num)), 'constant',constant_values=0)
while len(nodes) < max_nodes:
nodes.append('none')
print(f'edges size: {edges.shape}, nodes size: {len(nodes)}')
return edges,nodes
def get_nasbench_201_val(idx):
@ -766,10 +778,12 @@ class Dataset(InMemoryDataset):
with tqdm(total = len_data) as pbar:
active_nodes = set()
file_path = '/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json'
file_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json'
with open(file_path, 'r') as f:
graph_list = json.load(f)
i = 0
flex_graph_list = []
flex_graph_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json'
for graph in graph_list:
# arch_info = self.api.query_meta_info_by_index(i)
# results = self.api.query_by_index(i, 'cifar100')
@ -784,6 +798,16 @@ class Dataset(InMemoryDataset):
active_nodes.add(op)
data = graph_to_graph_data((adj_matrix, ops))
# with open(flex_graph_path, 'a') as f:
# flex_graph = {
# 'adj_matrix': adj_matrix,
# 'ops': ops,
# }
# json.dump(flex_graph, f)
flex_graph_list.append({
'adj_matrix':adj_matrix,
'ops': ops,
})
if i < 3:
print(f"i={i}, data={data}")
with open(f'{i}.json', 'w') as f:
@ -792,7 +816,17 @@ class Dataset(InMemoryDataset):
f.write(str(data.edge_attr))
data_list.append(data)
new_adj, new_ops = generate_flex_adj_mat(ori_nodes=ori_nodes, ori_edges=ori_adj, max_nodes=12, min_nodes=8, random_ratio=0.5)
new_adj, new_ops = generate_flex_adj_mat(ori_nodes=ori_nodes, ori_edges=ori_adj, max_nodes=12, min_nodes=9, random_ratio=0.5)
flex_graph_list.append({
'adj_matrix':new_adj.tolist(),
'ops': new_ops,
})
# with open(flex_graph_path, 'w') as f:
# flex_graph = {
# 'adj_matrix': new_adj.tolist(),
# 'ops': new_ops,
# }
# json.dump(flex_graph, f)
data_list.append(graph_to_graph_data((new_adj, new_ops)))
# graph_list.append({
@ -838,6 +872,8 @@ class Dataset(InMemoryDataset):
graph['ops'] = ops
with open(f'nasbench-201-graph.json', 'w') as f:
json.dump(graph_list, f)
with open(flex_graph_path, 'w') as f:
json.dump(flex_graph_list, f)
torch.save(self.collate(data_list), self.processed_paths[0])
@ -1034,8 +1070,8 @@ def parse_architecture_string(arch_str, padding=0):
assert idx == steps_coding[cont]
cont += 1
nodes.append(n)
ori_nodes = nodes.copy()
nodes.append('output') # Add output node
ori_nodes = nodes.copy()
if padding > 0:
for i in range(padding):
nodes.append('none')
@ -1048,7 +1084,7 @@ def parse_architecture_string(arch_str, padding=0):
# print(nodes)
# print(adj_mat)
# print(len(adj_mat))
# print(f'len(ori_nodes): {len(ori_nodes)}, len(nodes): {len(nodes)}')
return nodes, adj_mat, ori_nodes, ori_adj_mat
def create_adj_matrix_and_ops(nodes, edges):
@ -1091,6 +1127,7 @@ class DataInfos(AbstractDatasetInfos):
adj_ops_pairs = []
for item in data:
print(item)
adj_matrix = np.array(item['adj_matrix'])
ops = item['ops']
ops = [op_type[op] for op in ops]
@ -1111,12 +1148,12 @@ class DataInfos(AbstractDatasetInfos):
# ops_type[op] = len(ops_type)
# len_ops.add(len(ops))
# graphs.append((adj_matrix, ops))
graphs = read_adj_ops_from_json(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json')
graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json')
# check first five graphs
for i in range(5):
print(f'graph {i} : {graphs[i]}')
print(f'ops_type: {ops_type}')
# print(f'ops_type: {ops_type}')
meta_dict = new_graphs_to_json(graphs, 'nasbench-201')
self.base_path = base_path
@ -1325,11 +1362,11 @@ def compute_meta(root, source_name, train_index, test_index):
'transition_E': tansition_E.tolist(),
}
with open(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f:
with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f:
json.dump(meta_dict, f)
return meta_dict
if __name__ == "__main__":
dataset = Dataset(source='nasbench', root='/home/stud/hanzhang/nasbenchDiT/graph-dit', target_prop='Class', transform=None)
dataset = Dataset(source='nasbench', root='/nfs/data3/hanzhang/nasbenchDiT/graph-dit', target_prop='Class', transform=None)

View File

@ -3,9 +3,9 @@ import torch.nn.functional as F
import pytorch_lightning as pl
import time
import os
from naswot.score_networks import get_nasbench201_nodes_score
from naswot import nasspace
from naswot import datasets
# from naswot.score_networks import get_nasbench201_nodes_score
# from naswot import nasspace
# from naswot import datasets
from models.transformer import Denoiser
from diffusion.noise_schedule import PredefinedNoiseScheduleDiscrete, MarginalTransition
@ -41,7 +41,7 @@ class Graph_DiT(pl.LightningModule):
self.args.batch_size = 128
self.args.GPU = '0'
self.args.dataset = 'cifar10-valid'
self.args.api_loc = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
self.args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
self.args.data_loc = '../cifardata/'
self.args.seed = 777
self.args.init = ''
@ -59,10 +59,10 @@ class Graph_DiT(pl.LightningModule):
if 'valid' in self.args.dataset:
self.args.dataset = self.args.dataset.replace('-valid', '')
print('graph_dit starts to get searchspace of nasbench201')
self.searchspace = nasspace.get_search_space(self.args)
# self.searchspace = nasspace.get_search_space(self.args)
print('searchspace of nasbench201 is obtained')
print('graphdit starts to get train_loader')
self.train_loader = datasets.get_data(self.args.dataset, self.args.data_loc, self.args.trainval, self.args.batch_size, self.args.augtype, self.args.repeat, self.args)
# self.train_loader = datasets.get_data(self.args.dataset, self.args.data_loc, self.args.trainval, self.args.batch_size, self.args.augtype, self.args.repeat, self.args)
print('train_loader is obtained')
self.cfg = cfg
@ -162,7 +162,7 @@ class Graph_DiT(pl.LightningModule):
return pred
def training_step(self, data, i):
data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index]
data_x = F.one_hot(data.x, num_classes=12).float()[:, self.active_index]
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
@ -222,7 +222,7 @@ class Graph_DiT(pl.LightningModule):
@torch.no_grad()
def validation_step(self, data, i):
data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index]
data_x = F.one_hot(data.x, num_classes=12).float()[:, self.active_index]
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
dense_data = dense_data.mask(node_mask, collapse=False)
@ -315,7 +315,7 @@ class Graph_DiT(pl.LightningModule):
@torch.no_grad()
def test_step(self, data, i):
data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index]
data_x = F.one_hot(data.x, num_classes=12).float()[:, self.active_index]
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
@ -686,120 +686,120 @@ class Graph_DiT(pl.LightningModule):
assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-4).all()
assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-4).all()
# sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
# sample multiple times and get the best score arch...
num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
op_type = {
'input': 0,
'nor_conv_1x1': 1,
'nor_conv_3x3': 2,
'avg_pool_3x3': 3,
'skip_connect': 4,
'none': 5,
'output': 6,
}
def check_valid_graph(nodes, edges):
nodes = [num_to_op[i] for i in nodes]
if len(nodes) != edges.shape[0] or len(nodes) != edges.shape[1]:
return False
if nodes[0] != 'input' or nodes[-1] != 'output':
return False
for i in range(0, len(nodes)):
if edges[i][i] == 1:
return False
for i in range(1, len(nodes) - 1):
if nodes[i] not in op_type or nodes[i] == 'input' or nodes[i] == 'output':
return False
for i in range(0, len(nodes)):
for j in range(i, len(nodes)):
if edges[i, j] == 1 and nodes[j] == 'input':
return False
for i in range(0, len(nodes)):
for j in range(i, len(nodes)):
if edges[i, j] == 1 and nodes[i] == 'output':
return False
flag = 0
for i in range(0,len(nodes)):
if edges[i,-1] == 1:
flag = 1
break
if flag == 0: return False
return True
# num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
# op_type = {
# 'input': 0,
# 'nor_conv_1x1': 1,
# 'nor_conv_3x3': 2,
# 'avg_pool_3x3': 3,
# 'skip_connect': 4,
# 'none': 5,
# 'output': 6,
# }
# def check_valid_graph(nodes, edges):
# nodes = [num_to_op[i] for i in nodes]
# if len(nodes) != edges.shape[0] or len(nodes) != edges.shape[1]:
# return False
# if nodes[0] != 'input' or nodes[-1] != 'output':
# return False
# for i in range(0, len(nodes)):
# if edges[i][i] == 1:
# return False
# for i in range(1, len(nodes) - 1):
# if nodes[i] not in op_type or nodes[i] == 'input' or nodes[i] == 'output':
# return False
# for i in range(0, len(nodes)):
# for j in range(i, len(nodes)):
# if edges[i, j] == 1 and nodes[j] == 'input':
# return False
# for i in range(0, len(nodes)):
# for j in range(i, len(nodes)):
# if edges[i, j] == 1 and nodes[i] == 'output':
# return False
# flag = 0
# for i in range(0,len(nodes)):
# if edges[i,-1] == 1:
# flag = 1
# break
# if flag == 0: return False
# return True
class Args:
pass
# class Args:
# pass
def get_score(sampled_s):
x_list = sampled_s.X.unbind(dim=0)
e_list = sampled_s.E.unbind(dim=0)
valid_rlt = [check_valid_graph(x_list[i].cpu().numpy(), e_list[i].cpu().numpy()) for i in range(len(x_list))]
from graph_dit.naswot.naswot.score_networks import get_nasbench201_nodes_score
score = []
# def get_score(sampled_s):
# x_list = sampled_s.X.unbind(dim=0)
# e_list = sampled_s.E.unbind(dim=0)
# valid_rlt = [check_valid_graph(x_list[i].cpu().numpy(), e_list[i].cpu().numpy()) for i in range(len(x_list))]
# from graph_dit.naswot.naswot.score_networks import get_nasbench201_nodes_score
# score = []
for i in range(len(x_list)):
if valid_rlt[i]:
nodes = [num_to_op[j] for j in x_list[i].cpu().numpy()]
# edges = e_list[i].cpu().numpy()
score.append(get_nasbench201_nodes_score(nodes,train_loader=self.train_loader,searchspace=self.searchspace,device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu") , args=self.args))
else:
score.append(-1)
return torch.tensor(score, dtype=torch.float32, requires_grad=True).to(x_list[0].device)
# for i in range(len(x_list)):
# if valid_rlt[i]:
# nodes = [num_to_op[j] for j in x_list[i].cpu().numpy()]
# # edges = e_list[i].cpu().numpy()
# score.append(get_nasbench201_nodes_score(nodes,train_loader=self.train_loader,searchspace=self.searchspace,device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu") , args=self.args))
# else:
# score.append(-1)
# return torch.tensor(score, dtype=torch.float32, requires_grad=True).to(x_list[0].device)
sample_num = 10
best_arch = None
best_score_int = -1e8
score = torch.ones(100, dtype=torch.float32, requires_grad=True) * -1e8
# sample_num = 10
# best_arch = None
# best_score_int = -1e8
# score = torch.ones(100, dtype=torch.float32, requires_grad=True) * -1e8
for i in range(sample_num):
sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
score = get_score(sampled_s)
print(f'score: {score}')
print(f'score.shape: {score.shape}')
print(f'torch.sum(score): {torch.sum(score)}')
sum_score = torch.sum(score)
print(f'sum_score: {sum_score}')
if sum_score > best_score_int:
best_score_int = sum_score
best_score = score
best_arch = sampled_s
# for i in range(sample_num):
# sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
# score = get_score(sampled_s)
# print(f'score: {score}')
# print(f'score.shape: {score.shape}')
# print(f'torch.sum(score): {torch.sum(score)}')
# sum_score = torch.sum(score)
# print(f'sum_score: {sum_score}')
# if sum_score > best_score_int:
# best_score_int = sum_score
# best_score = score
# best_arch = sampled_s
# print(f'prob_X: {prob_X.shape}, prob_E: {prob_E.shape}')
# best_arch = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
# X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float()
# E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float()
print(f'best_arch.X: {best_arch.X.shape}, best_arch.E: {best_arch.E.shape}') # 100 8 8, bs n n, 100 8 8 2, bs n n 2
X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float()
E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float()
# print(f'best_arch.X: {best_arch.X.shape}, best_arch.E: {best_arch.E.shape}') # 100 8 8, bs n n, 100 8 8 2, bs n n 2
print(f'best_arch.X: {best_arch.X}, best_arch.E: {best_arch.E}')
X_s = F.one_hot(best_arch.X, num_classes=self.Xdim_output).float()
E_s = F.one_hot(best_arch.E, num_classes=self.Edim_output).float()
print(f'X_s: {X_s}, E_s: {E_s}')
# print(f'best_arch.X: {best_arch.X}, best_arch.E: {best_arch.E}')
# X_s = F.one_hot(best_arch.X, num_classes=self.Xdim_output).float()
# E_s = F.one_hot(best_arch.E, num_classes=self.Edim_output).float()
# print(f'X_s: {X_s}, E_s: {E_s}')
# NASWOT score
target_score = torch.ones(100, requires_grad=True) * 2000.0
target_score = target_score.to(X_s.device)
# # NASWOT score
# target_score = torch.ones(100, requires_grad=True) * 2000.0
# target_score = target_score.to(X_s.device)
# compute loss mse(cur_score - target_score)
mse_loss = torch.nn.MSELoss()
print(f'best_score: {best_score.shape}, target_score: {target_score.shape}')
print(f'best_score.requires_grad: {best_score.requires_grad}, target_score.requires_grad: {target_score.requires_grad}')
loss = mse_loss(best_score, target_score)
loss.backward(retain_graph=True)
# # compute loss mse(cur_score - target_score)
# mse_loss = torch.nn.MSELoss()
# print(f'best_score: {best_score.shape}, target_score: {target_score.shape}')
# print(f'best_score.requires_grad: {best_score.requires_grad}, target_score.requires_grad: {target_score.requires_grad}')
# loss = mse_loss(best_score, target_score)
# loss.backward(retain_graph=True)
# loss backward = gradient
# get prob.X, prob_E gradient
x_grad = pred.X.grad
e_grad = pred.E.grad
# x_grad = pred.X.grad
# e_grad = pred.E.grad
beta_ratio = 0.5
# x_current = pred.X - beta_ratio * x_grad
# e_current = pred.E - beta_ratio * e_grad
E_s = pred.X - beta_ratio * x_grad
X_s = pred.E - beta_ratio * e_grad
# beta_ratio = 0.5
# # x_current = pred.X - beta_ratio * x_grad
# # e_current = pred.E - beta_ratio * e_grad
# E_s = pred.X - beta_ratio * x_grad
# X_s = pred.E - beta_ratio * e_grad
# update prob.X prob_E with using gradient