need to run the jupyternotebook

This commit is contained in:
Hanzhang Ma 2024-06-12 17:56:08 +02:00
parent 99163a5150
commit e04ad5fbe7
2 changed files with 376 additions and 46082 deletions

View File

@ -13,6 +13,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from rdkit import Chem, RDLogger from rdkit import Chem, RDLogger
from rdkit.Chem.rdchem import BondType as BT from rdkit.Chem.rdchem import BondType as BT
from rdkit.Chem import rdchem
from tqdm import tqdm from tqdm import tqdm
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -24,6 +25,9 @@ import utils as utils
from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule
from diffusion.distributions import DistributionNodes from diffusion.distributions import DistributionNodes
import networkx as nx
bonds = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} bonds = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
op_to_atom = { op_to_atom = {
@ -111,8 +115,74 @@ class DataModule(AbstractDataModule):
print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index)) print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
return train_index, val_index, test_index, [] return train_index, val_index, test_index, []
def parse_architecture_string(arch_str):
stages = arch_str.split('+')
nodes = ['input']
edges = []
for stage in stages:
operations = stage.strip('|').split('|')
for op in operations:
operation, idx = op.split('~')
idx = int(idx)
edges.append((idx, len(nodes))) # Add edge from idx to the new node
nodes.append(operation)
nodes.append('output') # Add the output node
return nodes, edges
def create_molecule_from_graph(nodes, edges):
mol = Chem.RWMol() # RWMol allows for building the molecule step by step
atom_indices = {}
# Add atoms to the molecule
for i, node in enumerate(nodes):
atom_symbol = op_to_atom[node]
atom = Chem.Atom(atom_symbol)
atom_idx = mol.AddAtom(atom)
atom_indices[i] = atom_idx
# Add bonds to the molecule
for start, end in edges:
mol.AddBond(atom_indices[start], atom_indices[end], rdchem.BondType.SINGLE)
return mol
def arch_str_to_smiles(self, arch_str):
nodes, edges = self.parse_architecture_string(arch_str)
mol = self.create_molecule_from_graph(nodes, edges)
smiles = Chem.MolToSmiles(mol)
return smiles
def get_train_smiles(self): def get_train_smiles(self):
raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.") # raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.")
# train_arch_strs = []
# test_arch_strs = []
# for idx in self.train_index:
# arch_info = self.train_dataset[idx]
# arch_str = arch_info.arch_str
# train_arch_strs.append(arch_str)
# for idx in self.test_index:
# arch_info = self.train_dataset[idx]
# arch_str = arch_info.arch_str
# test_arch_strs.append(arch_str)
train_smiles = []
test_smiles = []
for idx in self.train_index:
graph = self.train_dataset[idx]
mol = self.create_molecule_from_graph(graph.x, graph.edge_index)
train_smiles.append(Chem.MolToSmiles(mol))
for idx in self.test_index:
graph = self.train_dataset[idx]
mol = self.create_molecule_from_graph(graph.x, graph.edge_index)
test_smiles.append(Chem.MolToSmiles(mol))
# train_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in train_arch_strs]
# test_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in test_arch_strs]
return train_smiles, test_smiles
def get_data_split(self): def get_data_split(self):
raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.") raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.")
@ -543,6 +613,96 @@ def graphs_to_json(graphs, filename):
json.dump(meta_dict, f) json.dump(meta_dict, f)
return meta_dict return meta_dict
class Dataset(InMemoryDataset):
def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):
self.target_prop = target_prop
source = '/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
self.source = source
self.api = API(source) # Initialize NAS-Bench-201 API
super().__init__(root, transform, pre_transform, pre_filter)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return [] # NAS-Bench-201 data is loaded via the API, no raw files needed
@property
def processed_file_names(self):
return [f'{self.source}.pt']
def process(self):
def parse_architecture_string(arch_str):
stages = arch_str.split('+')
nodes = ['input']
edges = []
for stage in stages:
operations = stage.strip('|').split('|')
for op in operations:
operation, idx = op.split('~')
idx = int(idx)
edges.append((idx, len(nodes))) # Add edge from idx to the new node
nodes.append(operation)
nodes.append('output') # Add the output node
return nodes, edges
def create_graph(nodes, edges):
G = nx.DiGraph()
for i, node in enumerate(nodes):
G.add_node(i, label=node)
G.add_edges_from(edges)
return G
def arch_to_graph(arch_str, sa, sc, target, target2=None, target3=None):
nodes, edges = parse_architecture_string(arch_str)
node_labels = [bonds[node] for node in nodes] # Replace with appropriate encoding if necessary
x = torch.LongTensor(node_labels)
edges_list = [(start, end) for start, end in edges]
edge_type = [bonds[nodes[end]] for start, end in edges] # Example: using end node type as edge type
edge_index = torch.tensor(edges_list, dtype=torch.long).t().contiguous()
edge_type = torch.tensor(edge_type, dtype=torch.long)
edge_attr = edge_type.view(-1, 1)
if target3 is not None:
y = torch.tensor([sa, sc, target, target2, target3], dtype=torch.float).view(1, -1)
elif target2 is not None:
y = torch.tensor([sa, sc, target, target2], dtype=torch.float).view(1, -1)
else:
y = torch.tensor([sa, sc, target], dtype=torch.float).view(1, -1)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
return data, nodes
bonds = {
'nor_conv_1x1': 1,
'nor_conv_3x3': 2,
'avg_pool_3x3': 3,
'skip_connect': 4,
'input': 7,
'output': 5,
'none': 6
}
# Prepare to process NAS-Bench-201 data
data_list = []
len_data = len(self.api) # Number of architectures
with tqdm(total=len_data) as pbar:
for arch_index in range(len_data):
arch_info = self.api.query_meta_info_by_index(arch_index)
arch_str = arch_info.arch_str
sa = np.random.rand() # Placeholder for synthetic accessibility
sc = np.random.rand() # Placeholder for substructure count
target = np.random.rand() # Placeholder for target value
target2 = np.random.rand() # Placeholder for second target value
target3 = np.random.rand() # Placeholder for third target value
data, active_nodes = arch_to_graph(arch_str, sa, sc, target, target2, target3)
data_list.append(data)
pbar.update(1)
torch.save(self.collate(data_list), self.processed_paths[0])
class Dataset_origin(InMemoryDataset): class Dataset_origin(InMemoryDataset):
def __init__(self, source, root, target_prop=None, def __init__(self, source, root, target_prop=None,
transform=None, pre_transform=None, pre_filter=None): transform=None, pre_transform=None, pre_filter=None):
@ -671,7 +831,7 @@ class DataInfos(AbstractDatasetInfos):
length = 15625 length = 15625
ops_type = {} ops_type = {}
len_ops = set() len_ops = set()
api = API('../NAS-Bench-201-v1_0-e61699.pth') api = API('/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth')
for i in range(length): for i in range(length):
arch_info = api.query_meta_info_by_index(i) arch_info = api.query_meta_info_by_index(i)
nodes, edges = parse_architecture_string(arch_info.arch_str) nodes, edges = parse_architecture_string(arch_info.arch_str)

File diff suppressed because one or more lines are too long