need to run the jupyternotebook
This commit is contained in:
parent
99163a5150
commit
e04ad5fbe7
@ -13,6 +13,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from rdkit import Chem, RDLogger
|
||||
from rdkit.Chem.rdchem import BondType as BT
|
||||
from rdkit.Chem import rdchem
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -24,6 +25,9 @@ import utils as utils
|
||||
from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule
|
||||
from diffusion.distributions import DistributionNodes
|
||||
|
||||
import networkx as nx
|
||||
|
||||
|
||||
bonds = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
|
||||
|
||||
op_to_atom = {
|
||||
@ -111,8 +115,74 @@ class DataModule(AbstractDataModule):
|
||||
print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
|
||||
return train_index, val_index, test_index, []
|
||||
|
||||
def parse_architecture_string(arch_str):
|
||||
stages = arch_str.split('+')
|
||||
nodes = ['input']
|
||||
edges = []
|
||||
|
||||
for stage in stages:
|
||||
operations = stage.strip('|').split('|')
|
||||
for op in operations:
|
||||
operation, idx = op.split('~')
|
||||
idx = int(idx)
|
||||
edges.append((idx, len(nodes))) # Add edge from idx to the new node
|
||||
nodes.append(operation)
|
||||
nodes.append('output') # Add the output node
|
||||
return nodes, edges
|
||||
|
||||
def create_molecule_from_graph(nodes, edges):
|
||||
mol = Chem.RWMol() # RWMol allows for building the molecule step by step
|
||||
atom_indices = {}
|
||||
|
||||
# Add atoms to the molecule
|
||||
for i, node in enumerate(nodes):
|
||||
atom_symbol = op_to_atom[node]
|
||||
atom = Chem.Atom(atom_symbol)
|
||||
atom_idx = mol.AddAtom(atom)
|
||||
atom_indices[i] = atom_idx
|
||||
|
||||
# Add bonds to the molecule
|
||||
for start, end in edges:
|
||||
mol.AddBond(atom_indices[start], atom_indices[end], rdchem.BondType.SINGLE)
|
||||
|
||||
return mol
|
||||
|
||||
def arch_str_to_smiles(self, arch_str):
|
||||
nodes, edges = self.parse_architecture_string(arch_str)
|
||||
mol = self.create_molecule_from_graph(nodes, edges)
|
||||
smiles = Chem.MolToSmiles(mol)
|
||||
return smiles
|
||||
|
||||
def get_train_smiles(self):
|
||||
raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.")
|
||||
# raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.")
|
||||
# train_arch_strs = []
|
||||
# test_arch_strs = []
|
||||
|
||||
# for idx in self.train_index:
|
||||
# arch_info = self.train_dataset[idx]
|
||||
# arch_str = arch_info.arch_str
|
||||
# train_arch_strs.append(arch_str)
|
||||
# for idx in self.test_index:
|
||||
# arch_info = self.train_dataset[idx]
|
||||
# arch_str = arch_info.arch_str
|
||||
# test_arch_strs.append(arch_str)
|
||||
|
||||
train_smiles = []
|
||||
test_smiles = []
|
||||
|
||||
for idx in self.train_index:
|
||||
graph = self.train_dataset[idx]
|
||||
mol = self.create_molecule_from_graph(graph.x, graph.edge_index)
|
||||
train_smiles.append(Chem.MolToSmiles(mol))
|
||||
|
||||
for idx in self.test_index:
|
||||
graph = self.train_dataset[idx]
|
||||
mol = self.create_molecule_from_graph(graph.x, graph.edge_index)
|
||||
test_smiles.append(Chem.MolToSmiles(mol))
|
||||
|
||||
# train_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in train_arch_strs]
|
||||
# test_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in test_arch_strs]
|
||||
return train_smiles, test_smiles
|
||||
|
||||
def get_data_split(self):
|
||||
raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.")
|
||||
@ -543,6 +613,96 @@ def graphs_to_json(graphs, filename):
|
||||
json.dump(meta_dict, f)
|
||||
return meta_dict
|
||||
|
||||
class Dataset(InMemoryDataset):
|
||||
def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):
|
||||
self.target_prop = target_prop
|
||||
source = '/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
||||
self.source = source
|
||||
self.api = API(source) # Initialize NAS-Bench-201 API
|
||||
super().__init__(root, transform, pre_transform, pre_filter)
|
||||
self.data, self.slices = torch.load(self.processed_paths[0])
|
||||
|
||||
@property
|
||||
def raw_file_names(self):
|
||||
return [] # NAS-Bench-201 data is loaded via the API, no raw files needed
|
||||
|
||||
@property
|
||||
def processed_file_names(self):
|
||||
return [f'{self.source}.pt']
|
||||
|
||||
def process(self):
|
||||
def parse_architecture_string(arch_str):
|
||||
stages = arch_str.split('+')
|
||||
nodes = ['input']
|
||||
edges = []
|
||||
|
||||
for stage in stages:
|
||||
operations = stage.strip('|').split('|')
|
||||
for op in operations:
|
||||
operation, idx = op.split('~')
|
||||
idx = int(idx)
|
||||
edges.append((idx, len(nodes))) # Add edge from idx to the new node
|
||||
nodes.append(operation)
|
||||
nodes.append('output') # Add the output node
|
||||
return nodes, edges
|
||||
|
||||
def create_graph(nodes, edges):
|
||||
G = nx.DiGraph()
|
||||
for i, node in enumerate(nodes):
|
||||
G.add_node(i, label=node)
|
||||
G.add_edges_from(edges)
|
||||
return G
|
||||
|
||||
def arch_to_graph(arch_str, sa, sc, target, target2=None, target3=None):
|
||||
nodes, edges = parse_architecture_string(arch_str)
|
||||
node_labels = [bonds[node] for node in nodes] # Replace with appropriate encoding if necessary
|
||||
x = torch.LongTensor(node_labels)
|
||||
|
||||
edges_list = [(start, end) for start, end in edges]
|
||||
edge_type = [bonds[nodes[end]] for start, end in edges] # Example: using end node type as edge type
|
||||
edge_index = torch.tensor(edges_list, dtype=torch.long).t().contiguous()
|
||||
edge_type = torch.tensor(edge_type, dtype=torch.long)
|
||||
edge_attr = edge_type.view(-1, 1)
|
||||
|
||||
if target3 is not None:
|
||||
y = torch.tensor([sa, sc, target, target2, target3], dtype=torch.float).view(1, -1)
|
||||
elif target2 is not None:
|
||||
y = torch.tensor([sa, sc, target, target2], dtype=torch.float).view(1, -1)
|
||||
else:
|
||||
y = torch.tensor([sa, sc, target], dtype=torch.float).view(1, -1)
|
||||
|
||||
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
|
||||
return data, nodes
|
||||
|
||||
bonds = {
|
||||
'nor_conv_1x1': 1,
|
||||
'nor_conv_3x3': 2,
|
||||
'avg_pool_3x3': 3,
|
||||
'skip_connect': 4,
|
||||
'input': 7,
|
||||
'output': 5,
|
||||
'none': 6
|
||||
}
|
||||
|
||||
# Prepare to process NAS-Bench-201 data
|
||||
data_list = []
|
||||
len_data = len(self.api) # Number of architectures
|
||||
with tqdm(total=len_data) as pbar:
|
||||
for arch_index in range(len_data):
|
||||
arch_info = self.api.query_meta_info_by_index(arch_index)
|
||||
arch_str = arch_info.arch_str
|
||||
sa = np.random.rand() # Placeholder for synthetic accessibility
|
||||
sc = np.random.rand() # Placeholder for substructure count
|
||||
target = np.random.rand() # Placeholder for target value
|
||||
target2 = np.random.rand() # Placeholder for second target value
|
||||
target3 = np.random.rand() # Placeholder for third target value
|
||||
|
||||
data, active_nodes = arch_to_graph(arch_str, sa, sc, target, target2, target3)
|
||||
data_list.append(data)
|
||||
pbar.update(1)
|
||||
|
||||
torch.save(self.collate(data_list), self.processed_paths[0])
|
||||
|
||||
class Dataset_origin(InMemoryDataset):
|
||||
def __init__(self, source, root, target_prop=None,
|
||||
transform=None, pre_transform=None, pre_filter=None):
|
||||
@ -671,7 +831,7 @@ class DataInfos(AbstractDatasetInfos):
|
||||
length = 15625
|
||||
ops_type = {}
|
||||
len_ops = set()
|
||||
api = API('../NAS-Bench-201-v1_0-e61699.pth')
|
||||
api = API('/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth')
|
||||
for i in range(length):
|
||||
arch_info = api.query_meta_info_by_index(i)
|
||||
nodes, edges = parse_architecture_string(arch_info.arch_str)
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user