Compare commits
	
		
			1 Commits
		
	
	
		
			trainer
			...
			83f9345028
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 83f9345028 | 
| @@ -677,7 +677,7 @@ class Dataset(InMemoryDataset): | |||||||
|  |  | ||||||
|     def process(self): |     def process(self): | ||||||
|         source = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' |         source = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||||
|         self.api = API(source) |         # self.api = API(source) | ||||||
|  |  | ||||||
|         data_list = [] |         data_list = [] | ||||||
|         # len_data = len(self.api) |         # len_data = len(self.api) | ||||||
| @@ -710,6 +710,10 @@ class Dataset(InMemoryDataset): | |||||||
|             return True |             return True | ||||||
|  |  | ||||||
|         def generate_flex_adj_mat(ori_nodes, ori_edges, max_nodes=12, min_nodes=8,random_ratio=0.5): |         def generate_flex_adj_mat(ori_nodes, ori_edges, max_nodes=12, min_nodes=8,random_ratio=0.5): | ||||||
|  |             # print(ori_nodes) | ||||||
|  |             # print(ori_edges) | ||||||
|  |             ori_edges = np.array(ori_edges) | ||||||
|  |             # ori_nodes = np.array(ori_nodes) | ||||||
|             nasbench_201_node_num = 8 |             nasbench_201_node_num = 8 | ||||||
|             # random.seed(random_seed) |             # random.seed(random_seed) | ||||||
|             nodes_num = random.randint(min_nodes, max_nodes) |             nodes_num = random.randint(min_nodes, max_nodes) | ||||||
| @@ -727,12 +731,13 @@ class Dataset(InMemoryDataset): | |||||||
|                     rand = random.random() |                     rand = random.random() | ||||||
|                     if rand < random_ratio: |                     if rand < random_ratio: | ||||||
|                         edges[i, j] = 1 |                         edges[i, j] = 1 | ||||||
|             return nodes, edges |             return  edges,nodes | ||||||
|          |          | ||||||
|         def get_nasbench_201_val(idx): |         def get_nasbench_201_val(idx): | ||||||
|             pass |             pass | ||||||
|  |  | ||||||
|         def graph_to_graph_data(graph, idx): |         # def graph_to_graph_data(graph, idx): | ||||||
|  |         def graph_to_graph_data(graph): | ||||||
|             ops = graph[1] |             ops = graph[1] | ||||||
|             adj = graph[0] |             adj = graph[0] | ||||||
|             nodes = [] |             nodes = [] | ||||||
| @@ -753,58 +758,73 @@ class Dataset(InMemoryDataset): | |||||||
|             edge_index = torch.tensor(edges_list, dtype=torch.long).t() |             edge_index = torch.tensor(edges_list, dtype=torch.long).t() | ||||||
|             edge_type = torch.tensor(edge_type, dtype=torch.long) |             edge_type = torch.tensor(edge_type, dtype=torch.long) | ||||||
|             edge_attr = edge_type |             edge_attr = edge_type | ||||||
|             # y = torch.tensor([0, 0], dtype=torch.float).view(1, -1) |             y = torch.tensor([0, 0], dtype=torch.float).view(1, -1) | ||||||
|             y = get_nasbench_201_val(idx) |             # y = get_nasbench_201_val(idx) | ||||||
|             data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i) |             data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i) | ||||||
|             return data |             return data | ||||||
|         graph_list = [] |         graph_list = [] | ||||||
|  |  | ||||||
|         with tqdm(total = len_data) as pbar: |         with tqdm(total = len_data) as pbar: | ||||||
|             active_nodes = set() |             active_nodes = set() | ||||||
|             for i in range(len_data): |             file_path = '/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json' | ||||||
|                 arch_info = self.api.query_meta_info_by_index(i) |             with open(file_path, 'r') as f: | ||||||
|                 results = self.api.query_by_index(i, 'cifar100') |                 graph_list = json.load(f) | ||||||
|  |             i = 0 | ||||||
|  |             for graph in graph_list: | ||||||
|  |                 # arch_info = self.api.query_meta_info_by_index(i) | ||||||
|  |                 # results = self.api.query_by_index(i, 'cifar100') | ||||||
|  |                 arch_info = graph['arch_str'] | ||||||
|  |                 # results =  | ||||||
|                 # nodes, edges = parse_architecture_string(arch_info.arch_str) |                 # nodes, edges = parse_architecture_string(arch_info.arch_str) | ||||||
|                 ops, adj_matrix = parse_architecture_string(arch_info.arch_str) |                 # ops, adj_matrix = parse_architecture_string(arch_info.arch_str, padding=4) | ||||||
|  |                 ops, adj_matrix, ori_nodes, ori_adj = parse_architecture_string(arch_info, padding=4) | ||||||
|                 # adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges) |                 # adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges) | ||||||
|                 for op in ops: |                 for op in ops: | ||||||
|                     if op not in active_nodes: |                     if op not in active_nodes: | ||||||
|                         active_nodes.add(op) |                         active_nodes.add(op) | ||||||
|                  |                  | ||||||
|                 graph_list.append({ |  | ||||||
|                     "adj_matrix": adj_matrix, |  | ||||||
|                     "ops": ops, |  | ||||||
|                     "idx": i, |  | ||||||
|                     "train": [{ |  | ||||||
|                         "iepoch": result.get_train()['iepoch'], |  | ||||||
|                         "loss": result.get_train()['loss'], |  | ||||||
|                         "accuracy": result.get_train()['accuracy'], |  | ||||||
|                         "cur_time": result.get_train()['cur_time'], |  | ||||||
|                         "all_time": result.get_train()['all_time'], |  | ||||||
|                         "seed": seed, |  | ||||||
|                     }for seed, result in results.items()], |  | ||||||
|                     "valid": [{ |  | ||||||
|                         "iepoch": result.get_eval('x-valid')['iepoch'], |  | ||||||
|                         "loss": result.get_eval('x-valid')['loss'], |  | ||||||
|                         "accuracy": result.get_eval('x-valid')['accuracy'], |  | ||||||
|                         "cur_time": result.get_eval('x-valid')['cur_time'], |  | ||||||
|                         "all_time": result.get_eval('x-valid')['all_time'], |  | ||||||
|                         "seed": seed, |  | ||||||
|                     }for seed, result in results.items()], |  | ||||||
|                     "test": [{ |  | ||||||
|                         "iepoch": result.get_eval('x-test')['iepoch'], |  | ||||||
|                         "loss": result.get_eval('x-test')['loss'], |  | ||||||
|                         "accuracy": result.get_eval('x-test')['accuracy'], |  | ||||||
|                         "cur_time": result.get_eval('x-test')['cur_time'], |  | ||||||
|                         "all_time": result.get_eval('x-test')['all_time'], |  | ||||||
|                         "seed": seed, |  | ||||||
|                     }for seed, result in results.items()] |  | ||||||
|                 }) |  | ||||||
|                 data = graph_to_graph_data((adj_matrix, ops))  |                 data = graph_to_graph_data((adj_matrix, ops))  | ||||||
|  |                 if i < 3: | ||||||
|  |                     print(f"i={i}, data={data}") | ||||||
|  |                     with open(f'{i}.json', 'w') as f: | ||||||
|  |                         f.write(str(data.x)) | ||||||
|  |                         f.write(str(data.edge_index)) | ||||||
|  |                         f.write(str(data.edge_attr)) | ||||||
|                 data_list.append(data) |                 data_list.append(data) | ||||||
|  |  | ||||||
|                 # new_adj, new_ops = generate_flex_adj_mat(ori_nodes=ops, ori_edges=adj_matrix, max_nodes=12, min_nodes=8,  random_ratio=0.5) |                 new_adj, new_ops = generate_flex_adj_mat(ori_nodes=ori_nodes, ori_edges=ori_adj, max_nodes=12, min_nodes=8,  random_ratio=0.5) | ||||||
|                 # data_list.append(graph_to_graph_data((new_adj, new_ops))) |                 data_list.append(graph_to_graph_data((new_adj, new_ops))) | ||||||
|  |                 | ||||||
|  |                 # graph_list.append({ | ||||||
|  |                 #     "adj_matrix": adj_matrix, | ||||||
|  |                 #     "ops": ops, | ||||||
|  |                 #     "arch_str": arch_info.arch_str, | ||||||
|  |                 #     "idx": i, | ||||||
|  |                 #     "train": [{ | ||||||
|  |                 #         "iepoch": result.get_train()['iepoch'], | ||||||
|  |                 #         "loss": result.get_train()['loss'], | ||||||
|  |                 #         "accuracy": result.get_train()['accuracy'], | ||||||
|  |                 #         "cur_time": result.get_train()['cur_time'], | ||||||
|  |                 #         "all_time": result.get_train()['all_time'], | ||||||
|  |                 #         "seed": seed, | ||||||
|  |                 #     }for seed, result in results.items()], | ||||||
|  |                 #     "valid": [{ | ||||||
|  |                 #         "iepoch": result.get_eval('x-valid')['iepoch'], | ||||||
|  |                 #         "loss": result.get_eval('x-valid')['loss'], | ||||||
|  |                 #         "accuracy": result.get_eval('x-valid')['accuracy'], | ||||||
|  |                 #         "cur_time": result.get_eval('x-valid')['cur_time'], | ||||||
|  |                 #         "all_time": result.get_eval('x-valid')['all_time'], | ||||||
|  |                 #         "seed": seed, | ||||||
|  |                 #     }for seed, result in results.items()], | ||||||
|  |                 #     "test": [{ | ||||||
|  |                 #         "iepoch": result.get_eval('x-test')['iepoch'], | ||||||
|  |                 #         "loss": result.get_eval('x-test')['loss'], | ||||||
|  |                 #         "accuracy": result.get_eval('x-test')['accuracy'], | ||||||
|  |                 #         "cur_time": result.get_eval('x-test')['cur_time'], | ||||||
|  |                 #         "all_time": result.get_eval('x-test')['all_time'], | ||||||
|  |                 #         "seed": seed, | ||||||
|  |                 #     }for seed, result in results.items()] | ||||||
|  |                 # }) | ||||||
|                 pbar.update(1) |                 pbar.update(1) | ||||||
|          |          | ||||||
|         for graph in graph_list: |         for graph in graph_list: | ||||||
| @@ -981,18 +1001,29 @@ class Dataset_origin(InMemoryDataset): | |||||||
|  |  | ||||||
|         torch.save(self.collate(data_list), self.processed_paths[0]) |         torch.save(self.collate(data_list), self.processed_paths[0]) | ||||||
|  |  | ||||||
| def parse_architecture_string(arch_str): | def parse_architecture_string(arch_str, padding=0): | ||||||
|     # print(arch_str) |     # print(arch_str) | ||||||
|     steps = arch_str.split('+') |     steps = arch_str.split('+') | ||||||
|     nodes = ['input']  # Start with input node |     nodes = ['input']  # Start with input node | ||||||
|     adj_mat = np.array([[0, 1, 1, 0, 1, 0, 0, 0], |     ori_adj_mat = [[0, 1, 1, 0, 1, 0, 0, 0], | ||||||
|                         [0, 0, 0, 1, 0, 1 ,0 ,0], |                         [0, 0, 0, 1, 0, 1 ,0 ,0], | ||||||
|                         [0, 0, 0, 0, 0, 0, 1, 0], |                         [0, 0, 0, 0, 0, 0, 1, 0], | ||||||
|                         [0, 0, 0, 0, 0, 0, 1, 0], |                         [0, 0, 0, 0, 0, 0, 1, 0], | ||||||
|                         [0, 0, 0, 0, 0, 0, 0, 1], |                         [0, 0, 0, 0, 0, 0, 0, 1], | ||||||
|                         [0, 0, 0, 0, 0, 0, 0, 1], |                         [0, 0, 0, 0, 0, 0, 0, 1], | ||||||
|                         [0, 0, 0, 0, 0, 0, 0, 1], |                         [0, 0, 0, 0, 0, 0, 0, 1], | ||||||
|                         [0, 0, 0, 0, 0, 0, 0, 0]])  |                         # [0, 0, 0, 0, 0, 0, 0, 0]])  | ||||||
|  |                         [0, 0, 0, 0, 0, 0, 0, 0]]  | ||||||
|  |    # adj_mat = np.array([[0, 1, 1, 0, 1, 0, 0, 0], | ||||||
|  |     adj_mat = [[0, 1, 1, 0, 1, 0, 0, 0], | ||||||
|  |                         [0, 0, 0, 1, 0, 1 ,0 ,0], | ||||||
|  |                         [0, 0, 0, 0, 0, 0, 1, 0], | ||||||
|  |                         [0, 0, 0, 0, 0, 0, 1, 0], | ||||||
|  |                         [0, 0, 0, 0, 0, 0, 0, 1], | ||||||
|  |                         [0, 0, 0, 0, 0, 0, 0, 1], | ||||||
|  |                         [0, 0, 0, 0, 0, 0, 0, 1], | ||||||
|  |                         # [0, 0, 0, 0, 0, 0, 0, 0]])  | ||||||
|  |                         [0, 0, 0, 0, 0, 0, 0, 0]]  | ||||||
|     steps = arch_str.split('+') |     steps = arch_str.split('+') | ||||||
|     steps_coding = ['0', '0', '1', '0', '1', '2'] |     steps_coding = ['0', '0', '1', '0', '1', '2'] | ||||||
|     cont = 0 |     cont = 0 | ||||||
| @@ -1003,8 +1034,22 @@ def parse_architecture_string(arch_str): | |||||||
|             assert idx == steps_coding[cont] |             assert idx == steps_coding[cont] | ||||||
|             cont += 1 |             cont += 1 | ||||||
|             nodes.append(n) |             nodes.append(n) | ||||||
|  |     ori_nodes = nodes.copy() | ||||||
|     nodes.append('output')  # Add output node |     nodes.append('output')  # Add output node | ||||||
|     return nodes, adj_mat |     if padding > 0: | ||||||
|  |         for i in range(padding): | ||||||
|  |             nodes.append('none') | ||||||
|  |         for adj_row in adj_mat: | ||||||
|  |             for i in range(padding): | ||||||
|  |                 adj_row.append(0) | ||||||
|  |         # adj_mat = np.append(adj_mat, np.zeros((padding, len(nodes)))) | ||||||
|  |         for i in range(padding): | ||||||
|  |             adj_mat.append([0] * len(nodes)) | ||||||
|  |     # print(nodes) | ||||||
|  |     # print(adj_mat) | ||||||
|  |     # print(len(adj_mat)) | ||||||
|  |  | ||||||
|  |     return nodes, adj_mat, ori_nodes, ori_adj_mat | ||||||
|  |  | ||||||
| def create_adj_matrix_and_ops(nodes, edges): | def create_adj_matrix_and_ops(nodes, edges): | ||||||
|     num_nodes = len(nodes) |     num_nodes = len(nodes) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user