need to run the jupyternotebook

This commit is contained in:
Hanzhang Ma 2024-06-12 17:56:08 +02:00
parent 99163a5150
commit e04ad5fbe7
2 changed files with 376 additions and 46082 deletions

View File

@ -13,6 +13,7 @@ import torch
import torch.nn.functional as F
from rdkit import Chem, RDLogger
from rdkit.Chem.rdchem import BondType as BT
from rdkit.Chem import rdchem
from tqdm import tqdm
import numpy as np
import pandas as pd
@ -24,6 +25,9 @@ import utils as utils
from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule
from diffusion.distributions import DistributionNodes
import networkx as nx
bonds = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
op_to_atom = {
@ -111,8 +115,74 @@ class DataModule(AbstractDataModule):
print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
return train_index, val_index, test_index, []
def parse_architecture_string(arch_str):
stages = arch_str.split('+')
nodes = ['input']
edges = []
for stage in stages:
operations = stage.strip('|').split('|')
for op in operations:
operation, idx = op.split('~')
idx = int(idx)
edges.append((idx, len(nodes))) # Add edge from idx to the new node
nodes.append(operation)
nodes.append('output') # Add the output node
return nodes, edges
def create_molecule_from_graph(nodes, edges):
mol = Chem.RWMol() # RWMol allows for building the molecule step by step
atom_indices = {}
# Add atoms to the molecule
for i, node in enumerate(nodes):
atom_symbol = op_to_atom[node]
atom = Chem.Atom(atom_symbol)
atom_idx = mol.AddAtom(atom)
atom_indices[i] = atom_idx
# Add bonds to the molecule
for start, end in edges:
mol.AddBond(atom_indices[start], atom_indices[end], rdchem.BondType.SINGLE)
return mol
def arch_str_to_smiles(self, arch_str):
nodes, edges = self.parse_architecture_string(arch_str)
mol = self.create_molecule_from_graph(nodes, edges)
smiles = Chem.MolToSmiles(mol)
return smiles
def get_train_smiles(self):
raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.")
# raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.")
# train_arch_strs = []
# test_arch_strs = []
# for idx in self.train_index:
# arch_info = self.train_dataset[idx]
# arch_str = arch_info.arch_str
# train_arch_strs.append(arch_str)
# for idx in self.test_index:
# arch_info = self.train_dataset[idx]
# arch_str = arch_info.arch_str
# test_arch_strs.append(arch_str)
train_smiles = []
test_smiles = []
for idx in self.train_index:
graph = self.train_dataset[idx]
mol = self.create_molecule_from_graph(graph.x, graph.edge_index)
train_smiles.append(Chem.MolToSmiles(mol))
for idx in self.test_index:
graph = self.train_dataset[idx]
mol = self.create_molecule_from_graph(graph.x, graph.edge_index)
test_smiles.append(Chem.MolToSmiles(mol))
# train_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in train_arch_strs]
# test_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in test_arch_strs]
return train_smiles, test_smiles
def get_data_split(self):
raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.")
@ -543,6 +613,96 @@ def graphs_to_json(graphs, filename):
json.dump(meta_dict, f)
return meta_dict
class Dataset(InMemoryDataset):
def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):
self.target_prop = target_prop
source = '/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
self.source = source
self.api = API(source) # Initialize NAS-Bench-201 API
super().__init__(root, transform, pre_transform, pre_filter)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return [] # NAS-Bench-201 data is loaded via the API, no raw files needed
@property
def processed_file_names(self):
return [f'{self.source}.pt']
def process(self):
def parse_architecture_string(arch_str):
stages = arch_str.split('+')
nodes = ['input']
edges = []
for stage in stages:
operations = stage.strip('|').split('|')
for op in operations:
operation, idx = op.split('~')
idx = int(idx)
edges.append((idx, len(nodes))) # Add edge from idx to the new node
nodes.append(operation)
nodes.append('output') # Add the output node
return nodes, edges
def create_graph(nodes, edges):
G = nx.DiGraph()
for i, node in enumerate(nodes):
G.add_node(i, label=node)
G.add_edges_from(edges)
return G
def arch_to_graph(arch_str, sa, sc, target, target2=None, target3=None):
nodes, edges = parse_architecture_string(arch_str)
node_labels = [bonds[node] for node in nodes] # Replace with appropriate encoding if necessary
x = torch.LongTensor(node_labels)
edges_list = [(start, end) for start, end in edges]
edge_type = [bonds[nodes[end]] for start, end in edges] # Example: using end node type as edge type
edge_index = torch.tensor(edges_list, dtype=torch.long).t().contiguous()
edge_type = torch.tensor(edge_type, dtype=torch.long)
edge_attr = edge_type.view(-1, 1)
if target3 is not None:
y = torch.tensor([sa, sc, target, target2, target3], dtype=torch.float).view(1, -1)
elif target2 is not None:
y = torch.tensor([sa, sc, target, target2], dtype=torch.float).view(1, -1)
else:
y = torch.tensor([sa, sc, target], dtype=torch.float).view(1, -1)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
return data, nodes
bonds = {
'nor_conv_1x1': 1,
'nor_conv_3x3': 2,
'avg_pool_3x3': 3,
'skip_connect': 4,
'input': 7,
'output': 5,
'none': 6
}
# Prepare to process NAS-Bench-201 data
data_list = []
len_data = len(self.api) # Number of architectures
with tqdm(total=len_data) as pbar:
for arch_index in range(len_data):
arch_info = self.api.query_meta_info_by_index(arch_index)
arch_str = arch_info.arch_str
sa = np.random.rand() # Placeholder for synthetic accessibility
sc = np.random.rand() # Placeholder for substructure count
target = np.random.rand() # Placeholder for target value
target2 = np.random.rand() # Placeholder for second target value
target3 = np.random.rand() # Placeholder for third target value
data, active_nodes = arch_to_graph(arch_str, sa, sc, target, target2, target3)
data_list.append(data)
pbar.update(1)
torch.save(self.collate(data_list), self.processed_paths[0])
class Dataset_origin(InMemoryDataset):
def __init__(self, source, root, target_prop=None,
transform=None, pre_transform=None, pre_filter=None):
@ -671,7 +831,7 @@ class DataInfos(AbstractDatasetInfos):
length = 15625
ops_type = {}
len_ops = set()
api = API('../NAS-Bench-201-v1_0-e61699.pth')
api = API('/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth')
for i in range(length):
arch_info = api.query_meta_info_by_index(i)
nodes, edges = parse_architecture_string(arch_info.arch_str)

File diff suppressed because one or more lines are too long