add somecomments
This commit is contained in:
		@@ -124,3 +124,7 @@ class AbstractDatasetInfos:
 | 
				
			|||||||
        self.output_dims = {'X': example_batch_x.size(1),
 | 
					        self.output_dims = {'X': example_batch_x.size(1),
 | 
				
			||||||
                            'E': example_batch_edge_attr.size(1),
 | 
					                            'E': example_batch_edge_attr.size(1),
 | 
				
			||||||
                            'y': example_batch['y'].size(1)}
 | 
					                            'y': example_batch['y'].size(1)}
 | 
				
			||||||
 | 
					        print('input dims')
 | 
				
			||||||
 | 
					        print(self.input_dims)
 | 
				
			||||||
 | 
					        print('output dims')
 | 
				
			||||||
 | 
					        print(self.output_dims)
 | 
				
			||||||
@@ -28,19 +28,38 @@ class DataModule(AbstractDataModule):
 | 
				
			|||||||
    def __init__(self, cfg):
 | 
					    def __init__(self, cfg):
 | 
				
			||||||
        self.datadir = cfg.dataset.datadir
 | 
					        self.datadir = cfg.dataset.datadir
 | 
				
			||||||
        self.task = cfg.dataset.task_name
 | 
					        self.task = cfg.dataset.task_name
 | 
				
			||||||
 | 
					        print("DataModule")
 | 
				
			||||||
 | 
					        print("task", self.task)
 | 
				
			||||||
 | 
					        print("datadir`",self.datadir)
 | 
				
			||||||
        super().__init__(cfg)
 | 
					        super().__init__(cfg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def prepare_data(self) -> None:
 | 
					    def prepare_data(self) -> None:
 | 
				
			||||||
        target = getattr(self.cfg.dataset, 'guidance_target', None)
 | 
					        target = getattr(self.cfg.dataset, 'guidance_target', None)
 | 
				
			||||||
 | 
					        print("target", target)
 | 
				
			||||||
        base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
 | 
					        base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
 | 
				
			||||||
        root_path = os.path.join(base_path, self.datadir)
 | 
					        root_path = os.path.join(base_path, self.datadir)
 | 
				
			||||||
        self.root_path = root_path
 | 
					        self.root_path = root_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        batch_size = self.cfg.train.batch_size
 | 
					        batch_size = self.cfg.train.batch_size
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
        num_workers = self.cfg.train.num_workers
 | 
					        num_workers = self.cfg.train.num_workers
 | 
				
			||||||
        pin_memory = self.cfg.dataset.pin_memory
 | 
					        pin_memory = self.cfg.dataset.pin_memory
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Load the dataset to the memory
 | 
				
			||||||
 | 
					        # Dataset has target property, root path, and transform
 | 
				
			||||||
        dataset = Dataset(source=self.task, root=root_path, target_prop=target, transform=None)
 | 
					        dataset = Dataset(source=self.task, root=root_path, target_prop=target, transform=None)
 | 
				
			||||||
 | 
					        print("len dataset", len(dataset))
 | 
				
			||||||
 | 
					        def print_data(dataset):
 | 
				
			||||||
 | 
					            print("dataset", dataset)
 | 
				
			||||||
 | 
					            print("dataset keys", dataset.keys)
 | 
				
			||||||
 | 
					            print("dataset x", dataset.x)
 | 
				
			||||||
 | 
					            print("dataset edge_index", dataset.edge_index)
 | 
				
			||||||
 | 
					            print("dataset edge_attr", dataset.edge_attr)
 | 
				
			||||||
 | 
					            print("dataset y", dataset.y)
 | 
				
			||||||
 | 
					            print("")
 | 
				
			||||||
 | 
					        print_data(dataset=dataset[0])
 | 
				
			||||||
 | 
					        print_data(dataset=dataset[1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if len(self.task.split('-')) == 2:
 | 
					        if len(self.task.split('-')) == 2:
 | 
				
			||||||
            train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)
 | 
					            train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)
 | 
				
			||||||
@@ -54,7 +73,11 @@ class DataModule(AbstractDataModule):
 | 
				
			|||||||
        
 | 
					        
 | 
				
			||||||
        train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index]
 | 
					        train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index]
 | 
				
			||||||
        self.train_dataset = train_dataset  
 | 
					        self.train_dataset = train_dataset  
 | 
				
			||||||
 | 
					        print('train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))
 | 
				
			||||||
 | 
					        print('train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
 | 
				
			||||||
 | 
					        print('dataset len', len(dataset) , 'train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))
 | 
				
			||||||
        self.train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=pin_memory)
 | 
					        self.train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=pin_memory)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)
 | 
					        self.val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)
 | 
				
			||||||
        self.test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)
 | 
					        self.test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -253,6 +276,9 @@ class DataInfos(AbstractDatasetInfos):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def compute_meta(root, source_name, train_index, test_index):
 | 
					def compute_meta(root, source_name, train_index, test_index):
 | 
				
			||||||
 | 
					    # initialize the periodic table
 | 
				
			||||||
 | 
					    # 118 elements + 1 for *
 | 
				
			||||||
 | 
					    # Initializes arrays to count the number of atoms per molecule, bond types, valencies, and transition probabilities between atom types.
 | 
				
			||||||
    pt = Chem.GetPeriodicTable()
 | 
					    pt = Chem.GetPeriodicTable()
 | 
				
			||||||
    atom_name_list = []
 | 
					    atom_name_list = []
 | 
				
			||||||
    atom_count_list = []
 | 
					    atom_count_list = []
 | 
				
			||||||
@@ -267,11 +293,13 @@ def compute_meta(root, source_name, train_index, test_index):
 | 
				
			|||||||
    valencies = [0] * 500
 | 
					    valencies = [0] * 500
 | 
				
			||||||
    tansition_E = np.zeros((118, 118, 5))
 | 
					    tansition_E = np.zeros((118, 118, 5))
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
 | 
					    # Load the data from the source file
 | 
				
			||||||
    filename = f'{source_name}.csv.gz'
 | 
					    filename = f'{source_name}.csv.gz'
 | 
				
			||||||
    df = pd.read_csv(f'{root}/{filename}')
 | 
					    df = pd.read_csv(f'{root}/{filename}')
 | 
				
			||||||
    all_index = list(range(len(df)))
 | 
					    all_index = list(range(len(df)))
 | 
				
			||||||
    non_test_index = list(set(all_index) - set(test_index))
 | 
					    non_test_index = list(set(all_index) - set(test_index))
 | 
				
			||||||
    df = df.iloc[non_test_index]
 | 
					    df = df.iloc[non_test_index]
 | 
				
			||||||
 | 
					    # extract the smiles from the dataframe
 | 
				
			||||||
    tot_smiles = df['smiles'].tolist()
 | 
					    tot_smiles = df['smiles'].tolist()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    n_atom_list = []
 | 
					    n_atom_list = []
 | 
				
			||||||
@@ -323,6 +351,11 @@ def compute_meta(root, source_name, train_index, test_index):
 | 
				
			|||||||
            bond_index = bond_type_to_index[bond_type]
 | 
					            bond_index = bond_type_to_index[bond_type]
 | 
				
			||||||
            bond_count_list[bond_index] += 2
 | 
					            bond_count_list[bond_index] += 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Update the transition matrix
 | 
				
			||||||
 | 
					            # The transition matrix is symmetric, so we update both directions
 | 
				
			||||||
 | 
					            # We also update the temporary transition matrix to check for errors
 | 
				
			||||||
 | 
					            # in the atom count
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
            tansition_E[start_index, end_index, bond_index] += 2
 | 
					            tansition_E[start_index, end_index, bond_index] += 2
 | 
				
			||||||
            tansition_E[end_index, start_index, bond_index] += 2
 | 
					            tansition_E[end_index, start_index, bond_index] += 2
 | 
				
			||||||
            tansition_E_temp[start_index, end_index, bond_index] += 2
 | 
					            tansition_E_temp[start_index, end_index, bond_index] += 2
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -76,12 +76,16 @@ class Graph_DiT(pl.LightningModule):
 | 
				
			|||||||
                                                              timesteps=cfg.model.diffusion_steps)
 | 
					                                                              timesteps=cfg.model.diffusion_steps)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        print("__init__")
 | 
				
			||||||
 | 
					        print("dataset_info.node_types", self.dataset_info.node_types)
 | 
				
			||||||
 | 
					        # dataset_info.node_types tensor([7.4826e-01, 2.6870e-02, 9.3930e-02, 4.4959e-02, 5.2982e-03, 7.5689e-04, 5.3739e-03, 1.5138e-03, 7.5689e-05, 4.3143e-03, 6.8650e-02])
 | 
				
			||||||
        x_marginals = self.dataset_info.node_types.float() / torch.sum(self.dataset_info.node_types.float())
 | 
					        x_marginals = self.dataset_info.node_types.float() / torch.sum(self.dataset_info.node_types.float())
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        e_marginals = self.dataset_info.edge_types.float() / torch.sum(self.dataset_info.edge_types.float())
 | 
					        e_marginals = self.dataset_info.edge_types.float() / torch.sum(self.dataset_info.edge_types.float())
 | 
				
			||||||
        x_marginals = x_marginals / (x_marginals ).sum()
 | 
					        x_marginals = x_marginals / (x_marginals ).sum()
 | 
				
			||||||
        e_marginals = e_marginals / (e_marginals ).sum()
 | 
					        e_marginals = e_marginals / (e_marginals ).sum()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # transition e is the probability of transitioning from x1 to x2 with e
 | 
				
			||||||
        xe_conditions = self.dataset_info.transition_E.float()
 | 
					        xe_conditions = self.dataset_info.transition_E.float()
 | 
				
			||||||
        xe_conditions = xe_conditions[self.active_index][:, self.active_index] 
 | 
					        xe_conditions = xe_conditions[self.active_index][:, self.active_index] 
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -82,6 +82,7 @@ def main(cfg: DictConfig):
 | 
				
			|||||||
    dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg)
 | 
					    dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg)
 | 
				
			||||||
    train_smiles, reference_smiles = datamodule.get_train_smiles()
 | 
					    train_smiles, reference_smiles = datamodule.get_train_smiles()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # get input output dimensions
 | 
				
			||||||
    dataset_infos.compute_input_output_dims(datamodule=datamodule)
 | 
					    dataset_infos.compute_input_output_dims(datamodule=datamodule)
 | 
				
			||||||
    train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)
 | 
					    train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -84,7 +84,7 @@ class BondMetricsCE(MetricCollection):
 | 
				
			|||||||
        ce_TR = TripleCE(3)
 | 
					        ce_TR = TripleCE(3)
 | 
				
			||||||
        super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR])
 | 
					        super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 
 | 
				
			||||||
class TrainMolecularMetricsDiscrete(nn.Module):
 | 
					class TrainMolecularMetricsDiscrete(nn.Module):
 | 
				
			||||||
    def __init__(self, dataset_infos):
 | 
					    def __init__(self, dataset_infos):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -75,28 +75,55 @@ class Denoiser(nn.Module):
 | 
				
			|||||||
            _constant_init(block.adaLN_modulation[0], 0)
 | 
					            _constant_init(block.adaLN_modulation[0], 0)
 | 
				
			||||||
        _constant_init(self.out_layer.adaLN_modulation[0], 0)
 | 
					        _constant_init(self.out_layer.adaLN_modulation[0], 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Input Parameters:
 | 
				
			||||||
 | 
					    x: Node features.
 | 
				
			||||||
 | 
					    e: Edge features.
 | 
				
			||||||
 | 
					    node_mask: Mask indicating valid nodes.
 | 
				
			||||||
 | 
					    y: Condition features.
 | 
				
			||||||
 | 
					    t: Current timestep in the diffusion process.
 | 
				
			||||||
 | 
					    unconditioned: Boolean flag indicating whether to ignore conditions.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
    def forward(self, x, e, node_mask, y, t, unconditioned):
 | 
					    def forward(self, x, e, node_mask, y, t, unconditioned):
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
 | 
					        print("Denoiser Forward")
 | 
				
			||||||
 | 
					        print(x.shape, e.shape, y.shape, t.shape, unconditioned)
 | 
				
			||||||
        force_drop_id = torch.zeros_like(y.sum(-1))
 | 
					        force_drop_id = torch.zeros_like(y.sum(-1))
 | 
				
			||||||
 | 
					        # drop the nan values
 | 
				
			||||||
        force_drop_id[torch.isnan(y.sum(-1))] = 1
 | 
					        force_drop_id[torch.isnan(y.sum(-1))] = 1
 | 
				
			||||||
        if unconditioned:
 | 
					        if unconditioned:
 | 
				
			||||||
            force_drop_id = torch.ones_like(y[:, 0])
 | 
					            force_drop_id = torch.ones_like(y[:, 0])
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        x_in, e_in, y_in = x, e, y
 | 
					        x_in, e_in, y_in = x, e, y
 | 
				
			||||||
 | 
					        # bs = batch size, n = number of nodes
 | 
				
			||||||
        bs, n, _ = x.size()
 | 
					        bs, n, _ = x.size()
 | 
				
			||||||
        x = torch.cat([x, e.reshape(bs, n, -1)], dim=-1)
 | 
					        x = torch.cat([x, e.reshape(bs, n, -1)], dim=-1)
 | 
				
			||||||
 | 
					        print("X after concat with E")
 | 
				
			||||||
 | 
					        print(x.shape)
 | 
				
			||||||
 | 
					        # self.x_embedder = nn.Linear(Xdim + max_n_nodes * Edim, hidden_size, bias=False)
 | 
				
			||||||
        x = self.x_embedder(x)
 | 
					        x = self.x_embedder(x)
 | 
				
			||||||
 | 
					        print("X after x_embedder")
 | 
				
			||||||
 | 
					        print(x.shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # self.t_embedder = TimestepEmbedder(hidden_size)
 | 
				
			||||||
        c1 = self.t_embedder(t)
 | 
					        c1 = self.t_embedder(t)
 | 
				
			||||||
 | 
					        print("C1 after t_embedder")
 | 
				
			||||||
 | 
					        print(c1.shape)
 | 
				
			||||||
        for i in range(1, self.ydim):
 | 
					        for i in range(1, self.ydim):
 | 
				
			||||||
            if i == 1:
 | 
					            if i == 1:
 | 
				
			||||||
                c2 = self.y_embedding_list[i-1](y[:, :2], self.training, force_drop_id, t)
 | 
					                c2 = self.y_embedding_list[i-1](y[:, :2], self.training, force_drop_id, t)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                c2 = c2 + self.y_embedding_list[i-1](y[:, i:i+1], self.training, force_drop_id, t)
 | 
					                c2 = c2 + self.y_embedding_list[i-1](y[:, i:i+1], self.training, force_drop_id, t)
 | 
				
			||||||
 | 
					        print("C2 after y_embedding_list")
 | 
				
			||||||
 | 
					        print(c2.shape)
 | 
				
			||||||
 | 
					        print("C1 + C2")
 | 
				
			||||||
        c = c1 + c2
 | 
					        c = c1 + c2
 | 
				
			||||||
 | 
					        print(c.shape)
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        for i, block in enumerate(self.encoders):
 | 
					        for i, block in enumerate(self.encoders):
 | 
				
			||||||
            x = block(x, c, node_mask)
 | 
					            x = block(x, c, node_mask)
 | 
				
			||||||
 | 
					        print("X after block")
 | 
				
			||||||
 | 
					        print(x.shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # X: B * N * dx, E: B * N * N * de
 | 
					        # X: B * N * dx, E: B * N * N * de
 | 
				
			||||||
        X, E, y = self.out_layer(x, x_in, e_in, c, t, node_mask)
 | 
					        X, E, y = self.out_layer(x, x_in, e_in, c, t, node_mask)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user