update a small problem

This commit is contained in:
mhz 2024-06-30 19:41:31 +02:00
parent 0fc6f6e686
commit be8bb16f61

View File

@ -359,15 +359,15 @@ def new_graphs_to_json(graphs, filename):
node_name_list = []
node_count_list = []
node_name_list.append('*')
for op_name in op_type:
node_name_list.append(op_name)
node_count_list.append(0)
node_name_list.append('*')
node_count_list.append(0)
n_nodes_per_graph = [0] * num_graph
edge_count_list = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
edge_count_list = [0, 0]
valencies = [0] * (len(op_type) + 1)
transition_E = np.zeros((len(op_type) + 1, len(op_type) + 1, 2))
@ -388,16 +388,16 @@ def new_graphs_to_json(graphs, filename):
for op in ops:
node = op
if node == '*':
node_count_list[-1] += 1
cur_node_count_arr[-1] += 1
else:
node_count_list[op_type[node]] += 1
cur_node_count_arr[op_type[node]] += 1
try:
valencies[int(op_type[node])] += 1
except:
print('int(op_type[node])', int(op_type[node]))
# if node == '*':
# node_count_list[-1] += 1
# cur_node_count_arr[-1] += 1
# else:
node_count_list[node] += 1
cur_node_count_arr[node] += 1
try:
valencies[node] += 1
except:
print('int(op_type[node])', int(node))
transition_E_temp = np.zeros((len(op_type) + 1, len(op_type) + 1, 2))
for i in range(n_node):
@ -406,8 +406,8 @@ def new_graphs_to_json(graphs, filename):
continue
start_node, end_node = i, j
start_index = op_type[ops[start_node]]
end_index = op_type[ops[end_node]]
start_index = ops[start_node]
end_index = ops[end_node]
bond_index = 1
edge_count_list[bond_index] += 2
@ -418,7 +418,7 @@ def new_graphs_to_json(graphs, filename):
edge_count_list[0] += n_node * (n_node - 1) - n_edge * 2
cur_tot_edge = cur_node_count_arr.reshape(-1,1) * cur_node_count_arr.reshape(1,-1) * 2
print(f"cur_tot_edge={cur_tot_edge}, shape: {cur_tot_edge.shape}")
# print(f"cur_tot_edge={cur_tot_edge}, shape: {cur_tot_edge.shape}")
cur_tot_edge = cur_tot_edge - np.diag(cur_node_count_arr) * 2
transition_E[:, :, 0] += cur_tot_edge - transition_E_temp.sum(axis=-1)
assert (cur_tot_edge > transition_E_temp.sum(axis=-1)).sum() >= 0
@ -460,7 +460,7 @@ def new_graphs_to_json(graphs, filename):
'transition_E': transition_E.tolist(),
}
with open(f'{filename}.meta.json', 'w') as f:
with open(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f:
json.dump(meta_dict, f)
return meta_dict
@ -683,15 +683,41 @@ class Dataset(InMemoryDataset):
active_nodes = set()
for i in range(len_data):
arch_info = self.api.query_meta_info_by_index(i)
results = self.api.query_by_index(i, 'cifar100')
nodes, edges = parse_architecture_string(arch_info.arch_str)
adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges)
for op in ops:
if op not in active_nodes:
active_nodes.add(op)
graph_list.append({
"adj_matrix": adj_matrix,
"ops": ops,
"idx": i
"idx": i,
"train": [{
"iepoch": result.get_train()['iepoch'],
"loss": result.get_train()['loss'],
"accuracy": result.get_train()['accuracy'],
"cur_time": result.get_train()['cur_time'],
"all_time": result.get_train()['all_time'],
"seed": seed,
}for seed, result in results.items()],
"valid": [{
"iepoch": result.get_eval('x-valid')['iepoch'],
"loss": result.get_eval('x-valid')['loss'],
"accuracy": result.get_eval('x-valid')['accuracy'],
"cur_time": result.get_eval('x-valid')['cur_time'],
"all_time": result.get_eval('x-valid')['all_time'],
"seed": seed,
}for seed, result in results.items()],
"test": [{
"iepoch": result.get_eval('x-test')['iepoch'],
"loss": result.get_eval('x-test')['loss'],
"accuracy": result.get_eval('x-test')['accuracy'],
"cur_time": result.get_eval('x-test')['cur_time'],
"all_time": result.get_eval('x-test')['all_time'],
"seed": seed,
}for seed, result in results.items()]
})
data = graph_to_graph_data((adj_matrix, ops))
data_list.append(data)
@ -925,8 +951,9 @@ class DataInfos(AbstractDatasetInfos):
adj_ops_pairs = []
for item in data:
adj_matrix = np.array(item['adjacency_matrix'])
ops = item['operations']
adj_matrix = np.array(item['adj_matrix'])
ops = item['ops']
ops = [op_type[op] for op in ops]
adj_ops_pairs.append((adj_matrix, ops))
return adj_ops_pairs
@ -944,7 +971,7 @@ class DataInfos(AbstractDatasetInfos):
# ops_type[op] = len(ops_type)
# len_ops.add(len(ops))
# graphs.append((adj_matrix, ops))
graphs = read_adj_ops_from_json(f'nasbench-201.meta.json')
graphs = read_adj_ops_from_json(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json')
# check first five graphs
for i in range(5):
@ -1158,7 +1185,7 @@ def compute_meta(root, source_name, train_index, test_index):
'transition_E': tansition_E.tolist(),
}
with open(f'{root}/{source_name}.meta.json', "w") as f:
with open(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f:
json.dump(meta_dict, f)
return meta_dict