add train_loader and searchspace codes
This commit is contained in:
parent
5e66aa74e7
commit
f5d00be56e
@ -25,6 +25,7 @@ from sklearn.model_selection import train_test_split
|
||||
import utils as utils
|
||||
from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule
|
||||
from diffusion.distributions import DistributionNodes
|
||||
# from naswot.score_networks import get_nasbench201_idx_score
|
||||
|
||||
import networkx as nx
|
||||
|
||||
@ -679,7 +680,8 @@ class Dataset(InMemoryDataset):
|
||||
self.api = API(source)
|
||||
|
||||
data_list = []
|
||||
len_data = len(self.api)
|
||||
# len_data = len(self.api)
|
||||
len_data = 1000
|
||||
def check_valid_graph(nodes, edges):
|
||||
if len(nodes) != edges.shape[0] or len(nodes) != edges.shape[1]:
|
||||
return False
|
||||
@ -726,8 +728,11 @@ class Dataset(InMemoryDataset):
|
||||
if rand < random_ratio:
|
||||
edges[i, j] = 1
|
||||
return nodes, edges
|
||||
|
||||
def get_nasbench_201_val(idx):
|
||||
pass
|
||||
|
||||
def graph_to_graph_data(graph):
|
||||
def graph_to_graph_data(graph, idx):
|
||||
ops = graph[1]
|
||||
adj = graph[0]
|
||||
nodes = []
|
||||
@ -742,13 +747,14 @@ class Dataset(InMemoryDataset):
|
||||
if adj[start][end] == 1:
|
||||
edges_list.append((start, end))
|
||||
edge_type.append(1)
|
||||
# edges_list.append((end, start))
|
||||
# edge_type.append(1)
|
||||
edges_list.append((end, start))
|
||||
edge_type.append(1)
|
||||
|
||||
edge_index = torch.tensor(edges_list, dtype=torch.long).t()
|
||||
edge_type = torch.tensor(edge_type, dtype=torch.long)
|
||||
edge_attr = edge_type
|
||||
y = torch.tensor([0, 0], dtype=torch.float).view(1, -1)
|
||||
# y = torch.tensor([0, 0], dtype=torch.float).view(1, -1)
|
||||
y = get_nasbench_201_val(idx)
|
||||
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i)
|
||||
return data
|
||||
graph_list = []
|
||||
|
Loading…
Reference in New Issue
Block a user