add the idea of guidance
This commit is contained in:
		| @@ -8,6 +8,7 @@ import os | ||||
| import os.path as osp | ||||
| import pathlib | ||||
| import json | ||||
| import random | ||||
|  | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| @@ -49,6 +50,9 @@ op_type = { | ||||
|     'none': 5, | ||||
|     'output': 6, | ||||
| } | ||||
|  | ||||
| num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] | ||||
|  | ||||
| class DataModule(AbstractDataModule): | ||||
|     def __init__(self, cfg): | ||||
|         self.datadir = cfg.dataset.datadir | ||||
| @@ -676,6 +680,52 @@ class Dataset(InMemoryDataset): | ||||
|  | ||||
|         data_list = [] | ||||
|         len_data = len(self.api) | ||||
|         def check_valid_graph(nodes, edges): | ||||
|             if len(nodes) != edges.shape[0] or len(nodes) != edges.shape[1]: | ||||
|                 return False | ||||
|             if nodes[0] != 'input' or nodes[-1] != 'output': | ||||
|                 return False | ||||
|             for i in range(0, len(nodes)): | ||||
|                 if edges[i][i] == 1: | ||||
|                     return False | ||||
|             for i in range(1, len(nodes) - 1): | ||||
|                 if nodes[i] not in op_type or nodes[i] == 'input' or nodes[i] == 'output': | ||||
|                     return False | ||||
|             for i in range(0, len(nodes)): | ||||
|                 for j in range(i, len(nodes)): | ||||
|                     if edges[i, j] == 1 and nodes[j] == 'input': | ||||
|                         return False | ||||
|             for i in range(0, len(nodes)): | ||||
|                 for j in range(i, len(nodes)): | ||||
|                     if edges[i, j] == 1 and nodes[i] == 'output': | ||||
|                         return False | ||||
|             flag = 0 | ||||
|             for i in range(0,len(nodes)): | ||||
|                 if edges[i,-1] == 1: | ||||
|                     flag = 1 | ||||
|                     break | ||||
|             if flag == 0: return False | ||||
|             return True | ||||
|  | ||||
|         def generate_flex_adj_mat(ori_nodes, ori_edges, max_nodes=12, min_nodes=8,random_ratio=0.5): | ||||
|             nasbench_201_node_num = 8 | ||||
|             # random.seed(random_seed) | ||||
|             nodes_num = random.randint(min_nodes, max_nodes) | ||||
|             # print(f'arch_str: {arch_str}, \nmax_nodes: {max_nodes}, min_nodes: {min_nodes}, nodes_num: {nodes_num},random_seed: {random_seed},random_ratio: {random_ratio}') | ||||
|             add_num = nodes_num - nasbench_201_node_num | ||||
|             # ori_nodes, ori_edges = parse_architecture_string(arch_str) | ||||
|             add_nodes = [op for op in random.choices(num_to_op[1:-1], k=add_num)] | ||||
|             # print(add_nodes) | ||||
|             nodes = ori_nodes[:-1] + add_nodes + ['output'] | ||||
|             edges = np.zeros((nodes_num , nodes_num)) | ||||
|             edges[:6, :6] = ori_edges[:6, :6] | ||||
|             edges[0:8, -1] = ori_edges[0:8 , -1] | ||||
|             for i in range(0, nodes_num): | ||||
|                 for j in range(max(7,i + 1), nodes_num): | ||||
|                     rand = random.random() | ||||
|                     if rand < random_ratio: | ||||
|                         edges[i, j] = 1 | ||||
|             return nodes, edges | ||||
|  | ||||
|         def graph_to_graph_data(graph): | ||||
|             ops = graph[1] | ||||
| @@ -746,6 +796,9 @@ class Dataset(InMemoryDataset): | ||||
|                 }) | ||||
|                 data = graph_to_graph_data((adj_matrix, ops))  | ||||
|                 data_list.append(data) | ||||
|  | ||||
|                 # new_adj, new_ops = generate_flex_adj_mat(ori_nodes=ops, ori_edges=adj_matrix, max_nodes=12, min_nodes=8,  random_ratio=0.5) | ||||
|                 # data_list.append(graph_to_graph_data((new_adj, new_ops))) | ||||
|                 pbar.update(1) | ||||
|          | ||||
|         for graph in graph_list: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user