add train_loader and searchspace codes

This commit is contained in:
mhz 2024-07-30 00:12:37 +02:00
parent 5e66aa74e7
commit f5d00be56e

View File

@ -25,6 +25,7 @@ from sklearn.model_selection import train_test_split
import utils as utils
from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule
from diffusion.distributions import DistributionNodes
# from naswot.score_networks import get_nasbench201_idx_score
import networkx as nx
@ -679,7 +680,8 @@ class Dataset(InMemoryDataset):
self.api = API(source)
data_list = []
len_data = len(self.api)
# len_data = len(self.api)
len_data = 1000
def check_valid_graph(nodes, edges):
if len(nodes) != edges.shape[0] or len(nodes) != edges.shape[1]:
return False
@ -726,8 +728,11 @@ class Dataset(InMemoryDataset):
if rand < random_ratio:
edges[i, j] = 1
return nodes, edges
def get_nasbench_201_val(idx):
pass
def graph_to_graph_data(graph):
def graph_to_graph_data(graph, idx):
ops = graph[1]
adj = graph[0]
nodes = []
@ -742,13 +747,14 @@ class Dataset(InMemoryDataset):
if adj[start][end] == 1:
edges_list.append((start, end))
edge_type.append(1)
# edges_list.append((end, start))
# edge_type.append(1)
edges_list.append((end, start))
edge_type.append(1)
edge_index = torch.tensor(edges_list, dtype=torch.long).t()
edge_type = torch.tensor(edge_type, dtype=torch.long)
edge_attr = edge_type
y = torch.tensor([0, 0], dtype=torch.float).view(1, -1)
# y = torch.tensor([0, 0], dtype=torch.float).view(1, -1)
y = get_nasbench_201_val(idx)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i)
return data
graph_list = []