update a small problem
This commit is contained in:
		| @@ -359,15 +359,15 @@ def new_graphs_to_json(graphs, filename): | |||||||
|  |  | ||||||
|     node_name_list = [] |     node_name_list = [] | ||||||
|     node_count_list = [] |     node_count_list = [] | ||||||
|  |     node_name_list.append('*') | ||||||
|      |      | ||||||
|     for op_name in op_type: |     for op_name in op_type: | ||||||
|         node_name_list.append(op_name) |         node_name_list.append(op_name) | ||||||
|         node_count_list.append(0)  |         node_count_list.append(0)  | ||||||
|      |      | ||||||
|     node_name_list.append('*') |  | ||||||
|     node_count_list.append(0) |     node_count_list.append(0) | ||||||
|     n_nodes_per_graph = [0] * num_graph |     n_nodes_per_graph = [0] * num_graph | ||||||
|     edge_count_list = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] |     edge_count_list = [0, 0]  | ||||||
|     valencies = [0] * (len(op_type) + 1) |     valencies = [0] * (len(op_type) + 1) | ||||||
|     transition_E = np.zeros((len(op_type) + 1, len(op_type) + 1, 2)) |     transition_E = np.zeros((len(op_type) + 1, len(op_type) + 1, 2)) | ||||||
|  |  | ||||||
| @@ -388,16 +388,16 @@ def new_graphs_to_json(graphs, filename): | |||||||
|  |  | ||||||
|         for op in ops: |         for op in ops: | ||||||
|             node = op |             node = op | ||||||
|             if node == '*': |             # if node == '*': | ||||||
|                 node_count_list[-1] += 1 |             #     node_count_list[-1] += 1 | ||||||
|                 cur_node_count_arr[-1] += 1 |             #     cur_node_count_arr[-1] += 1 | ||||||
|             else: |             # else: | ||||||
|                 node_count_list[op_type[node]] += 1 |             node_count_list[node] += 1 | ||||||
|                 cur_node_count_arr[op_type[node]] += 1 |             cur_node_count_arr[node] += 1 | ||||||
|                 try: |             try: | ||||||
|                     valencies[int(op_type[node])] += 1 |                 valencies[node] += 1 | ||||||
|                 except: |             except: | ||||||
|                     print('int(op_type[node])', int(op_type[node])) |                 print('int(op_type[node])', int(node)) | ||||||
|          |          | ||||||
|         transition_E_temp = np.zeros((len(op_type) + 1, len(op_type) + 1, 2)) |         transition_E_temp = np.zeros((len(op_type) + 1, len(op_type) + 1, 2)) | ||||||
|         for i in range(n_node): |         for i in range(n_node): | ||||||
| @@ -406,8 +406,8 @@ def new_graphs_to_json(graphs, filename): | |||||||
|                     continue |                     continue | ||||||
|                 start_node, end_node = i, j |                 start_node, end_node = i, j | ||||||
|                  |                  | ||||||
|                 start_index = op_type[ops[start_node]] |                 start_index = ops[start_node] | ||||||
|                 end_index = op_type[ops[end_node]] |                 end_index = ops[end_node] | ||||||
|                 bond_index = 1 |                 bond_index = 1 | ||||||
|                 edge_count_list[bond_index] += 2 |                 edge_count_list[bond_index] += 2 | ||||||
|                  |                  | ||||||
| @@ -418,7 +418,7 @@ def new_graphs_to_json(graphs, filename): | |||||||
|  |  | ||||||
|         edge_count_list[0] += n_node * (n_node - 1) - n_edge * 2 |         edge_count_list[0] += n_node * (n_node - 1) - n_edge * 2 | ||||||
|         cur_tot_edge = cur_node_count_arr.reshape(-1,1) * cur_node_count_arr.reshape(1,-1) * 2 |         cur_tot_edge = cur_node_count_arr.reshape(-1,1) * cur_node_count_arr.reshape(1,-1) * 2 | ||||||
|         print(f"cur_tot_edge={cur_tot_edge}, shape: {cur_tot_edge.shape}") |         # print(f"cur_tot_edge={cur_tot_edge}, shape: {cur_tot_edge.shape}") | ||||||
|         cur_tot_edge = cur_tot_edge - np.diag(cur_node_count_arr) * 2 |         cur_tot_edge = cur_tot_edge - np.diag(cur_node_count_arr) * 2 | ||||||
|         transition_E[:, :, 0] += cur_tot_edge - transition_E_temp.sum(axis=-1) |         transition_E[:, :, 0] += cur_tot_edge - transition_E_temp.sum(axis=-1) | ||||||
|         assert (cur_tot_edge > transition_E_temp.sum(axis=-1)).sum() >= 0 |         assert (cur_tot_edge > transition_E_temp.sum(axis=-1)).sum() >= 0 | ||||||
| @@ -460,7 +460,7 @@ def new_graphs_to_json(graphs, filename): | |||||||
|         'transition_E': transition_E.tolist(), |         'transition_E': transition_E.tolist(), | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     with open(f'{filename}.meta.json', 'w') as f: |     with open(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f: | ||||||
|         json.dump(meta_dict, f) |         json.dump(meta_dict, f) | ||||||
|      |      | ||||||
|     return meta_dict |     return meta_dict | ||||||
| @@ -683,15 +683,41 @@ class Dataset(InMemoryDataset): | |||||||
|             active_nodes = set() |             active_nodes = set() | ||||||
|             for i in range(len_data): |             for i in range(len_data): | ||||||
|                 arch_info = self.api.query_meta_info_by_index(i) |                 arch_info = self.api.query_meta_info_by_index(i) | ||||||
|  |                 results = self.api.query_by_index(i, 'cifar100') | ||||||
|                 nodes, edges = parse_architecture_string(arch_info.arch_str) |                 nodes, edges = parse_architecture_string(arch_info.arch_str) | ||||||
|                 adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges) |                 adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges) | ||||||
|                 for op in ops: |                 for op in ops: | ||||||
|                     if op not in active_nodes: |                     if op not in active_nodes: | ||||||
|                         active_nodes.add(op) |                         active_nodes.add(op) | ||||||
|  |                  | ||||||
|                 graph_list.append({ |                 graph_list.append({ | ||||||
|                     "adj_matrix": adj_matrix, |                     "adj_matrix": adj_matrix, | ||||||
|                     "ops": ops, |                     "ops": ops, | ||||||
|                     "idx": i |                     "idx": i, | ||||||
|  |                     "train": [{ | ||||||
|  |                         "iepoch": result.get_train()['iepoch'], | ||||||
|  |                         "loss": result.get_train()['loss'], | ||||||
|  |                         "accuracy": result.get_train()['accuracy'], | ||||||
|  |                         "cur_time": result.get_train()['cur_time'], | ||||||
|  |                         "all_time": result.get_train()['all_time'], | ||||||
|  |                         "seed": seed, | ||||||
|  |                     }for seed, result in results.items()], | ||||||
|  |                     "valid": [{ | ||||||
|  |                         "iepoch": result.get_eval('x-valid')['iepoch'], | ||||||
|  |                         "loss": result.get_eval('x-valid')['loss'], | ||||||
|  |                         "accuracy": result.get_eval('x-valid')['accuracy'], | ||||||
|  |                         "cur_time": result.get_eval('x-valid')['cur_time'], | ||||||
|  |                         "all_time": result.get_eval('x-valid')['all_time'], | ||||||
|  |                         "seed": seed, | ||||||
|  |                     }for seed, result in results.items()], | ||||||
|  |                     "test": [{ | ||||||
|  |                         "iepoch": result.get_eval('x-test')['iepoch'], | ||||||
|  |                         "loss": result.get_eval('x-test')['loss'], | ||||||
|  |                         "accuracy": result.get_eval('x-test')['accuracy'], | ||||||
|  |                         "cur_time": result.get_eval('x-test')['cur_time'], | ||||||
|  |                         "all_time": result.get_eval('x-test')['all_time'], | ||||||
|  |                         "seed": seed, | ||||||
|  |                     }for seed, result in results.items()] | ||||||
|                 }) |                 }) | ||||||
|                 data = graph_to_graph_data((adj_matrix, ops))  |                 data = graph_to_graph_data((adj_matrix, ops))  | ||||||
|                 data_list.append(data) |                 data_list.append(data) | ||||||
| @@ -925,8 +951,9 @@ class DataInfos(AbstractDatasetInfos): | |||||||
|  |  | ||||||
|             adj_ops_pairs = [] |             adj_ops_pairs = [] | ||||||
|             for item in data: |             for item in data: | ||||||
|                 adj_matrix = np.array(item['adjacency_matrix']) |                 adj_matrix = np.array(item['adj_matrix']) | ||||||
|                 ops = item['operations'] |                 ops = item['ops'] | ||||||
|  |                 ops = [op_type[op] for op in ops] | ||||||
|                 adj_ops_pairs.append((adj_matrix, ops)) |                 adj_ops_pairs.append((adj_matrix, ops)) | ||||||
|              |              | ||||||
|             return adj_ops_pairs |             return adj_ops_pairs | ||||||
| @@ -944,7 +971,7 @@ class DataInfos(AbstractDatasetInfos): | |||||||
|             #         ops_type[op] = len(ops_type) |             #         ops_type[op] = len(ops_type) | ||||||
|             # len_ops.add(len(ops)) |             # len_ops.add(len(ops)) | ||||||
|             # graphs.append((adj_matrix, ops)) |             # graphs.append((adj_matrix, ops)) | ||||||
|         graphs = read_adj_ops_from_json(f'nasbench-201.meta.json') |         graphs = read_adj_ops_from_json(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json') | ||||||
|  |  | ||||||
|         # check first five graphs |         # check first five graphs | ||||||
|         for i in range(5): |         for i in range(5): | ||||||
| @@ -1158,7 +1185,7 @@ def compute_meta(root, source_name, train_index, test_index): | |||||||
|         'transition_E': tansition_E.tolist(), |         'transition_E': tansition_E.tolist(), | ||||||
|         } |         } | ||||||
|  |  | ||||||
|     with open(f'{root}/{source_name}.meta.json', "w") as f: |     with open(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f: | ||||||
|         json.dump(meta_dict, f) |         json.dump(meta_dict, f) | ||||||
|      |      | ||||||
|     return meta_dict |     return meta_dict | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user