diff --git a/graph_dit/datasets/dataset.py b/graph_dit/datasets/dataset.py index 8fe95ad..94d9437 100644 --- a/graph_dit/datasets/dataset.py +++ b/graph_dit/datasets/dataset.py @@ -25,7 +25,9 @@ from sklearn.model_selection import train_test_split import utils as utils from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule from diffusion.distributions import DistributionNodes -# from naswot.score_networks import get_nasbench201_idx_score +from naswot.score_networks import get_nasbench201_idx_score +from naswot import nasspace +from naswot import datasets as dt import networkx as nx @@ -682,7 +684,7 @@ class Dataset(InMemoryDataset): data_list = [] # len_data = len(self.api) - len_data = 1000 + len_data = 15625 def check_valid_graph(nodes, edges): if len(nodes) != edges.shape[0] or len(nodes) != edges.shape[1]: return False @@ -745,11 +747,9 @@ class Dataset(InMemoryDataset): print(f'edges size: {edges.shape}, nodes size: {len(nodes)}') return edges,nodes - def get_nasbench_201_val(idx): - pass - # def graph_to_graph_data(graph, idx): - def graph_to_graph_data(graph): + def graph_to_graph_data(graph, idx, train_loader, searchspace, args, device): + # def graph_to_graph_data(graph): ops = graph[1] adj = graph[0] nodes = [] @@ -770,12 +770,49 @@ class Dataset(InMemoryDataset): edge_index = torch.tensor(edges_list, dtype=torch.long).t() edge_type = torch.tensor(edge_type, dtype=torch.long) edge_attr = edge_type - y = torch.tensor([0, 0], dtype=torch.float).view(1, -1) - # y = get_nasbench_201_val(idx) - data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i) + # y = torch.tensor([0, 0], dtype=torch.float).view(1, -1) + y = get_nasbench201_idx_score(idx, train_loader, searchspace, args, device) + print(y, idx) + if y > 1600: + print(f'idx={idx}, y={y}') + y = torch.tensor([1, 1], dtype=torch.float).view(1, -1) + data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i) + else: + print(f'idx={idx}, y={y}') + y = torch.tensor([0, 0], dtype=torch.float).view(1, -1) + data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i) + return None return data graph_list = [] - + class Args: + pass + args = Args() + args.trainval = True + args.augtype = 'none' + args.repeat = 1 + args.score = 'hook_logdet' + args.sigma = 0.05 + args.nasspace = 'nasbench201' + args.batch_size = 128 + args.GPU = '0' + args.dataset = 'cifar10' + args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' + args.data_loc = '../cifardata/' + args.seed = 777 + args.init = '' + args.save_loc = 'results' + args.save_string = 'naswot' + args.dropout = False + args.maxofn = 1 + args.n_samples = 100 + args.n_runs = 500 + args.stem_out_channels = 16 + args.num_stacks = 3 + args.num_modules_per_stack = 3 + args.num_labels = 1 + searchspace = nasspace.get_search_space(args) + train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) + device = torch.device('cuda:2') with tqdm(total = len_data) as pbar: active_nodes = set() file_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json' @@ -785,6 +822,7 @@ class Dataset(InMemoryDataset): flex_graph_list = [] flex_graph_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json' for graph in graph_list: + print(f'iterate every graph in graph_list, here is {i}') # arch_info = self.api.query_meta_info_by_index(i) # results = self.api.query_by_index(i, 'cifar100') arch_info = graph['arch_str'] @@ -796,8 +834,11 @@ class Dataset(InMemoryDataset): for op in ops: if op not in active_nodes: active_nodes.add(op) - - data = graph_to_graph_data((adj_matrix, ops)) + data = graph_to_graph_data((adj_matrix, ops),idx=i, train_loader=train_loader, searchspace=searchspace, args=args, device=device) + i += 1 + if data is None: + pbar.update(1) + continue # with open(flex_graph_path, 'a') as f: # flex_graph = { # 'adj_matrix': adj_matrix, @@ -816,18 +857,12 @@ class Dataset(InMemoryDataset): f.write(str(data.edge_attr)) data_list.append(data) - new_adj, new_ops = generate_flex_adj_mat(ori_nodes=ori_nodes, ori_edges=ori_adj, max_nodes=12, min_nodes=9, random_ratio=0.5) - flex_graph_list.append({ - 'adj_matrix':new_adj.tolist(), - 'ops': new_ops, - }) - # with open(flex_graph_path, 'w') as f: - # flex_graph = { - # 'adj_matrix': new_adj.tolist(), - # 'ops': new_ops, - # } - # json.dump(flex_graph, f) - data_list.append(graph_to_graph_data((new_adj, new_ops))) + # new_adj, new_ops = generate_flex_adj_mat(ori_nodes=ori_nodes, ori_edges=ori_adj, max_nodes=12, min_nodes=9, random_ratio=0.5) + # flex_graph_list.append({ + # 'adj_matrix':new_adj.tolist(), + # 'ops': new_ops, + # }) + # data_list.append(graph_to_graph_data((new_adj, new_ops))) # graph_list.append({ # "adj_matrix": adj_matrix, @@ -859,6 +894,7 @@ class Dataset(InMemoryDataset): # "seed": seed, # }for seed, result in results.items()] # }) + # i += 1 pbar.update(1) for graph in graph_list: @@ -872,8 +908,8 @@ class Dataset(InMemoryDataset): graph['ops'] = ops with open(f'nasbench-201-graph.json', 'w') as f: json.dump(graph_list, f) - with open(flex_graph_path, 'w') as f: - json.dump(flex_graph_list, f) + # with open(flex_graph_path, 'w') as f: + # json.dump(flex_graph_list, f) torch.save(self.collate(data_list), self.processed_paths[0]) @@ -1148,7 +1184,8 @@ class DataInfos(AbstractDatasetInfos): # ops_type[op] = len(ops_type) # len_ops.add(len(ops)) # graphs.append((adj_matrix, ops)) - graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json') + # graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json') + graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json') # check first five graphs for i in range(5):