import os from omegaconf import OmegaConf, open_dict import torch import torch_geometric.utils from torch_geometric.utils import to_dense_adj, to_dense_batch def create_folders(args): try: os.makedirs('graphs') os.makedirs('chains') except OSError: pass try: os.makedirs('graphs/' + args.general.name) os.makedirs('chains/' + args.general.name) except OSError: pass def normalize(X, E, y, norm_values, norm_biases, node_mask): X = (X - norm_biases[0]) / norm_values[0] E = (E - norm_biases[1]) / norm_values[1] y = (y - norm_biases[2]) / norm_values[2] diag = torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1) E[diag] = 0 return PlaceHolder(X=X, E=E, y=y).mask(node_mask) def unnormalize(X, E, y, norm_values, norm_biases, node_mask, collapse=False): """ X : node features E : edge features y : global features` norm_values : [norm value X, norm value E, norm value y] norm_biases : same order node_mask """ X = (X * norm_values[0] + norm_biases[0]) E = (E * norm_values[1] + norm_biases[1]) y = y * norm_values[2] + norm_biases[2] return PlaceHolder(X=X, E=E, y=y).mask(node_mask, collapse) def to_dense(x, edge_index, edge_attr, batch, max_num_nodes=None): X, node_mask = to_dense_batch(x=x, batch=batch, max_num_nodes=max_num_nodes) # node_mask = node_mask.float() edge_index, edge_attr = torch_geometric.utils.remove_self_loops(edge_index, edge_attr) if max_num_nodes is None: max_num_nodes = X.size(1) E = to_dense_adj(edge_index=edge_index, batch=batch, edge_attr=edge_attr, max_num_nodes=max_num_nodes) E = encode_no_edge(E) return PlaceHolder(X=X, E=E, y=None), node_mask def encode_no_edge(E): assert len(E.shape) == 4 if E.shape[-1] == 0: return E no_edge = torch.sum(E, dim=3) == 0 first_elt = E[:, :, :, 0] first_elt[no_edge] = 1 E[:, :, :, 0] = first_elt diag = torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1) E[diag] = 0 return E def update_config_with_new_keys(cfg, saved_cfg): saved_general = saved_cfg.general saved_train = saved_cfg.train saved_model = saved_cfg.model saved_dataset = saved_cfg.dataset for key, val in saved_dataset.items(): OmegaConf.set_struct(cfg.dataset, True) with open_dict(cfg.dataset): if key not in cfg.dataset.keys(): setattr(cfg.dataset, key, val) for key, val in saved_general.items(): OmegaConf.set_struct(cfg.general, True) with open_dict(cfg.general): if key not in cfg.general.keys(): setattr(cfg.general, key, val) OmegaConf.set_struct(cfg.train, True) with open_dict(cfg.train): for key, val in saved_train.items(): if key not in cfg.train.keys(): setattr(cfg.train, key, val) OmegaConf.set_struct(cfg.model, True) with open_dict(cfg.model): for key, val in saved_model.items(): if key not in cfg.model.keys(): setattr(cfg.model, key, val) return cfg class PlaceHolder: def __init__(self, X, E, y): self.X = X self.E = E self.y = y def type_as(self, x: torch.Tensor, categorical: bool = False): """ Changes the device and dtype of X, E, y. """ self.X = self.X.type_as(x) self.E = self.E.type_as(x) if categorical: self.y = self.y.type_as(x) return self def mask(self, node_mask, collapse=False): x_mask = node_mask.unsqueeze(-1) # bs, n, 1 e_mask1 = x_mask.unsqueeze(2) # bs, n, 1, 1 e_mask2 = x_mask.unsqueeze(1) # bs, 1, n, 1 if collapse: self.X = torch.argmax(self.X, dim=-1) self.E = torch.argmax(self.E, dim=-1) self.X[node_mask == 0] = - 1 self.E[(e_mask1 * e_mask2).squeeze(-1) == 0] = - 1 else: self.X = self.X * x_mask self.E = self.E * e_mask1 * e_mask2 assert torch.allclose(self.E, torch.transpose(self.E, 1, 2)) return self