add config path
This commit is contained in:
		| @@ -16,9 +16,12 @@ general: | |||||||
|     final_model_chains_to_save: 1 |     final_model_chains_to_save: 1 | ||||||
|     enable_progress_bar: False |     enable_progress_bar: False | ||||||
|     save_model: True |     save_model: True | ||||||
|     log_dir: '/nfs/data3/hanzhang/nasbenchDiT' |     log_dir: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT' | ||||||
|     number_checkpoint_limit: 3 |     number_checkpoint_limit: 3 | ||||||
|     type: 'Trainer' |     type: 'Trainer' | ||||||
|  |     nas_201: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||||
|  |     swap_result: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/swap_results.csv' | ||||||
|  |     root: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/' | ||||||
| model: | model: | ||||||
|     type: 'discrete' |     type: 'discrete' | ||||||
|     transition: 'marginal'                   |     transition: 'marginal'                   | ||||||
|   | |||||||
| @@ -25,7 +25,6 @@ from sklearn.model_selection import train_test_split | |||||||
| import utils as utils | import utils as utils | ||||||
| from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule | from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule | ||||||
| from diffusion.distributions import DistributionNodes | from diffusion.distributions import DistributionNodes | ||||||
| from naswot.score_networks import get_nasbench201_idx_score |  | ||||||
| from naswot import nasspace | from naswot import nasspace | ||||||
| from naswot import datasets as dt | from naswot import datasets as dt | ||||||
|  |  | ||||||
| @@ -72,7 +71,9 @@ class DataModule(AbstractDataModule): | |||||||
|         #     base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] |         #     base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] | ||||||
|         # except NameError: |         # except NameError: | ||||||
|         # base_path = pathlib.Path(os.getcwd()).parent[2] |         # base_path = pathlib.Path(os.getcwd()).parent[2] | ||||||
|         base_path = '/nfs/data3/hanzhang/nasbenchDiT' |         # base_path = '/nfs/data3/hanzhang/nasbenchDiT' | ||||||
|  |         base_path = os.path.join(self.cfg.general.root, "..") | ||||||
|  |  | ||||||
|         root_path = os.path.join(base_path, self.datadir) |         root_path = os.path.join(base_path, self.datadir) | ||||||
|         self.root_path = root_path |         self.root_path = root_path | ||||||
|  |  | ||||||
| @@ -84,7 +85,7 @@ class DataModule(AbstractDataModule): | |||||||
|         # Load the dataset to the memory |         # Load the dataset to the memory | ||||||
|         # Dataset has target property, root path, and transform |         # Dataset has target property, root path, and transform | ||||||
|         source = './NAS-Bench-201-v1_1-096897.pth' |         source = './NAS-Bench-201-v1_1-096897.pth' | ||||||
|         dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None) |         dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None, cfg=self.cfg) | ||||||
|         self.dataset = dataset |         self.dataset = dataset | ||||||
|         # self.api = dataset.api |         # self.api = dataset.api | ||||||
|  |  | ||||||
| @@ -384,7 +385,7 @@ class DataModule_original(AbstractDataModule): | |||||||
|     def test_dataloader(self): |     def test_dataloader(self): | ||||||
|         return self.test_loader |         return self.test_loader | ||||||
|  |  | ||||||
| def new_graphs_to_json(graphs, filename): | def new_graphs_to_json(graphs, filename, cfg): | ||||||
|     source_name = "nasbench-201" |     source_name = "nasbench-201" | ||||||
|     num_graph = len(graphs) |     num_graph = len(graphs) | ||||||
|  |  | ||||||
| @@ -491,8 +492,9 @@ def new_graphs_to_json(graphs, filename): | |||||||
|         'num_active_nodes': len(active_nodes), |         'num_active_nodes': len(active_nodes), | ||||||
|         'transition_E': transition_E.tolist(), |         'transition_E': transition_E.tolist(), | ||||||
|     } |     } | ||||||
|  |     import os | ||||||
|     with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f: |     # with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f: | ||||||
|  |     with open(os.path.join(cfg.general.root,'nasbench-201-meta.json'), 'w') as f: | ||||||
|         json.dump(meta_dict, f) |         json.dump(meta_dict, f) | ||||||
|      |      | ||||||
|     return meta_dict |     return meta_dict | ||||||
| @@ -656,9 +658,11 @@ def graphs_to_json(graphs, filename): | |||||||
|         json.dump(meta_dict, f) |         json.dump(meta_dict, f) | ||||||
|     return meta_dict |     return meta_dict | ||||||
| class Dataset(InMemoryDataset): | class Dataset(InMemoryDataset): | ||||||
|     def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None): |     def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None, cfg=None): | ||||||
|         self.target_prop = target_prop |         self.target_prop = target_prop | ||||||
|         source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' |         self.cfg = cfg | ||||||
|  |         # source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||||
|  |         source = os.path.join(self.cfg.general.root, 'NAS-Bench-201-v1_1-096897.pth') | ||||||
|         self.source = source |         self.source = source | ||||||
|         # self.api = API(source)  # Initialize NAS-Bench-201 API |         # self.api = API(source)  # Initialize NAS-Bench-201 API | ||||||
|         # print('API loaded') |         # print('API loaded') | ||||||
| @@ -679,7 +683,8 @@ class Dataset(InMemoryDataset): | |||||||
|         return [f'{self.source}.pt'] |         return [f'{self.source}.pt'] | ||||||
|  |  | ||||||
|     def process(self): |     def process(self): | ||||||
|         source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' |         # source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||||
|  |         source = self.cfg.general.nas_201 | ||||||
|         # self.api = API(source) |         # self.api = API(source) | ||||||
|  |  | ||||||
|         data_list = [] |         data_list = [] | ||||||
| @@ -748,7 +753,8 @@ class Dataset(InMemoryDataset): | |||||||
|             return  edges,nodes |             return  edges,nodes | ||||||
|          |          | ||||||
|  |  | ||||||
|         def graph_to_graph_data(graph, idx, train_loader, searchspace, args, device): |         # def graph_to_graph_data(graph, idx, train_loader, searchspace, args, device): | ||||||
|  |         def graph_to_graph_data(graph, idx,  args, device): | ||||||
|         # def graph_to_graph_data(graph): |         # def graph_to_graph_data(graph): | ||||||
|             ops = graph[1] |             ops = graph[1] | ||||||
|             adj = graph[0] |             adj = graph[0] | ||||||
| @@ -797,7 +803,7 @@ class Dataset(InMemoryDataset): | |||||||
|         args.batch_size = 128 |         args.batch_size = 128 | ||||||
|         args.GPU = '0' |         args.GPU = '0' | ||||||
|         args.dataset = 'cifar10' |         args.dataset = 'cifar10' | ||||||
|         args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' |         args.api_loc = self.cfg.general.nas_201  | ||||||
|         args.data_loc = '../cifardata/' |         args.data_loc = '../cifardata/' | ||||||
|         args.seed = 777 |         args.seed = 777 | ||||||
|         args.init = '' |         args.init = '' | ||||||
| @@ -812,10 +818,11 @@ class Dataset(InMemoryDataset): | |||||||
|         args.num_modules_per_stack = 3 |         args.num_modules_per_stack = 3 | ||||||
|         args.num_labels = 1 |         args.num_labels = 1 | ||||||
|         searchspace = nasspace.get_search_space(args) |         searchspace = nasspace.get_search_space(args) | ||||||
|         train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) |         # train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) | ||||||
|         self.swap_scores = [] |         self.swap_scores = [] | ||||||
|         import csv |         import csv | ||||||
|         with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f: |         # with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f: | ||||||
|  |         with open(self.cfg.general.swap_result, 'r') as f: | ||||||
|         # with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results_cifar100.csv', 'r') as f: |         # with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results_cifar100.csv', 'r') as f: | ||||||
|             reader = csv.reader(f) |             reader = csv.reader(f) | ||||||
|             header = next(reader) |             header = next(reader) | ||||||
| @@ -824,12 +831,15 @@ class Dataset(InMemoryDataset): | |||||||
|         device = torch.device('cuda:2') |         device = torch.device('cuda:2') | ||||||
|         with tqdm(total = len_data) as pbar: |         with tqdm(total = len_data) as pbar: | ||||||
|             active_nodes = set() |             active_nodes = set() | ||||||
|             file_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json' |             import os | ||||||
|  |             # file_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json' | ||||||
|  |             file_path = os.path.join(self.cfg.general.root, 'nasbench-201-graph.json') | ||||||
|             with open(file_path, 'r') as f: |             with open(file_path, 'r') as f: | ||||||
|                 graph_list = json.load(f) |                 graph_list = json.load(f) | ||||||
|             i = 0 |             i = 0 | ||||||
|             flex_graph_list = [] |             flex_graph_list = [] | ||||||
|             flex_graph_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json' |             # flex_graph_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json' | ||||||
|  |             flex_graph_path = os.path.join(self.cfg.general.root,'flex-nasbench201-graph.json') | ||||||
|             for graph in graph_list: |             for graph in graph_list: | ||||||
|                 print(f'iterate every graph in graph_list, here is {i}') |                 print(f'iterate every graph in graph_list, here is {i}') | ||||||
|                 arch_info = graph['arch_str'] |                 arch_info = graph['arch_str'] | ||||||
| @@ -837,7 +847,8 @@ class Dataset(InMemoryDataset): | |||||||
|                 for op in ops: |                 for op in ops: | ||||||
|                     if op not in active_nodes: |                     if op not in active_nodes: | ||||||
|                         active_nodes.add(op) |                         active_nodes.add(op) | ||||||
|                 data = graph_to_graph_data((adj_matrix, ops),idx=i, train_loader=train_loader, searchspace=searchspace, args=args, device=device)  |                 # data = graph_to_graph_data((adj_matrix, ops),idx=i, train_loader=train_loader, searchspace=searchspace, args=args, device=device)  | ||||||
|  |                 data = graph_to_graph_data((adj_matrix, ops),idx=i, args=args, device=device)  | ||||||
|                 i += 1 |                 i += 1 | ||||||
|                 if data is None: |                 if data is None: | ||||||
|                     pbar.update(1) |                     pbar.update(1) | ||||||
| @@ -1140,6 +1151,7 @@ class DataInfos(AbstractDatasetInfos): | |||||||
|         self.task = task_name |         self.task = task_name | ||||||
|         self.task_type = tasktype_dict.get(task_name, "regression") |         self.task_type = tasktype_dict.get(task_name, "regression") | ||||||
|         self.ensure_connected = cfg.model.ensure_connected |         self.ensure_connected = cfg.model.ensure_connected | ||||||
|  |         self.cfg = cfg | ||||||
|         # self.api = dataset.api |         # self.api = dataset.api | ||||||
|  |  | ||||||
|         datadir = cfg.dataset.datadir |         datadir = cfg.dataset.datadir | ||||||
| @@ -1182,14 +1194,15 @@ class DataInfos(AbstractDatasetInfos): | |||||||
|             # len_ops.add(len(ops)) |             # len_ops.add(len(ops)) | ||||||
|             # graphs.append((adj_matrix, ops)) |             # graphs.append((adj_matrix, ops)) | ||||||
|         # graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json') |         # graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json') | ||||||
|         graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json') |         # graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json') | ||||||
|  |         graphs = read_adj_ops_from_json(os.path.join(self.cfg.general.root, 'nasbench-201-graph.json')) | ||||||
|  |  | ||||||
|         # check first five graphs |         # check first five graphs | ||||||
|         for i in range(5): |         for i in range(5): | ||||||
|             print(f'graph {i} : {graphs[i]}') |             print(f'graph {i} : {graphs[i]}') | ||||||
|         # print(f'ops_type: {ops_type}') |         # print(f'ops_type: {ops_type}') | ||||||
|  |  | ||||||
|         meta_dict = new_graphs_to_json(graphs, 'nasbench-201') |         meta_dict = new_graphs_to_json(graphs, 'nasbench-201', self.cfg) | ||||||
|         self.base_path = base_path |         self.base_path = base_path | ||||||
|         self.active_nodes = meta_dict['active_nodes'] |         self.active_nodes = meta_dict['active_nodes'] | ||||||
|         self.max_n_nodes = meta_dict['max_n_nodes'] |         self.max_n_nodes = meta_dict['max_n_nodes'] | ||||||
| @@ -1396,11 +1409,12 @@ def compute_meta(root, source_name, train_index, test_index): | |||||||
|         'transition_E': tansition_E.tolist(), |         'transition_E': tansition_E.tolist(), | ||||||
|         } |         } | ||||||
|  |  | ||||||
|     with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f: |     # with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f: | ||||||
|  |     with open(os.path.join(self.cfg.general.root, 'nasbench201.meta.json'), "w") as f: | ||||||
|         json.dump(meta_dict, f) |         json.dump(meta_dict, f) | ||||||
|      |      | ||||||
|     return meta_dict |     return meta_dict | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     dataset = Dataset(source='nasbench', root='/nfs/data3/hanzhang/nasbenchDiT/graph-dit', target_prop='Class', transform=None) |     dataset = Dataset(source='nasbench', root='/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/', target_prop='Class', transform=None) | ||||||
|   | |||||||
| @@ -24,7 +24,7 @@ class Graph_DiT(pl.LightningModule): | |||||||
|         self.guidance_target = getattr(cfg.dataset, 'guidance_target', None) |         self.guidance_target = getattr(cfg.dataset, 'guidance_target', None) | ||||||
|  |  | ||||||
|         from nas_201_api import NASBench201API as API |         from nas_201_api import NASBench201API as API | ||||||
|         self.api = API('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth') |         self.api = API(cfg.general.nas_201) | ||||||
|  |  | ||||||
|         input_dims = dataset_infos.input_dims |         input_dims = dataset_infos.input_dims | ||||||
|         output_dims = dataset_infos.output_dims |         output_dims = dataset_infos.output_dims | ||||||
| @@ -44,7 +44,7 @@ class Graph_DiT(pl.LightningModule): | |||||||
|         self.args.batch_size = 128 |         self.args.batch_size = 128 | ||||||
|         self.args.GPU = '0' |         self.args.GPU = '0' | ||||||
|         self.args.dataset = 'cifar10-valid' |         self.args.dataset = 'cifar10-valid' | ||||||
|         self.args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' |         self.args.api_loc = cfg.general.nas_201 | ||||||
|         self.args.data_loc = '../cifardata/' |         self.args.data_loc = '../cifardata/' | ||||||
|         self.args.seed = 777 |         self.args.seed = 777 | ||||||
|         self.args.init = '' |         self.args.init = '' | ||||||
| @@ -177,7 +177,7 @@ class Graph_DiT(pl.LightningModule): | |||||||
|                 rewards = [] |                 rewards = [] | ||||||
|                 if reward_model == 'swap': |                 if reward_model == 'swap': | ||||||
|                     import csv |                     import csv | ||||||
|                     with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f: |                     with open(self.cfg.general.swap_result, 'r') as f: | ||||||
|                         reader = csv.reader(f) |                         reader = csv.reader(f) | ||||||
|                         header = next(reader) |                         header = next(reader) | ||||||
|                         data = [row for row in reader] |                         data = [row for row in reader] | ||||||
| @@ -345,10 +345,15 @@ class Graph_DiT(pl.LightningModule): | |||||||
|                     num_examples = self.val_y_collection.size(0) |                     num_examples = self.val_y_collection.size(0) | ||||||
|                 batch_y = self.val_y_collection[start_index:start_index + to_generate]                 |                 batch_y = self.val_y_collection[start_index:start_index + to_generate]                 | ||||||
|                 all_ys.append(batch_y) |                 all_ys.append(batch_y) | ||||||
|                 samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y, |                 cur_sample, logprobs = self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y, | ||||||
|                                                 save_final=to_save, |                                                 save_final=to_save, | ||||||
|                                                 keep_chain=chains_save, |                                                 keep_chain=chains_save, | ||||||
|                                                 number_chain_steps=self.number_chain_steps)) |                                                 number_chain_steps=self.number_chain_steps) | ||||||
|  |                 samples.extend(cur_sample) | ||||||
|  |                 # samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y, | ||||||
|  |                 #                                 save_final=to_save, | ||||||
|  |                 #                                 keep_chain=chains_save, | ||||||
|  |                 #                                 number_chain_steps=self.number_chain_steps)) | ||||||
|                 ident += to_generate |                 ident += to_generate | ||||||
|                 start_index += to_generate |                 start_index += to_generate | ||||||
|  |  | ||||||
| @@ -423,7 +428,7 @@ class Graph_DiT(pl.LightningModule): | |||||||
|  |  | ||||||
|             cur_sample, log_probs = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, |             cur_sample, log_probs = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, | ||||||
|                                             keep_chain=chains_save, number_chain_steps=self.number_chain_steps) |                                             keep_chain=chains_save, number_chain_steps=self.number_chain_steps) | ||||||
|             samples.append(cur_sample)   |             samples.extend(cur_sample)   | ||||||
|              |              | ||||||
|             all_ys.append(batch_y) |             all_ys.append(batch_y) | ||||||
|             batch_id += to_generate |             batch_id += to_generate | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user