add train_loader and searchspace codes
This commit is contained in:
		| @@ -25,6 +25,7 @@ from sklearn.model_selection import train_test_split | |||||||
| import utils as utils | import utils as utils | ||||||
| from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule | from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule | ||||||
| from diffusion.distributions import DistributionNodes | from diffusion.distributions import DistributionNodes | ||||||
|  | # from naswot.score_networks import get_nasbench201_idx_score | ||||||
|  |  | ||||||
| import networkx as nx | import networkx as nx | ||||||
|  |  | ||||||
| @@ -679,7 +680,8 @@ class Dataset(InMemoryDataset): | |||||||
|         self.api = API(source) |         self.api = API(source) | ||||||
|  |  | ||||||
|         data_list = [] |         data_list = [] | ||||||
|         len_data = len(self.api) |         # len_data = len(self.api) | ||||||
|  |         len_data = 1000 | ||||||
|         def check_valid_graph(nodes, edges): |         def check_valid_graph(nodes, edges): | ||||||
|             if len(nodes) != edges.shape[0] or len(nodes) != edges.shape[1]: |             if len(nodes) != edges.shape[0] or len(nodes) != edges.shape[1]: | ||||||
|                 return False |                 return False | ||||||
| @@ -727,7 +729,10 @@ class Dataset(InMemoryDataset): | |||||||
|                         edges[i, j] = 1 |                         edges[i, j] = 1 | ||||||
|             return nodes, edges |             return nodes, edges | ||||||
|          |          | ||||||
|         def graph_to_graph_data(graph): |         def get_nasbench_201_val(idx): | ||||||
|  |             pass | ||||||
|  |  | ||||||
|  |         def graph_to_graph_data(graph, idx): | ||||||
|             ops = graph[1] |             ops = graph[1] | ||||||
|             adj = graph[0] |             adj = graph[0] | ||||||
|             nodes = [] |             nodes = [] | ||||||
| @@ -742,13 +747,14 @@ class Dataset(InMemoryDataset): | |||||||
|                     if adj[start][end] == 1: |                     if adj[start][end] == 1: | ||||||
|                         edges_list.append((start, end)) |                         edges_list.append((start, end)) | ||||||
|                         edge_type.append(1) |                         edge_type.append(1) | ||||||
|                         # edges_list.append((end, start)) |                         edges_list.append((end, start)) | ||||||
|                         # edge_type.append(1) |                         edge_type.append(1) | ||||||
|              |              | ||||||
|             edge_index = torch.tensor(edges_list, dtype=torch.long).t() |             edge_index = torch.tensor(edges_list, dtype=torch.long).t() | ||||||
|             edge_type = torch.tensor(edge_type, dtype=torch.long) |             edge_type = torch.tensor(edge_type, dtype=torch.long) | ||||||
|             edge_attr = edge_type |             edge_attr = edge_type | ||||||
|             y = torch.tensor([0, 0], dtype=torch.float).view(1, -1) |             # y = torch.tensor([0, 0], dtype=torch.float).view(1, -1) | ||||||
|  |             y = get_nasbench_201_val(idx) | ||||||
|             data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i) |             data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i) | ||||||
|             return data |             return data | ||||||
|         graph_list = [] |         graph_list = [] | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user