update a small problem
This commit is contained in:
parent
0fc6f6e686
commit
be8bb16f61
@ -359,15 +359,15 @@ def new_graphs_to_json(graphs, filename):
|
||||
|
||||
node_name_list = []
|
||||
node_count_list = []
|
||||
node_name_list.append('*')
|
||||
|
||||
for op_name in op_type:
|
||||
node_name_list.append(op_name)
|
||||
node_count_list.append(0)
|
||||
|
||||
node_name_list.append('*')
|
||||
node_count_list.append(0)
|
||||
n_nodes_per_graph = [0] * num_graph
|
||||
edge_count_list = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
||||
edge_count_list = [0, 0]
|
||||
valencies = [0] * (len(op_type) + 1)
|
||||
transition_E = np.zeros((len(op_type) + 1, len(op_type) + 1, 2))
|
||||
|
||||
@ -388,16 +388,16 @@ def new_graphs_to_json(graphs, filename):
|
||||
|
||||
for op in ops:
|
||||
node = op
|
||||
if node == '*':
|
||||
node_count_list[-1] += 1
|
||||
cur_node_count_arr[-1] += 1
|
||||
else:
|
||||
node_count_list[op_type[node]] += 1
|
||||
cur_node_count_arr[op_type[node]] += 1
|
||||
try:
|
||||
valencies[int(op_type[node])] += 1
|
||||
except:
|
||||
print('int(op_type[node])', int(op_type[node]))
|
||||
# if node == '*':
|
||||
# node_count_list[-1] += 1
|
||||
# cur_node_count_arr[-1] += 1
|
||||
# else:
|
||||
node_count_list[node] += 1
|
||||
cur_node_count_arr[node] += 1
|
||||
try:
|
||||
valencies[node] += 1
|
||||
except:
|
||||
print('int(op_type[node])', int(node))
|
||||
|
||||
transition_E_temp = np.zeros((len(op_type) + 1, len(op_type) + 1, 2))
|
||||
for i in range(n_node):
|
||||
@ -406,8 +406,8 @@ def new_graphs_to_json(graphs, filename):
|
||||
continue
|
||||
start_node, end_node = i, j
|
||||
|
||||
start_index = op_type[ops[start_node]]
|
||||
end_index = op_type[ops[end_node]]
|
||||
start_index = ops[start_node]
|
||||
end_index = ops[end_node]
|
||||
bond_index = 1
|
||||
edge_count_list[bond_index] += 2
|
||||
|
||||
@ -418,7 +418,7 @@ def new_graphs_to_json(graphs, filename):
|
||||
|
||||
edge_count_list[0] += n_node * (n_node - 1) - n_edge * 2
|
||||
cur_tot_edge = cur_node_count_arr.reshape(-1,1) * cur_node_count_arr.reshape(1,-1) * 2
|
||||
print(f"cur_tot_edge={cur_tot_edge}, shape: {cur_tot_edge.shape}")
|
||||
# print(f"cur_tot_edge={cur_tot_edge}, shape: {cur_tot_edge.shape}")
|
||||
cur_tot_edge = cur_tot_edge - np.diag(cur_node_count_arr) * 2
|
||||
transition_E[:, :, 0] += cur_tot_edge - transition_E_temp.sum(axis=-1)
|
||||
assert (cur_tot_edge > transition_E_temp.sum(axis=-1)).sum() >= 0
|
||||
@ -460,7 +460,7 @@ def new_graphs_to_json(graphs, filename):
|
||||
'transition_E': transition_E.tolist(),
|
||||
}
|
||||
|
||||
with open(f'{filename}.meta.json', 'w') as f:
|
||||
with open(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f:
|
||||
json.dump(meta_dict, f)
|
||||
|
||||
return meta_dict
|
||||
@ -683,15 +683,41 @@ class Dataset(InMemoryDataset):
|
||||
active_nodes = set()
|
||||
for i in range(len_data):
|
||||
arch_info = self.api.query_meta_info_by_index(i)
|
||||
results = self.api.query_by_index(i, 'cifar100')
|
||||
nodes, edges = parse_architecture_string(arch_info.arch_str)
|
||||
adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges)
|
||||
for op in ops:
|
||||
if op not in active_nodes:
|
||||
active_nodes.add(op)
|
||||
|
||||
graph_list.append({
|
||||
"adj_matrix": adj_matrix,
|
||||
"ops": ops,
|
||||
"idx": i
|
||||
"idx": i,
|
||||
"train": [{
|
||||
"iepoch": result.get_train()['iepoch'],
|
||||
"loss": result.get_train()['loss'],
|
||||
"accuracy": result.get_train()['accuracy'],
|
||||
"cur_time": result.get_train()['cur_time'],
|
||||
"all_time": result.get_train()['all_time'],
|
||||
"seed": seed,
|
||||
}for seed, result in results.items()],
|
||||
"valid": [{
|
||||
"iepoch": result.get_eval('x-valid')['iepoch'],
|
||||
"loss": result.get_eval('x-valid')['loss'],
|
||||
"accuracy": result.get_eval('x-valid')['accuracy'],
|
||||
"cur_time": result.get_eval('x-valid')['cur_time'],
|
||||
"all_time": result.get_eval('x-valid')['all_time'],
|
||||
"seed": seed,
|
||||
}for seed, result in results.items()],
|
||||
"test": [{
|
||||
"iepoch": result.get_eval('x-test')['iepoch'],
|
||||
"loss": result.get_eval('x-test')['loss'],
|
||||
"accuracy": result.get_eval('x-test')['accuracy'],
|
||||
"cur_time": result.get_eval('x-test')['cur_time'],
|
||||
"all_time": result.get_eval('x-test')['all_time'],
|
||||
"seed": seed,
|
||||
}for seed, result in results.items()]
|
||||
})
|
||||
data = graph_to_graph_data((adj_matrix, ops))
|
||||
data_list.append(data)
|
||||
@ -925,8 +951,9 @@ class DataInfos(AbstractDatasetInfos):
|
||||
|
||||
adj_ops_pairs = []
|
||||
for item in data:
|
||||
adj_matrix = np.array(item['adjacency_matrix'])
|
||||
ops = item['operations']
|
||||
adj_matrix = np.array(item['adj_matrix'])
|
||||
ops = item['ops']
|
||||
ops = [op_type[op] for op in ops]
|
||||
adj_ops_pairs.append((adj_matrix, ops))
|
||||
|
||||
return adj_ops_pairs
|
||||
@ -944,7 +971,7 @@ class DataInfos(AbstractDatasetInfos):
|
||||
# ops_type[op] = len(ops_type)
|
||||
# len_ops.add(len(ops))
|
||||
# graphs.append((adj_matrix, ops))
|
||||
graphs = read_adj_ops_from_json(f'nasbench-201.meta.json')
|
||||
graphs = read_adj_ops_from_json(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json')
|
||||
|
||||
# check first five graphs
|
||||
for i in range(5):
|
||||
@ -1158,7 +1185,7 @@ def compute_meta(root, source_name, train_index, test_index):
|
||||
'transition_E': tansition_E.tolist(),
|
||||
}
|
||||
|
||||
with open(f'{root}/{source_name}.meta.json', "w") as f:
|
||||
with open(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f:
|
||||
json.dump(meta_dict, f)
|
||||
|
||||
return meta_dict
|
||||
|
Loading…
Reference in New Issue
Block a user