try update the api in DataInfo

This commit is contained in:
mhz 2024-06-26 22:09:46 +02:00
parent 82299e5213
commit 0c7c525680
2 changed files with 16 additions and 11 deletions

View File

@ -50,12 +50,12 @@ class DataModule(AbstractDataModule):
def prepare_data(self) -> None:
target = getattr(self.cfg.dataset, 'guidance_target', None)
print("target", target)
print("target", target) # nasbench-201
# try:
# base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
# except NameError:
# base_path = pathlib.Path(os.getcwd()).parent[2]
base_path = '/home/stud/hanzhang/Graph-Dit'
base_path = '/home/stud/hanzhang/nasbenchDiT'
root_path = os.path.join(base_path, self.datadir)
self.root_path = root_path
@ -68,13 +68,15 @@ class DataModule(AbstractDataModule):
# Dataset has target property, root path, and transform
source = './NAS-Bench-201-v1_1-096897.pth'
dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None)
self.dataset = dataset
# if len(self.task.split('-')) == 2:
# train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)
# else:
train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset)
self.train_index, self.val_index, self.test_index, self.unlabeled_index = train_index, val_index, test_index, unlabeled_index
self.train_index, self.val_index, self.test_index, self.unlabeled_index = (
train_index, val_index, test_index, unlabeled_index)
train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index)
if len(unlabeled_index) > 0:
train_index = torch.cat([train_index, unlabeled_index], dim=0)
@ -477,14 +479,17 @@ def graphs_to_json(graphs, filename):
class Dataset(InMemoryDataset):
def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):
self.target_prop = target_prop
source = '/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
source = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
self.source = source
super().__init__(root, transform, pre_transform, pre_filter)
print(self.processed_paths[0]) #/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth.pt
self.api = API(source) # Initialize NAS-Bench-201 API
print('API loaded')
super().__init__(root, transform, pre_transform, pre_filter)
print('Dataset initialized')
print(self.processed_paths[0])
self.data, self.slices = torch.load(self.processed_paths[0])
self.data.edge_attr = self.data.edge_attr.squeeze()
self.data.idx = torch.arange(len(self.data.y))
print(f"self.data={self.data}, self.slices={self.slices}")
@property
def raw_file_names(self):
@ -676,7 +681,7 @@ def create_adj_matrix_and_ops(nodes, edges):
adj_matrix[src][dst] = 1
return adj_matrix, nodes
class DataInfos(AbstractDatasetInfos):
def __init__(self, datamodule, cfg):
def __init__(self, datamodule, cfg, dataset):
tasktype_dict = {
'hiv_b': 'classification',
'bace_b': 'classification',
@ -689,6 +694,7 @@ class DataInfos(AbstractDatasetInfos):
self.task = task_name
self.task_type = tasktype_dict.get(task_name, "regression")
self.ensure_connected = cfg.model.ensure_connected
self.api = dataset.api
datadir = cfg.dataset.datadir
@ -699,9 +705,9 @@ class DataInfos(AbstractDatasetInfos):
length = 15625
ops_type = {}
len_ops = set()
api = API('/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth')
# api = API('/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth')
for i in range(length):
arch_info = api.query_meta_info_by_index(i)
arch_info = self.api.query_meta_info_by_index(i)
nodes, edges = parse_architecture_string(arch_info.arch_str)
adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges)
if i < 5:
@ -716,7 +722,6 @@ class DataInfos(AbstractDatasetInfos):
graphs.append((adj_matrix, ops))
meta_dict = graphs_to_json(graphs, 'nasbench-201')
self.base_path = base_path
self.active_atoms = meta_dict['active_atoms']
self.max_n_nodes = meta_dict['max_node']
@ -930,4 +935,4 @@ def compute_meta(root, source_name, train_index, test_index):
if __name__ == "__main__":
pass
dataset = Dataset(source='nasbench', root='/home/stud/hanzhang/nasbenchDiT/graph-dit', target_prop='Class', transform=None)

0
graph_dit/workingdoc.md Normal file
View File