add config path
This commit is contained in:
		| @@ -16,9 +16,12 @@ general: | ||||
|     final_model_chains_to_save: 1 | ||||
|     enable_progress_bar: False | ||||
|     save_model: True | ||||
|     log_dir: '/nfs/data3/hanzhang/nasbenchDiT' | ||||
|     log_dir: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT' | ||||
|     number_checkpoint_limit: 3 | ||||
|     type: 'Trainer' | ||||
|     nas_201: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||
|     swap_result: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/swap_results.csv' | ||||
|     root: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/' | ||||
| model: | ||||
|     type: 'discrete' | ||||
|     transition: 'marginal'                   | ||||
|   | ||||
| @@ -25,7 +25,6 @@ from sklearn.model_selection import train_test_split | ||||
| import utils as utils | ||||
| from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule | ||||
| from diffusion.distributions import DistributionNodes | ||||
| from naswot.score_networks import get_nasbench201_idx_score | ||||
| from naswot import nasspace | ||||
| from naswot import datasets as dt | ||||
|  | ||||
| @@ -72,7 +71,9 @@ class DataModule(AbstractDataModule): | ||||
|         #     base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] | ||||
|         # except NameError: | ||||
|         # base_path = pathlib.Path(os.getcwd()).parent[2] | ||||
|         base_path = '/nfs/data3/hanzhang/nasbenchDiT' | ||||
|         # base_path = '/nfs/data3/hanzhang/nasbenchDiT' | ||||
|         base_path = os.path.join(self.cfg.general.root, "..") | ||||
|  | ||||
|         root_path = os.path.join(base_path, self.datadir) | ||||
|         self.root_path = root_path | ||||
|  | ||||
| @@ -84,7 +85,7 @@ class DataModule(AbstractDataModule): | ||||
|         # Load the dataset to the memory | ||||
|         # Dataset has target property, root path, and transform | ||||
|         source = './NAS-Bench-201-v1_1-096897.pth' | ||||
|         dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None) | ||||
|         dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None, cfg=self.cfg) | ||||
|         self.dataset = dataset | ||||
|         # self.api = dataset.api | ||||
|  | ||||
| @@ -384,7 +385,7 @@ class DataModule_original(AbstractDataModule): | ||||
|     def test_dataloader(self): | ||||
|         return self.test_loader | ||||
|  | ||||
| def new_graphs_to_json(graphs, filename): | ||||
| def new_graphs_to_json(graphs, filename, cfg): | ||||
|     source_name = "nasbench-201" | ||||
|     num_graph = len(graphs) | ||||
|  | ||||
| @@ -491,8 +492,9 @@ def new_graphs_to_json(graphs, filename): | ||||
|         'num_active_nodes': len(active_nodes), | ||||
|         'transition_E': transition_E.tolist(), | ||||
|     } | ||||
|  | ||||
|     with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f: | ||||
|     import os | ||||
|     # with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f: | ||||
|     with open(os.path.join(cfg.general.root,'nasbench-201-meta.json'), 'w') as f: | ||||
|         json.dump(meta_dict, f) | ||||
|      | ||||
|     return meta_dict | ||||
| @@ -656,9 +658,11 @@ def graphs_to_json(graphs, filename): | ||||
|         json.dump(meta_dict, f) | ||||
|     return meta_dict | ||||
| class Dataset(InMemoryDataset): | ||||
|     def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None): | ||||
|     def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None, cfg=None): | ||||
|         self.target_prop = target_prop | ||||
|         source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||
|         self.cfg = cfg | ||||
|         # source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||
|         source = os.path.join(self.cfg.general.root, 'NAS-Bench-201-v1_1-096897.pth') | ||||
|         self.source = source | ||||
|         # self.api = API(source)  # Initialize NAS-Bench-201 API | ||||
|         # print('API loaded') | ||||
| @@ -679,7 +683,8 @@ class Dataset(InMemoryDataset): | ||||
|         return [f'{self.source}.pt'] | ||||
|  | ||||
|     def process(self): | ||||
|         source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||
|         # source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||
|         source = self.cfg.general.nas_201 | ||||
|         # self.api = API(source) | ||||
|  | ||||
|         data_list = [] | ||||
| @@ -748,7 +753,8 @@ class Dataset(InMemoryDataset): | ||||
|             return  edges,nodes | ||||
|          | ||||
|  | ||||
|         def graph_to_graph_data(graph, idx, train_loader, searchspace, args, device): | ||||
|         # def graph_to_graph_data(graph, idx, train_loader, searchspace, args, device): | ||||
|         def graph_to_graph_data(graph, idx,  args, device): | ||||
|         # def graph_to_graph_data(graph): | ||||
|             ops = graph[1] | ||||
|             adj = graph[0] | ||||
| @@ -797,7 +803,7 @@ class Dataset(InMemoryDataset): | ||||
|         args.batch_size = 128 | ||||
|         args.GPU = '0' | ||||
|         args.dataset = 'cifar10' | ||||
|         args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||
|         args.api_loc = self.cfg.general.nas_201  | ||||
|         args.data_loc = '../cifardata/' | ||||
|         args.seed = 777 | ||||
|         args.init = '' | ||||
| @@ -812,10 +818,11 @@ class Dataset(InMemoryDataset): | ||||
|         args.num_modules_per_stack = 3 | ||||
|         args.num_labels = 1 | ||||
|         searchspace = nasspace.get_search_space(args) | ||||
|         train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) | ||||
|         # train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) | ||||
|         self.swap_scores = [] | ||||
|         import csv | ||||
|         with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f: | ||||
|         # with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f: | ||||
|         with open(self.cfg.general.swap_result, 'r') as f: | ||||
|         # with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results_cifar100.csv', 'r') as f: | ||||
|             reader = csv.reader(f) | ||||
|             header = next(reader) | ||||
| @@ -824,12 +831,15 @@ class Dataset(InMemoryDataset): | ||||
|         device = torch.device('cuda:2') | ||||
|         with tqdm(total = len_data) as pbar: | ||||
|             active_nodes = set() | ||||
|             file_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json' | ||||
|             import os | ||||
|             # file_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json' | ||||
|             file_path = os.path.join(self.cfg.general.root, 'nasbench-201-graph.json') | ||||
|             with open(file_path, 'r') as f: | ||||
|                 graph_list = json.load(f) | ||||
|             i = 0 | ||||
|             flex_graph_list = [] | ||||
|             flex_graph_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json' | ||||
|             # flex_graph_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json' | ||||
|             flex_graph_path = os.path.join(self.cfg.general.root,'flex-nasbench201-graph.json') | ||||
|             for graph in graph_list: | ||||
|                 print(f'iterate every graph in graph_list, here is {i}') | ||||
|                 arch_info = graph['arch_str'] | ||||
| @@ -837,7 +847,8 @@ class Dataset(InMemoryDataset): | ||||
|                 for op in ops: | ||||
|                     if op not in active_nodes: | ||||
|                         active_nodes.add(op) | ||||
|                 data = graph_to_graph_data((adj_matrix, ops),idx=i, train_loader=train_loader, searchspace=searchspace, args=args, device=device)  | ||||
|                 # data = graph_to_graph_data((adj_matrix, ops),idx=i, train_loader=train_loader, searchspace=searchspace, args=args, device=device)  | ||||
|                 data = graph_to_graph_data((adj_matrix, ops),idx=i, args=args, device=device)  | ||||
|                 i += 1 | ||||
|                 if data is None: | ||||
|                     pbar.update(1) | ||||
| @@ -1140,6 +1151,7 @@ class DataInfos(AbstractDatasetInfos): | ||||
|         self.task = task_name | ||||
|         self.task_type = tasktype_dict.get(task_name, "regression") | ||||
|         self.ensure_connected = cfg.model.ensure_connected | ||||
|         self.cfg = cfg | ||||
|         # self.api = dataset.api | ||||
|  | ||||
|         datadir = cfg.dataset.datadir | ||||
| @@ -1182,14 +1194,15 @@ class DataInfos(AbstractDatasetInfos): | ||||
|             # len_ops.add(len(ops)) | ||||
|             # graphs.append((adj_matrix, ops)) | ||||
|         # graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json') | ||||
|         graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json') | ||||
|         # graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json') | ||||
|         graphs = read_adj_ops_from_json(os.path.join(self.cfg.general.root, 'nasbench-201-graph.json')) | ||||
|  | ||||
|         # check first five graphs | ||||
|         for i in range(5): | ||||
|             print(f'graph {i} : {graphs[i]}') | ||||
|         # print(f'ops_type: {ops_type}') | ||||
|  | ||||
|         meta_dict = new_graphs_to_json(graphs, 'nasbench-201') | ||||
|         meta_dict = new_graphs_to_json(graphs, 'nasbench-201', self.cfg) | ||||
|         self.base_path = base_path | ||||
|         self.active_nodes = meta_dict['active_nodes'] | ||||
|         self.max_n_nodes = meta_dict['max_n_nodes'] | ||||
| @@ -1396,11 +1409,12 @@ def compute_meta(root, source_name, train_index, test_index): | ||||
|         'transition_E': tansition_E.tolist(), | ||||
|         } | ||||
|  | ||||
|     with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f: | ||||
|     # with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f: | ||||
|     with open(os.path.join(self.cfg.general.root, 'nasbench201.meta.json'), "w") as f: | ||||
|         json.dump(meta_dict, f) | ||||
|      | ||||
|     return meta_dict | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     dataset = Dataset(source='nasbench', root='/nfs/data3/hanzhang/nasbenchDiT/graph-dit', target_prop='Class', transform=None) | ||||
|     dataset = Dataset(source='nasbench', root='/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/', target_prop='Class', transform=None) | ||||
|   | ||||
| @@ -24,7 +24,7 @@ class Graph_DiT(pl.LightningModule): | ||||
|         self.guidance_target = getattr(cfg.dataset, 'guidance_target', None) | ||||
|  | ||||
|         from nas_201_api import NASBench201API as API | ||||
|         self.api = API('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth') | ||||
|         self.api = API(cfg.general.nas_201) | ||||
|  | ||||
|         input_dims = dataset_infos.input_dims | ||||
|         output_dims = dataset_infos.output_dims | ||||
| @@ -44,7 +44,7 @@ class Graph_DiT(pl.LightningModule): | ||||
|         self.args.batch_size = 128 | ||||
|         self.args.GPU = '0' | ||||
|         self.args.dataset = 'cifar10-valid' | ||||
|         self.args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||
|         self.args.api_loc = cfg.general.nas_201 | ||||
|         self.args.data_loc = '../cifardata/' | ||||
|         self.args.seed = 777 | ||||
|         self.args.init = '' | ||||
| @@ -177,7 +177,7 @@ class Graph_DiT(pl.LightningModule): | ||||
|                 rewards = [] | ||||
|                 if reward_model == 'swap': | ||||
|                     import csv | ||||
|                     with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f: | ||||
|                     with open(self.cfg.general.swap_result, 'r') as f: | ||||
|                         reader = csv.reader(f) | ||||
|                         header = next(reader) | ||||
|                         data = [row for row in reader] | ||||
| @@ -345,10 +345,15 @@ class Graph_DiT(pl.LightningModule): | ||||
|                     num_examples = self.val_y_collection.size(0) | ||||
|                 batch_y = self.val_y_collection[start_index:start_index + to_generate]                 | ||||
|                 all_ys.append(batch_y) | ||||
|                 samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y, | ||||
|                 cur_sample, logprobs = self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y, | ||||
|                                                 save_final=to_save, | ||||
|                                                 keep_chain=chains_save, | ||||
|                                                 number_chain_steps=self.number_chain_steps)) | ||||
|                                                 number_chain_steps=self.number_chain_steps) | ||||
|                 samples.extend(cur_sample) | ||||
|                 # samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y, | ||||
|                 #                                 save_final=to_save, | ||||
|                 #                                 keep_chain=chains_save, | ||||
|                 #                                 number_chain_steps=self.number_chain_steps)) | ||||
|                 ident += to_generate | ||||
|                 start_index += to_generate | ||||
|  | ||||
| @@ -423,7 +428,7 @@ class Graph_DiT(pl.LightningModule): | ||||
|  | ||||
|             cur_sample, log_probs = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, | ||||
|                                             keep_chain=chains_save, number_chain_steps=self.number_chain_steps) | ||||
|             samples.append(cur_sample)   | ||||
|             samples.extend(cur_sample)   | ||||
|              | ||||
|             all_ys.append(batch_y) | ||||
|             batch_id += to_generate | ||||
|   | ||||
		Reference in New Issue
	
	Block a user