add get_train_graphs
This commit is contained in:
parent
062a27b83f
commit
a222c514d9
@ -69,6 +69,7 @@ class DataModule(AbstractDataModule):
|
||||
source = './NAS-Bench-201-v1_1-096897.pth'
|
||||
dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None)
|
||||
self.dataset = dataset
|
||||
self.api = dataset.api
|
||||
|
||||
# if len(self.task.split('-')) == 2:
|
||||
# train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)
|
||||
@ -177,6 +178,27 @@ class DataModule(AbstractDataModule):
|
||||
smiles = Chem.MolToSmiles(mol)
|
||||
return smiles
|
||||
|
||||
def get_train_graphs(self):
|
||||
train_graphs = []
|
||||
test_graphs = []
|
||||
for graph in self.train_dataset:
|
||||
train_graphs.append(graph)
|
||||
for graph in self.test_dataset:
|
||||
test_graphs.append(graph)
|
||||
return train_graphs, test_graphs
|
||||
|
||||
|
||||
# def get_train_smiles(self):
|
||||
# filename = f'{self.task}.csv.gz'
|
||||
# df = pd.read_csv(f'{self.root_path}/raw/{filename}')
|
||||
# df_test = df.iloc[self.test_index]
|
||||
# df = df.iloc[self.train_index]
|
||||
# smiles_list = df['smiles'].tolist()
|
||||
# smiles_list_test = df_test['smiles'].tolist()
|
||||
# smiles_list = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list]
|
||||
# smiles_list_test = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list_test]
|
||||
# return smiles_list, smiles_list_test
|
||||
|
||||
def get_train_smiles(self):
|
||||
train_smiles = []
|
||||
test_smiles = []
|
||||
|
Loading…
Reference in New Issue
Block a user