update a small problem
This commit is contained in:
		| @@ -359,15 +359,15 @@ def new_graphs_to_json(graphs, filename): | ||||
|  | ||||
|     node_name_list = [] | ||||
|     node_count_list = [] | ||||
|     node_name_list.append('*') | ||||
|      | ||||
|     for op_name in op_type: | ||||
|         node_name_list.append(op_name) | ||||
|         node_count_list.append(0)  | ||||
|      | ||||
|     node_name_list.append('*') | ||||
|     node_count_list.append(0) | ||||
|     n_nodes_per_graph = [0] * num_graph | ||||
|     edge_count_list = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] | ||||
|     edge_count_list = [0, 0]  | ||||
|     valencies = [0] * (len(op_type) + 1) | ||||
|     transition_E = np.zeros((len(op_type) + 1, len(op_type) + 1, 2)) | ||||
|  | ||||
| @@ -388,16 +388,16 @@ def new_graphs_to_json(graphs, filename): | ||||
|  | ||||
|         for op in ops: | ||||
|             node = op | ||||
|             if node == '*': | ||||
|                 node_count_list[-1] += 1 | ||||
|                 cur_node_count_arr[-1] += 1 | ||||
|             else: | ||||
|                 node_count_list[op_type[node]] += 1 | ||||
|                 cur_node_count_arr[op_type[node]] += 1 | ||||
|                 try: | ||||
|                     valencies[int(op_type[node])] += 1 | ||||
|                 except: | ||||
|                     print('int(op_type[node])', int(op_type[node])) | ||||
|             # if node == '*': | ||||
|             #     node_count_list[-1] += 1 | ||||
|             #     cur_node_count_arr[-1] += 1 | ||||
|             # else: | ||||
|             node_count_list[node] += 1 | ||||
|             cur_node_count_arr[node] += 1 | ||||
|             try: | ||||
|                 valencies[node] += 1 | ||||
|             except: | ||||
|                 print('int(op_type[node])', int(node)) | ||||
|          | ||||
|         transition_E_temp = np.zeros((len(op_type) + 1, len(op_type) + 1, 2)) | ||||
|         for i in range(n_node): | ||||
| @@ -406,8 +406,8 @@ def new_graphs_to_json(graphs, filename): | ||||
|                     continue | ||||
|                 start_node, end_node = i, j | ||||
|                  | ||||
|                 start_index = op_type[ops[start_node]] | ||||
|                 end_index = op_type[ops[end_node]] | ||||
|                 start_index = ops[start_node] | ||||
|                 end_index = ops[end_node] | ||||
|                 bond_index = 1 | ||||
|                 edge_count_list[bond_index] += 2 | ||||
|                  | ||||
| @@ -418,7 +418,7 @@ def new_graphs_to_json(graphs, filename): | ||||
|  | ||||
|         edge_count_list[0] += n_node * (n_node - 1) - n_edge * 2 | ||||
|         cur_tot_edge = cur_node_count_arr.reshape(-1,1) * cur_node_count_arr.reshape(1,-1) * 2 | ||||
|         print(f"cur_tot_edge={cur_tot_edge}, shape: {cur_tot_edge.shape}") | ||||
|         # print(f"cur_tot_edge={cur_tot_edge}, shape: {cur_tot_edge.shape}") | ||||
|         cur_tot_edge = cur_tot_edge - np.diag(cur_node_count_arr) * 2 | ||||
|         transition_E[:, :, 0] += cur_tot_edge - transition_E_temp.sum(axis=-1) | ||||
|         assert (cur_tot_edge > transition_E_temp.sum(axis=-1)).sum() >= 0 | ||||
| @@ -460,7 +460,7 @@ def new_graphs_to_json(graphs, filename): | ||||
|         'transition_E': transition_E.tolist(), | ||||
|     } | ||||
|  | ||||
|     with open(f'{filename}.meta.json', 'w') as f: | ||||
|     with open(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f: | ||||
|         json.dump(meta_dict, f) | ||||
|      | ||||
|     return meta_dict | ||||
| @@ -683,15 +683,41 @@ class Dataset(InMemoryDataset): | ||||
|             active_nodes = set() | ||||
|             for i in range(len_data): | ||||
|                 arch_info = self.api.query_meta_info_by_index(i) | ||||
|                 results = self.api.query_by_index(i, 'cifar100') | ||||
|                 nodes, edges = parse_architecture_string(arch_info.arch_str) | ||||
|                 adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges) | ||||
|                 for op in ops: | ||||
|                     if op not in active_nodes: | ||||
|                         active_nodes.add(op) | ||||
|                  | ||||
|                 graph_list.append({ | ||||
|                     "adj_matrix": adj_matrix, | ||||
|                     "ops": ops, | ||||
|                     "idx": i | ||||
|                     "idx": i, | ||||
|                     "train": [{ | ||||
|                         "iepoch": result.get_train()['iepoch'], | ||||
|                         "loss": result.get_train()['loss'], | ||||
|                         "accuracy": result.get_train()['accuracy'], | ||||
|                         "cur_time": result.get_train()['cur_time'], | ||||
|                         "all_time": result.get_train()['all_time'], | ||||
|                         "seed": seed, | ||||
|                     }for seed, result in results.items()], | ||||
|                     "valid": [{ | ||||
|                         "iepoch": result.get_eval('x-valid')['iepoch'], | ||||
|                         "loss": result.get_eval('x-valid')['loss'], | ||||
|                         "accuracy": result.get_eval('x-valid')['accuracy'], | ||||
|                         "cur_time": result.get_eval('x-valid')['cur_time'], | ||||
|                         "all_time": result.get_eval('x-valid')['all_time'], | ||||
|                         "seed": seed, | ||||
|                     }for seed, result in results.items()], | ||||
|                     "test": [{ | ||||
|                         "iepoch": result.get_eval('x-test')['iepoch'], | ||||
|                         "loss": result.get_eval('x-test')['loss'], | ||||
|                         "accuracy": result.get_eval('x-test')['accuracy'], | ||||
|                         "cur_time": result.get_eval('x-test')['cur_time'], | ||||
|                         "all_time": result.get_eval('x-test')['all_time'], | ||||
|                         "seed": seed, | ||||
|                     }for seed, result in results.items()] | ||||
|                 }) | ||||
|                 data = graph_to_graph_data((adj_matrix, ops))  | ||||
|                 data_list.append(data) | ||||
| @@ -925,8 +951,9 @@ class DataInfos(AbstractDatasetInfos): | ||||
|  | ||||
|             adj_ops_pairs = [] | ||||
|             for item in data: | ||||
|                 adj_matrix = np.array(item['adjacency_matrix']) | ||||
|                 ops = item['operations'] | ||||
|                 adj_matrix = np.array(item['adj_matrix']) | ||||
|                 ops = item['ops'] | ||||
|                 ops = [op_type[op] for op in ops] | ||||
|                 adj_ops_pairs.append((adj_matrix, ops)) | ||||
|              | ||||
|             return adj_ops_pairs | ||||
| @@ -944,7 +971,7 @@ class DataInfos(AbstractDatasetInfos): | ||||
|             #         ops_type[op] = len(ops_type) | ||||
|             # len_ops.add(len(ops)) | ||||
|             # graphs.append((adj_matrix, ops)) | ||||
|         graphs = read_adj_ops_from_json(f'nasbench-201.meta.json') | ||||
|         graphs = read_adj_ops_from_json(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json') | ||||
|  | ||||
|         # check first five graphs | ||||
|         for i in range(5): | ||||
| @@ -1158,7 +1185,7 @@ def compute_meta(root, source_name, train_index, test_index): | ||||
|         'transition_E': tansition_E.tolist(), | ||||
|         } | ||||
|  | ||||
|     with open(f'{root}/{source_name}.meta.json', "w") as f: | ||||
|     with open(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f: | ||||
|         json.dump(meta_dict, f) | ||||
|      | ||||
|     return meta_dict | ||||
|   | ||||
		Reference in New Issue
	
	Block a user