add get_train_graphs

This commit is contained in:
mhz 2024-06-26 22:42:06 +02:00
parent 062a27b83f
commit a222c514d9

View File

@ -69,6 +69,7 @@ class DataModule(AbstractDataModule):
source = './NAS-Bench-201-v1_1-096897.pth'
dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None)
self.dataset = dataset
self.api = dataset.api
# if len(self.task.split('-')) == 2:
# train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)
@ -177,6 +178,27 @@ class DataModule(AbstractDataModule):
smiles = Chem.MolToSmiles(mol)
return smiles
def get_train_graphs(self):
train_graphs = []
test_graphs = []
for graph in self.train_dataset:
train_graphs.append(graph)
for graph in self.test_dataset:
test_graphs.append(graph)
return train_graphs, test_graphs
# def get_train_smiles(self):
# filename = f'{self.task}.csv.gz'
# df = pd.read_csv(f'{self.root_path}/raw/{filename}')
# df_test = df.iloc[self.test_index]
# df = df.iloc[self.train_index]
# smiles_list = df['smiles'].tolist()
# smiles_list_test = df_test['smiles'].tolist()
# smiles_list = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list]
# smiles_list_test = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list_test]
# return smiles_list, smiles_list_test
def get_train_smiles(self):
train_smiles = []
test_smiles = []