Source code for PlanetAlign.utils.anchors

from typing import List, Tuple, Union
from collections import defaultdict
import numpy as np
from torch_geometric.data import Data
from torch_geometric.utils import degree, to_dense_adj
import torch
import torch.nn.functional as F

from PlanetAlign.data import Dataset


[docs] def get_anchor_pairs(anchor_links: torch.Tensor, gid1: int, gid2: int): """ Get anchor node pairs between two graphs identified by gid1 and gid2. Parameters: ---------- anchor_links : torch.Tensor A tensor of shape (num_anchors, num_graphs) containing anchor node indices for each graph. gid1 : int The graph ID of the first graph. gid2 : int The graph ID of the second graph. Returns: ------- anchor_pairs : torch.Tensor A tensor of shape (num_anchor_pairs, 2) containing the anchor node pairs between the two graphs. """ potential_pairs = anchor_links[:, [gid1, gid2]] anchor_pairs = potential_pairs[torch.all(potential_pairs != -1, dim=1)] return anchor_pairs
[docs] def get_pairwise_anchor_pairs(anchor_links: torch.Tensor) -> dict: """ Get anchor node pairs for all pairs of graphs. Parameters: ---------- anchor_links : torch.Tensor A tensor of shape (num_anchors, num_graphs) containing anchor node indices for each graph. Returns: ------- anchor_pairs_dict : dict A dictionary where keys are tuples of graph IDs (gid1, gid2) and values are tensors containing the anchor node pairs between the two graphs """ num_graphs = anchor_links.shape[1] anchor_pairs_dict = {} for gid1 in range(num_graphs): for gid2 in range(gid1 + 1, num_graphs): anchor_pairs = get_anchor_pairs(anchor_links, gid1, gid2) anchor_pairs_dict[(gid1, gid2)] = anchor_pairs return anchor_pairs_dict
[docs] def get_anchor_embeddings(graph: Data, anchors: torch.Tensor) -> torch.Tensor: """ Generate one-hot anchor embeddings for nodes in a graph based on given anchor nodes. Parameters: ---------- graph : torch_geometric.data.Data The input graph. anchors : torch.Tensor A tensor containing the indices of anchor nodes in the graph. Returns: ------- anchor_embeddings : torch.Tensor A tensor of shape (num_nodes, num_anchors) representing the one-hot anchor embeddings. """ device = anchors.device num_anchors = anchors.shape[0] anchor_embeddings = torch.zeros(graph.num_nodes, num_anchors, dtype=torch.float32).to(device) anchor_embeddings[anchors, torch.arange(num_anchors)] = 1 return anchor_embeddings
[docs] def get_rwr_embeddings(graph: Data, anchors: torch.Tensor, restart_prob: float = 0.15, max_iters: int = 1000, tol: float = 1e-6, connect_isolated: bool = False, dtype: torch.dtype = torch.float32) -> torch.Tensor: """ Compute Random Walk with Restart (RWR) embeddings for nodes in a graph based on given anchor nodes. Parameters: ---------- graph : torch_geometric.data.Data The input graph. anchors : torch.Tensor A tensor containing the indices of anchor nodes in the graph. restart_prob : float, optional The probability of restarting the random walk at each step. Default is 0.15. max_iters : int, optional The maximum number of iterations for RWR. Default is 1000. tol : float, optional The tolerance for convergence. Default is 1e-6. connect_isolated : bool, optional Whether to connect isolated nodes to all other nodes. Default is False. dtype : torch.dtype, optional The data type for computations. Default is torch.float32. Returns: ------- rwr_embeddings : torch.Tensor A tensor of shape (num_nodes, num_anchors) representing the RWR embeddings. """ device = anchors.device batch_landmark_vecs = torch.zeros(graph.num_nodes, len(anchors)).to(dtype).to(device) batch_landmark_vecs[anchors, torch.arange(len(anchors))] = 1 batch_rwr_vecs = torch.ones(graph.num_nodes, len(anchors)).to(dtype).to(device) adj = to_dense_adj(graph.edge_index, max_num_nodes=graph.num_nodes).squeeze().to(dtype) if connect_isolated: adj[torch.where(~adj.sum(1).bool())] = torch.ones(graph.num_nodes, dtype=dtype) trans_mat = F.normalize(adj.to(device), p=1, dim=1).T for i in range(max_iters): batch_rwr_vecs_old = torch.clone(batch_rwr_vecs) batch_rwr_vecs = (1 - restart_prob) * trans_mat @ batch_rwr_vecs + restart_prob * batch_landmark_vecs diff = torch.max(torch.abs(batch_rwr_vecs - batch_rwr_vecs_old)) if diff.item() < tol: break return batch_rwr_vecs
[docs] def merge_pyg_graphs_on_anchors(pyg_graphs: Union[List[Data], Tuple[Data, ...]], anchor_links: torch.Tensor) -> Tuple[Data, torch.Tensor, dict, dict]: """ Merge multiple PyG graphs into a single graph based on anchor links. Parameters: ---------- pyg_graphs : List[Data] or Tuple[Data, ...] A list or tuple of PyG Data objects representing the graphs to be merged. anchor_links : torch.Tensor A tensor of shape (num_anchors, num_graphs) containing anchor node indices for each graph. Returns: ------- merged_graph : Data The merged PyG graph. merged_anchors : torch.Tensor A tensor containing the indices of merged anchor nodes in the merged graph. id2node_dict : dict A dictionary mapping merged node IDs to lists of (graph_id, original_node_id) tuples. node2id_dict : dict A dictionary mapping (graph_id, original_node_id) tuples to merged node IDs. """ assert len(pyg_graphs) >= 2, 'At least two graphs are required for merging' assert len(pyg_graphs) == anchor_links.shape[1], 'Number of PyG graphs and anchor links dimension do not match' anchor_maps = dict() anchor_links = anchor_links.cpu().numpy() for anchor_link in anchor_links: true_anchor = None for gid, anchor in enumerate(anchor_link): if anchor > -1: if true_anchor is None: true_anchor = (gid, int(anchor)) else: anchor_maps[(gid, int(anchor))] = true_anchor merged_node_cnt = 0 id2node_dict, node2id_list, node2id_dict = defaultdict(list), list(), dict() merged_anchors = set() for gid, g in enumerate(pyg_graphs): node2id = np.zeros(g.num_nodes, dtype=np.int64) for node in range(g.num_nodes): if (gid, node) in anchor_maps: nid = node2id_list[anchor_maps[(gid, node)][0]][anchor_maps[(gid, node)][1]] merged_anchors.add(nid) else: nid = merged_node_cnt merged_node_cnt += 1 node2id[node] = nid node2id_dict[(gid, node)] = nid id2node_dict[nid].append((gid, node)) node2id_list.append(node2id) merged_anchors = torch.tensor(list(merged_anchors), dtype=torch.int64) # Build merged pyg graph assert all([g.num_node_features == pyg_graphs[0].num_node_features for g in pyg_graphs]), 'Node features must match' assert all([g.num_edge_features == pyg_graphs[0].num_edge_features for g in pyg_graphs]), 'Edge features must match' num_node_attr = pyg_graphs[0].num_node_features num_edge_attr = pyg_graphs[0].num_edge_features merged_edge_index = torch.empty((2, 0), dtype=torch.int64) merged_node_attr = torch.zeros((merged_node_cnt, num_node_attr), dtype=pyg_graphs[0].x.dtype) if pyg_graphs[0].x is not None else None merged_node_attr_cnt = torch.zeros(merged_node_cnt, dtype=torch.int) merged_edge_attr = torch.empty((0, num_edge_attr), dtype=pyg_graphs[0].edge_attr.dtype) if pyg_graphs[0].edge_attr is not None else None for gid, g in enumerate(pyg_graphs): lookup = torch.from_numpy(node2id_list[gid]) edge_index_mapped = lookup[g.edge_index] merged_edge_index = torch.cat([merged_edge_index, edge_index_mapped], dim=1) if g.x is not None: merged_node_attr[lookup] += g.x merged_node_attr_cnt[lookup] += 1 if g.edge_attr is not None: merged_edge_attr = torch.cat([merged_edge_attr, g.edge_attr], dim=0) if merged_node_attr is not None: merged_node_attr /= merged_node_attr_cnt.view(-1, 1) merged_graph = Data(x=merged_node_attr, edge_index=merged_edge_index, edge_attr=merged_edge_attr, num_nodes=merged_node_cnt) return merged_graph, merged_anchors, id2node_dict, node2id_dict
[docs] def infer_anchors_from_degree(dataset: Dataset, topk_ratio: float = 0.1) -> torch.Tensor: """ Infer anchor node pairs based on degree similarity. The number of anchor pairs is calculated by taking the top k% of the minimum number of nodes in the two graphs. Parameters: ---------- dataset : Dataset The dataset containing the graphs to be aligned. topk_ratio : float, optional The ratio of the number of anchor pairs to the minimum number of nodes in the two graphs. Default is 0.1. Returns: ------- anchors : torch.Tensor The inferred anchor pairs. """ g1, g2 = dataset.pyg_graphs[0], dataset.pyg_graphs[1] deg1 = degree(g1.edge_index[0], g1.num_nodes) deg2 = degree(g2.edge_index[0], g2.num_nodes) deg1 = deg1.unsqueeze(1) deg2 = deg2.unsqueeze(0) abs_diff = torch.abs(deg1 - deg2) sim_matrix = 1 / (1 + abs_diff.float()) n1, n2 = sim_matrix.shape k = int(topk_ratio * min(n1, n2)) sim_flat = sim_matrix.view(-1) _, topk_indices = torch.topk(sim_flat, k) node1_idx = topk_indices // n2 node2_idx = topk_indices % n2 anchors = torch.stack([node1_idx, node2_idx], dim=1) return anchors
[docs] def infer_anchors_from_attributes(dataset: Dataset, topk_ratio: float = 0.1) -> torch.Tensor: """ Infer anchor node pairs based on attribute similarity. The number of anchor pairs is calculated by taking the top k% of the minimum number of nodes in the two graphs. Parameters: ---------- dataset : Dataset The dataset containing the graphs to be aligned. topk_ratio : float, optional The ratio of the number of anchor pairs to the minimum number of nodes in the two graphs. Default is 0.1. Returns: ------- anchors : torch.Tensor The inferred anchor pairs. """ g1, g2 = dataset.pyg_graphs[0], dataset.pyg_graphs[1] assert g1.x is not None and g2.x is not None, "Graph attributes are required for anchor inference." assert g1.x.shape[1] == g2.x.shape[1], "Graph attributes must have the same dimension." x1 = F.normalize(g1.x, p=2, dim=1) x2 = F.normalize(g2.x, p=2, dim=1) sim_matrix = x1 @ x2.T n1, n2 = sim_matrix.shape k = int(topk_ratio * min(n1, n2)) sim_flat = sim_matrix.view(-1) _, topk_indices = torch.topk(sim_flat, k) node1_idx = topk_indices // n2 node2_idx = topk_indices % n2 anchors = torch.stack([node1_idx, node2_idx], dim=1) return anchors