need to run the jupyternotebook
This commit is contained in:
parent
99163a5150
commit
e04ad5fbe7
@ -13,6 +13,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from rdkit import Chem, RDLogger
|
from rdkit import Chem, RDLogger
|
||||||
from rdkit.Chem.rdchem import BondType as BT
|
from rdkit.Chem.rdchem import BondType as BT
|
||||||
|
from rdkit.Chem import rdchem
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -24,6 +25,9 @@ import utils as utils
|
|||||||
from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule
|
from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule
|
||||||
from diffusion.distributions import DistributionNodes
|
from diffusion.distributions import DistributionNodes
|
||||||
|
|
||||||
|
import networkx as nx
|
||||||
|
|
||||||
|
|
||||||
bonds = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
|
bonds = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
|
||||||
|
|
||||||
op_to_atom = {
|
op_to_atom = {
|
||||||
@ -111,8 +115,74 @@ class DataModule(AbstractDataModule):
|
|||||||
print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
|
print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
|
||||||
return train_index, val_index, test_index, []
|
return train_index, val_index, test_index, []
|
||||||
|
|
||||||
|
def parse_architecture_string(arch_str):
|
||||||
|
stages = arch_str.split('+')
|
||||||
|
nodes = ['input']
|
||||||
|
edges = []
|
||||||
|
|
||||||
|
for stage in stages:
|
||||||
|
operations = stage.strip('|').split('|')
|
||||||
|
for op in operations:
|
||||||
|
operation, idx = op.split('~')
|
||||||
|
idx = int(idx)
|
||||||
|
edges.append((idx, len(nodes))) # Add edge from idx to the new node
|
||||||
|
nodes.append(operation)
|
||||||
|
nodes.append('output') # Add the output node
|
||||||
|
return nodes, edges
|
||||||
|
|
||||||
|
def create_molecule_from_graph(nodes, edges):
|
||||||
|
mol = Chem.RWMol() # RWMol allows for building the molecule step by step
|
||||||
|
atom_indices = {}
|
||||||
|
|
||||||
|
# Add atoms to the molecule
|
||||||
|
for i, node in enumerate(nodes):
|
||||||
|
atom_symbol = op_to_atom[node]
|
||||||
|
atom = Chem.Atom(atom_symbol)
|
||||||
|
atom_idx = mol.AddAtom(atom)
|
||||||
|
atom_indices[i] = atom_idx
|
||||||
|
|
||||||
|
# Add bonds to the molecule
|
||||||
|
for start, end in edges:
|
||||||
|
mol.AddBond(atom_indices[start], atom_indices[end], rdchem.BondType.SINGLE)
|
||||||
|
|
||||||
|
return mol
|
||||||
|
|
||||||
|
def arch_str_to_smiles(self, arch_str):
|
||||||
|
nodes, edges = self.parse_architecture_string(arch_str)
|
||||||
|
mol = self.create_molecule_from_graph(nodes, edges)
|
||||||
|
smiles = Chem.MolToSmiles(mol)
|
||||||
|
return smiles
|
||||||
|
|
||||||
def get_train_smiles(self):
|
def get_train_smiles(self):
|
||||||
raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.")
|
# raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.")
|
||||||
|
# train_arch_strs = []
|
||||||
|
# test_arch_strs = []
|
||||||
|
|
||||||
|
# for idx in self.train_index:
|
||||||
|
# arch_info = self.train_dataset[idx]
|
||||||
|
# arch_str = arch_info.arch_str
|
||||||
|
# train_arch_strs.append(arch_str)
|
||||||
|
# for idx in self.test_index:
|
||||||
|
# arch_info = self.train_dataset[idx]
|
||||||
|
# arch_str = arch_info.arch_str
|
||||||
|
# test_arch_strs.append(arch_str)
|
||||||
|
|
||||||
|
train_smiles = []
|
||||||
|
test_smiles = []
|
||||||
|
|
||||||
|
for idx in self.train_index:
|
||||||
|
graph = self.train_dataset[idx]
|
||||||
|
mol = self.create_molecule_from_graph(graph.x, graph.edge_index)
|
||||||
|
train_smiles.append(Chem.MolToSmiles(mol))
|
||||||
|
|
||||||
|
for idx in self.test_index:
|
||||||
|
graph = self.train_dataset[idx]
|
||||||
|
mol = self.create_molecule_from_graph(graph.x, graph.edge_index)
|
||||||
|
test_smiles.append(Chem.MolToSmiles(mol))
|
||||||
|
|
||||||
|
# train_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in train_arch_strs]
|
||||||
|
# test_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in test_arch_strs]
|
||||||
|
return train_smiles, test_smiles
|
||||||
|
|
||||||
def get_data_split(self):
|
def get_data_split(self):
|
||||||
raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.")
|
raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.")
|
||||||
@ -543,6 +613,96 @@ def graphs_to_json(graphs, filename):
|
|||||||
json.dump(meta_dict, f)
|
json.dump(meta_dict, f)
|
||||||
return meta_dict
|
return meta_dict
|
||||||
|
|
||||||
|
class Dataset(InMemoryDataset):
|
||||||
|
def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):
|
||||||
|
self.target_prop = target_prop
|
||||||
|
source = '/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
||||||
|
self.source = source
|
||||||
|
self.api = API(source) # Initialize NAS-Bench-201 API
|
||||||
|
super().__init__(root, transform, pre_transform, pre_filter)
|
||||||
|
self.data, self.slices = torch.load(self.processed_paths[0])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def raw_file_names(self):
|
||||||
|
return [] # NAS-Bench-201 data is loaded via the API, no raw files needed
|
||||||
|
|
||||||
|
@property
|
||||||
|
def processed_file_names(self):
|
||||||
|
return [f'{self.source}.pt']
|
||||||
|
|
||||||
|
def process(self):
|
||||||
|
def parse_architecture_string(arch_str):
|
||||||
|
stages = arch_str.split('+')
|
||||||
|
nodes = ['input']
|
||||||
|
edges = []
|
||||||
|
|
||||||
|
for stage in stages:
|
||||||
|
operations = stage.strip('|').split('|')
|
||||||
|
for op in operations:
|
||||||
|
operation, idx = op.split('~')
|
||||||
|
idx = int(idx)
|
||||||
|
edges.append((idx, len(nodes))) # Add edge from idx to the new node
|
||||||
|
nodes.append(operation)
|
||||||
|
nodes.append('output') # Add the output node
|
||||||
|
return nodes, edges
|
||||||
|
|
||||||
|
def create_graph(nodes, edges):
|
||||||
|
G = nx.DiGraph()
|
||||||
|
for i, node in enumerate(nodes):
|
||||||
|
G.add_node(i, label=node)
|
||||||
|
G.add_edges_from(edges)
|
||||||
|
return G
|
||||||
|
|
||||||
|
def arch_to_graph(arch_str, sa, sc, target, target2=None, target3=None):
|
||||||
|
nodes, edges = parse_architecture_string(arch_str)
|
||||||
|
node_labels = [bonds[node] for node in nodes] # Replace with appropriate encoding if necessary
|
||||||
|
x = torch.LongTensor(node_labels)
|
||||||
|
|
||||||
|
edges_list = [(start, end) for start, end in edges]
|
||||||
|
edge_type = [bonds[nodes[end]] for start, end in edges] # Example: using end node type as edge type
|
||||||
|
edge_index = torch.tensor(edges_list, dtype=torch.long).t().contiguous()
|
||||||
|
edge_type = torch.tensor(edge_type, dtype=torch.long)
|
||||||
|
edge_attr = edge_type.view(-1, 1)
|
||||||
|
|
||||||
|
if target3 is not None:
|
||||||
|
y = torch.tensor([sa, sc, target, target2, target3], dtype=torch.float).view(1, -1)
|
||||||
|
elif target2 is not None:
|
||||||
|
y = torch.tensor([sa, sc, target, target2], dtype=torch.float).view(1, -1)
|
||||||
|
else:
|
||||||
|
y = torch.tensor([sa, sc, target], dtype=torch.float).view(1, -1)
|
||||||
|
|
||||||
|
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
|
||||||
|
return data, nodes
|
||||||
|
|
||||||
|
bonds = {
|
||||||
|
'nor_conv_1x1': 1,
|
||||||
|
'nor_conv_3x3': 2,
|
||||||
|
'avg_pool_3x3': 3,
|
||||||
|
'skip_connect': 4,
|
||||||
|
'input': 7,
|
||||||
|
'output': 5,
|
||||||
|
'none': 6
|
||||||
|
}
|
||||||
|
|
||||||
|
# Prepare to process NAS-Bench-201 data
|
||||||
|
data_list = []
|
||||||
|
len_data = len(self.api) # Number of architectures
|
||||||
|
with tqdm(total=len_data) as pbar:
|
||||||
|
for arch_index in range(len_data):
|
||||||
|
arch_info = self.api.query_meta_info_by_index(arch_index)
|
||||||
|
arch_str = arch_info.arch_str
|
||||||
|
sa = np.random.rand() # Placeholder for synthetic accessibility
|
||||||
|
sc = np.random.rand() # Placeholder for substructure count
|
||||||
|
target = np.random.rand() # Placeholder for target value
|
||||||
|
target2 = np.random.rand() # Placeholder for second target value
|
||||||
|
target3 = np.random.rand() # Placeholder for third target value
|
||||||
|
|
||||||
|
data, active_nodes = arch_to_graph(arch_str, sa, sc, target, target2, target3)
|
||||||
|
data_list.append(data)
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
torch.save(self.collate(data_list), self.processed_paths[0])
|
||||||
|
|
||||||
class Dataset_origin(InMemoryDataset):
|
class Dataset_origin(InMemoryDataset):
|
||||||
def __init__(self, source, root, target_prop=None,
|
def __init__(self, source, root, target_prop=None,
|
||||||
transform=None, pre_transform=None, pre_filter=None):
|
transform=None, pre_transform=None, pre_filter=None):
|
||||||
@ -671,7 +831,7 @@ class DataInfos(AbstractDatasetInfos):
|
|||||||
length = 15625
|
length = 15625
|
||||||
ops_type = {}
|
ops_type = {}
|
||||||
len_ops = set()
|
len_ops = set()
|
||||||
api = API('../NAS-Bench-201-v1_0-e61699.pth')
|
api = API('/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth')
|
||||||
for i in range(length):
|
for i in range(length):
|
||||||
arch_info = api.query_meta_info_by_index(i)
|
arch_info = api.query_meta_info_by_index(i)
|
||||||
nodes, edges = parse_architecture_string(arch_info.arch_str)
|
nodes, edges = parse_architecture_string(arch_info.arch_str)
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user