update a small problem
This commit is contained in:
		@@ -359,15 +359,15 @@ def new_graphs_to_json(graphs, filename):
 | 
			
		||||
 | 
			
		||||
    node_name_list = []
 | 
			
		||||
    node_count_list = []
 | 
			
		||||
    node_name_list.append('*')
 | 
			
		||||
    
 | 
			
		||||
    for op_name in op_type:
 | 
			
		||||
        node_name_list.append(op_name)
 | 
			
		||||
        node_count_list.append(0) 
 | 
			
		||||
    
 | 
			
		||||
    node_name_list.append('*')
 | 
			
		||||
    node_count_list.append(0)
 | 
			
		||||
    n_nodes_per_graph = [0] * num_graph
 | 
			
		||||
    edge_count_list = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
 | 
			
		||||
    edge_count_list = [0, 0] 
 | 
			
		||||
    valencies = [0] * (len(op_type) + 1)
 | 
			
		||||
    transition_E = np.zeros((len(op_type) + 1, len(op_type) + 1, 2))
 | 
			
		||||
 | 
			
		||||
@@ -388,16 +388,16 @@ def new_graphs_to_json(graphs, filename):
 | 
			
		||||
 | 
			
		||||
        for op in ops:
 | 
			
		||||
            node = op
 | 
			
		||||
            if node == '*':
 | 
			
		||||
                node_count_list[-1] += 1
 | 
			
		||||
                cur_node_count_arr[-1] += 1
 | 
			
		||||
            else:
 | 
			
		||||
                node_count_list[op_type[node]] += 1
 | 
			
		||||
                cur_node_count_arr[op_type[node]] += 1
 | 
			
		||||
                try:
 | 
			
		||||
                    valencies[int(op_type[node])] += 1
 | 
			
		||||
                except:
 | 
			
		||||
                    print('int(op_type[node])', int(op_type[node]))
 | 
			
		||||
            # if node == '*':
 | 
			
		||||
            #     node_count_list[-1] += 1
 | 
			
		||||
            #     cur_node_count_arr[-1] += 1
 | 
			
		||||
            # else:
 | 
			
		||||
            node_count_list[node] += 1
 | 
			
		||||
            cur_node_count_arr[node] += 1
 | 
			
		||||
            try:
 | 
			
		||||
                valencies[node] += 1
 | 
			
		||||
            except:
 | 
			
		||||
                print('int(op_type[node])', int(node))
 | 
			
		||||
        
 | 
			
		||||
        transition_E_temp = np.zeros((len(op_type) + 1, len(op_type) + 1, 2))
 | 
			
		||||
        for i in range(n_node):
 | 
			
		||||
@@ -406,8 +406,8 @@ def new_graphs_to_json(graphs, filename):
 | 
			
		||||
                    continue
 | 
			
		||||
                start_node, end_node = i, j
 | 
			
		||||
                
 | 
			
		||||
                start_index = op_type[ops[start_node]]
 | 
			
		||||
                end_index = op_type[ops[end_node]]
 | 
			
		||||
                start_index = ops[start_node]
 | 
			
		||||
                end_index = ops[end_node]
 | 
			
		||||
                bond_index = 1
 | 
			
		||||
                edge_count_list[bond_index] += 2
 | 
			
		||||
                
 | 
			
		||||
@@ -418,7 +418,7 @@ def new_graphs_to_json(graphs, filename):
 | 
			
		||||
 | 
			
		||||
        edge_count_list[0] += n_node * (n_node - 1) - n_edge * 2
 | 
			
		||||
        cur_tot_edge = cur_node_count_arr.reshape(-1,1) * cur_node_count_arr.reshape(1,-1) * 2
 | 
			
		||||
        print(f"cur_tot_edge={cur_tot_edge}, shape: {cur_tot_edge.shape}")
 | 
			
		||||
        # print(f"cur_tot_edge={cur_tot_edge}, shape: {cur_tot_edge.shape}")
 | 
			
		||||
        cur_tot_edge = cur_tot_edge - np.diag(cur_node_count_arr) * 2
 | 
			
		||||
        transition_E[:, :, 0] += cur_tot_edge - transition_E_temp.sum(axis=-1)
 | 
			
		||||
        assert (cur_tot_edge > transition_E_temp.sum(axis=-1)).sum() >= 0
 | 
			
		||||
@@ -460,7 +460,7 @@ def new_graphs_to_json(graphs, filename):
 | 
			
		||||
        'transition_E': transition_E.tolist(),
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    with open(f'{filename}.meta.json', 'w') as f:
 | 
			
		||||
    with open(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f:
 | 
			
		||||
        json.dump(meta_dict, f)
 | 
			
		||||
    
 | 
			
		||||
    return meta_dict
 | 
			
		||||
@@ -683,15 +683,41 @@ class Dataset(InMemoryDataset):
 | 
			
		||||
            active_nodes = set()
 | 
			
		||||
            for i in range(len_data):
 | 
			
		||||
                arch_info = self.api.query_meta_info_by_index(i)
 | 
			
		||||
                results = self.api.query_by_index(i, 'cifar100')
 | 
			
		||||
                nodes, edges = parse_architecture_string(arch_info.arch_str)
 | 
			
		||||
                adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges)
 | 
			
		||||
                for op in ops:
 | 
			
		||||
                    if op not in active_nodes:
 | 
			
		||||
                        active_nodes.add(op)
 | 
			
		||||
                
 | 
			
		||||
                graph_list.append({
 | 
			
		||||
                    "adj_matrix": adj_matrix,
 | 
			
		||||
                    "ops": ops,
 | 
			
		||||
                    "idx": i
 | 
			
		||||
                    "idx": i,
 | 
			
		||||
                    "train": [{
 | 
			
		||||
                        "iepoch": result.get_train()['iepoch'],
 | 
			
		||||
                        "loss": result.get_train()['loss'],
 | 
			
		||||
                        "accuracy": result.get_train()['accuracy'],
 | 
			
		||||
                        "cur_time": result.get_train()['cur_time'],
 | 
			
		||||
                        "all_time": result.get_train()['all_time'],
 | 
			
		||||
                        "seed": seed,
 | 
			
		||||
                    }for seed, result in results.items()],
 | 
			
		||||
                    "valid": [{
 | 
			
		||||
                        "iepoch": result.get_eval('x-valid')['iepoch'],
 | 
			
		||||
                        "loss": result.get_eval('x-valid')['loss'],
 | 
			
		||||
                        "accuracy": result.get_eval('x-valid')['accuracy'],
 | 
			
		||||
                        "cur_time": result.get_eval('x-valid')['cur_time'],
 | 
			
		||||
                        "all_time": result.get_eval('x-valid')['all_time'],
 | 
			
		||||
                        "seed": seed,
 | 
			
		||||
                    }for seed, result in results.items()],
 | 
			
		||||
                    "test": [{
 | 
			
		||||
                        "iepoch": result.get_eval('x-test')['iepoch'],
 | 
			
		||||
                        "loss": result.get_eval('x-test')['loss'],
 | 
			
		||||
                        "accuracy": result.get_eval('x-test')['accuracy'],
 | 
			
		||||
                        "cur_time": result.get_eval('x-test')['cur_time'],
 | 
			
		||||
                        "all_time": result.get_eval('x-test')['all_time'],
 | 
			
		||||
                        "seed": seed,
 | 
			
		||||
                    }for seed, result in results.items()]
 | 
			
		||||
                })
 | 
			
		||||
                data = graph_to_graph_data((adj_matrix, ops)) 
 | 
			
		||||
                data_list.append(data)
 | 
			
		||||
@@ -925,8 +951,9 @@ class DataInfos(AbstractDatasetInfos):
 | 
			
		||||
 | 
			
		||||
            adj_ops_pairs = []
 | 
			
		||||
            for item in data:
 | 
			
		||||
                adj_matrix = np.array(item['adjacency_matrix'])
 | 
			
		||||
                ops = item['operations']
 | 
			
		||||
                adj_matrix = np.array(item['adj_matrix'])
 | 
			
		||||
                ops = item['ops']
 | 
			
		||||
                ops = [op_type[op] for op in ops]
 | 
			
		||||
                adj_ops_pairs.append((adj_matrix, ops))
 | 
			
		||||
            
 | 
			
		||||
            return adj_ops_pairs
 | 
			
		||||
@@ -944,7 +971,7 @@ class DataInfos(AbstractDatasetInfos):
 | 
			
		||||
            #         ops_type[op] = len(ops_type)
 | 
			
		||||
            # len_ops.add(len(ops))
 | 
			
		||||
            # graphs.append((adj_matrix, ops))
 | 
			
		||||
        graphs = read_adj_ops_from_json(f'nasbench-201.meta.json')
 | 
			
		||||
        graphs = read_adj_ops_from_json(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json')
 | 
			
		||||
 | 
			
		||||
        # check first five graphs
 | 
			
		||||
        for i in range(5):
 | 
			
		||||
@@ -1158,7 +1185,7 @@ def compute_meta(root, source_name, train_index, test_index):
 | 
			
		||||
        'transition_E': tansition_E.tolist(),
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    with open(f'{root}/{source_name}.meta.json', "w") as f:
 | 
			
		||||
    with open(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f:
 | 
			
		||||
        json.dump(meta_dict, f)
 | 
			
		||||
    
 | 
			
		||||
    return meta_dict
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user