Source code for PlanetAlign.algorithms.parrot

import time
import torch
import torch.nn.functional as F
from torch_geometric.utils import to_dense_adj
import time
import psutil
import os

from PlanetAlign.data import Dataset
from PlanetAlign.utils import get_anchor_pairs, get_batch_rwr_scores, get_normalized_neg_exp_dist
from PlanetAlign.metrics import hits_ks_scores, mrr_score
from .base_model import BaseModel


[docs] class PARROT(BaseModel): """OT-based method PARROT for pairwise network alignment. PARROT is proposed by the paper "`PARROT: Position-Aware Regularized Optimal Transport for Network Alignment. <https://dl.acm.org/doi/abs/10.1145/3543507.3583357>`_" in WWW 2023. Parameters ---------- alpha: float, optional The hyparameter balancing the costs computed by RWR and node attributes. Default is 0.5. rwr_restart_prob: float, optional The restart probability for the random walk with restart (RWR). Default is 0.15. gamma: float, optional The discount factor of RWR. Default is 0.75. lambda_p: float, optional The weight of the proximal point term. Default is 5e-4. lambda_e: float, optional The weight of the edge consistency term. Default is 1e-5. lambda_n: float, optional The weight of the neighborhood consistency term. Default is 5e-3. lambda_a: float, optional The weight of the alignment preference term. Default is 5e-4. dtype: torch.dtype, optional Data type of the tensors, choose from torch.float32 or torch.float64. Default is torch.float32. """ def __init__(self, alpha: float = 0.5, rwr_restart_prob: float = 0.15, gamma: float = 0.75, lambda_p: float = 5e-4, lambda_e: float = 1e-5, lambda_n: float = 5e-3, lambda_a: float = 5e-4, dtype: torch.dtype = torch.float32): super(PARROT, self).__init__(dtype=dtype) self.alpha = alpha self.rwr_restart_prob = rwr_restart_prob self.gamma = gamma self.lambda_p = lambda_p self.lambda_e = lambda_e self.lambda_n = lambda_n self.lambda_a = lambda_a
[docs] def train(self, dataset: Dataset, gid1: int, gid2: int, use_attr: bool = True, max_iters_sep_rwr: int = 100, max_iters_prod_rwr: int = 50, inner_iters: int = 5, outer_iters: int = 10, save_log: bool = True, verbose: bool = True): """ Parameters ---------- dataset : Dataset The dataset containing graphs to be aligned and the training/test data. gid1 : int The graph id of the first graph to be aligned. gid2 : int The graph id of the second graph to be aligned. use_attr : bool, optional Flag for using attributes. Default is True. max_iters_sep_rwr : int, optional Maximum number of iterations for separate RWR. Default is 100. max_iters_prod_rwr : int, optional Maximum number of iterations for product RWR. Default is 50. inner_iters : int, optional Number of inner iterations for the proximal point optimization. Default is 5. outer_iters : int, optional Number of outer iterations for the proximal point optimization. Default is 10. save_log : bool, optional Flag for saving the logs. Default is True. verbose : bool, optional Flag for printing the logs. Default is True. """ self.check_inputs(dataset, (gid1, gid2), plain_method=False, use_attr=use_attr, pairwise=True, supervised=True) logger = self.init_training_logger(dataset, use_attr, additional_headers=['memory', 'infer_time'], save_log=save_log) process = psutil.Process(os.getpid()) graph1, graph2 = dataset.pyg_graphs[gid1], dataset.pyg_graphs[gid2] anchor_links = get_anchor_pairs(dataset.train_data, gid1, gid2) test_pairs = get_anchor_pairs(dataset.test_data, gid1, gid2) # Compute transport cost matrices inf_t0 = time.time() cross_dist, intra_dist1, intra_dist2 = self.get_transport_cost(graph1, graph2, anchor_links, max_iters_sep_rwr, max_iters_prod_rwr, use_attr, verbose) cost_time = time.time() - inf_t0 # Constraint proximal point optimization S, logger = self.con_prox_pt_opt(graph1, graph2, cross_dist, intra_dist1, intra_dist2, inner_iters, outer_iters, anchor_links, test_pairs, logger, cost_time, process, verbose) return S, logger
def con_prox_pt_opt(self, graph1, graph2, cross_dist, intra_dist1, intra_dist2, inner_iters, outer_iters, anchor_links, test_pairs, logger, cost_time, process, verbose): n1, n2 = graph1.num_nodes, graph2.num_nodes infer_time = cost_time # Normalize adjacency matrices adj1 = to_dense_adj(graph1.edge_index, max_num_nodes=graph1.num_nodes).squeeze().to(self.dtype).to(self.device) adj2 = to_dense_adj(graph2.edge_index, max_num_nodes=graph2.num_nodes).squeeze().to(self.dtype).to(self.device) adj1[torch.where(~adj1.sum(1).bool())] = torch.ones(n1, dtype=self.dtype).to(self.device) adj2[torch.where(~adj2.sum(1).bool())] = torch.ones(n2, dtype=self.dtype).to(self.device) row_norm_adj1 = F.normalize(adj1, p=1, dim=1) row_norm_adj2 = F.normalize(adj2, p=1, dim=1) # Constraint proximal point iterations lambda_total = self.lambda_n + self.lambda_a + self.lambda_p lambda_e = self.lambda_e * n1 * n2 one_vec_n1 = torch.ones((n1, 1), dtype=self.dtype).to(self.device) one_vec_n2 = torch.ones((n2, 1), dtype=self.dtype).to(self.device) a = one_vec_n1 / n1 b = one_vec_n2.T / n2 r = one_vec_n1 / n1 c = one_vec_n2.T / n2 S = torch.ones((n1, n2), dtype=self.dtype).to(self.device) / (n1 * n2) H = torch.zeros((n1, n2), dtype=self.dtype).to(self.device) + 1e-6 H[anchor_links[:, 0], anchor_links[:, 1]] = 1 def mina(H_in, epsilon): in_a = torch.ones((n1, 1), dtype=self.dtype).to(self.device) / n1 return -epsilon * torch.log(torch.sum(in_a * torch.exp(-H_in / epsilon), dim=0, keepdim=True)) def minb(H_in, epsilon): in_b = torch.ones((1, n2), dtype=self.dtype).to(self.device) / n2 return -epsilon * torch.log(torch.sum(in_b * torch.exp(-H_in / epsilon), dim=1, keepdim=True)) def minaa(H_in, epsilon): return mina(H_in - torch.min(H_in, dim=0).values.view(1, -1), epsilon) + torch.min(H_in, dim=0).values.view( 1, -1) def minbb(H_in, epsilon): return minb(H_in - torch.min(H_in, dim=1).values.view(-1, 1), epsilon) + torch.min(H_in, dim=1).values.view( -1, 1) L_fixed = (intra_dist1 ** 2) @ r @ one_vec_n2.T + one_vec_n1 @ c @ (intra_dist2 ** 2).T logs = {} C_old = None if verbose: print('Starting constraint proximal point iteration') for i in range(outer_iters): if verbose: print(f'Iteration {i + 1}/{outer_iters}:', end=" ") t0 = time.time() logs[i] = {} S_old = torch.clone(S) L = L_fixed - 2 * intra_dist1 @ S @ intra_dist2.T C = cross_dist + lambda_e * L - self.lambda_n * torch.log( row_norm_adj1.T @ S @ row_norm_adj2) - self.lambda_a * torch.log(H) if C_old is None: C_old = C else: W_old = torch.sum(S * C_old) W = torch.sum(S * C) if W <= W_old: C_old = C else: C = C_old Cost = C - self.lambda_p * torch.log(S) if verbose: print('sinkhorn iteration', end=" ") for j in range(inner_iters): a = minaa(Cost - b, lambda_total) b = minbb(Cost - a, lambda_total) if verbose: print(j + 1, end=" ") S = 0.05 * S_old + 0.95 * r * torch.exp((a + b - Cost) / lambda_total) * c diff_S = torch.sum(torch.abs(S - S_old)) t1 = time.time() infer_time += t1 - t0 if verbose: print('done, time spent: {:.2f}s'.format(t1 - t0)) hits = hits_ks_scores(S, test_pairs, mode='mean') mrr = mrr_score(S, test_pairs, mode='mean') mem_gb = process.memory_info().rss / 1024 ** 3 logger.log(epoch=i+1, loss=diff_S.item(), epoch_time=t1-t0, hits=hits, mrr=mrr, memory=round(mem_gb, 4), infer_time=round(infer_time, 4), verbose=verbose) return S, logger def get_transport_cost(self, graph1, graph2, anchor_links, max_iters_sep_rwr, max_iters_prod_rwr, use_attr, verbose): # Compute position-aware cross-network transport cost if verbose: print('Computing separate RWR scores ...', end=" ") t0 = time.time() rwr_emb1 = get_batch_rwr_scores(graph1, anchor_links[:, 0], restart_prob=self.rwr_restart_prob, max_iters=max_iters_sep_rwr, connect_isolated=True, device=self.device).to(self.dtype) rwr_emb2 = get_batch_rwr_scores(graph2, anchor_links[:, 1], restart_prob=self.rwr_restart_prob, max_iters=max_iters_sep_rwr, connect_isolated=True, device=self.device).to(self.dtype) t1 = time.time() if verbose: print(f'done, time spent: {t1 - t0:.2f}s') cross_rwr_dist = get_normalized_neg_exp_dist(rwr_emb1, rwr_emb2, device=self.device).to(self.dtype) if use_attr: cross_attr_dist = get_normalized_neg_exp_dist(graph1.x, graph2.x, device=self.device).to(self.dtype) else: cross_attr_dist = get_normalized_neg_exp_dist(rwr_emb1, rwr_emb2, device=self.device).to(self.dtype) cross_node_dist = self.alpha * cross_rwr_dist + cross_attr_dist cross_node_dist[anchor_links[:, 0], anchor_links[:, 1]] = 0 adj1 = to_dense_adj(graph1.edge_index, max_num_nodes=graph1.num_nodes).squeeze().to(self.dtype).to(self.device) adj2 = to_dense_adj(graph2.edge_index, max_num_nodes=graph2.num_nodes).squeeze().to(self.dtype).to(self.device) adj1[torch.where(~adj1.sum(1).bool())] = torch.ones(graph1.num_nodes, dtype=self.dtype).to(self.device) adj2[torch.where(~adj2.sum(1).bool())] = torch.ones(graph2.num_nodes, dtype=self.dtype).to(self.device) if verbose: print('Computing product RWR scores ...', end=" ") t0 = time.time() cross_dist = self.get_product_rwr_mat(adj1, adj2, in_cross_dist=cross_node_dist, max_iters=max_iters_prod_rwr) t1 = time.time() if verbose: print(f'done, time spent: {t1 - t0:.2f}s') # Compute intra-network transport cost if use_attr: intra_dist1 = adj1 * get_normalized_neg_exp_dist(graph1.x, graph1.x, device=self.device).to(self.dtype) intra_dist2 = adj2 * get_normalized_neg_exp_dist(graph2.x, graph2.x, device=self.device).to(self.dtype) else: intra_dist1 = adj1 * get_normalized_neg_exp_dist(rwr_emb1, rwr_emb1, device=self.device).to(self.dtype) intra_dist2 = adj2 * get_normalized_neg_exp_dist(rwr_emb2, rwr_emb2, device=self.device).to(self.dtype) return cross_dist, intra_dist1, intra_dist2 def get_product_rwr_mat(self, adj1, adj2, in_cross_dist, max_iters=50, tol=1e-2): row_norm_adj1 = F.normalize(adj1.to(self.device), p=1, dim=1) row_norm_adj2 = F.normalize(adj2.to(self.device), p=1, dim=1) out_cross_dist = torch.zeros(in_cross_dist.shape).to(self.device).to(self.dtype) for i in range(max_iters): out_cross_dist_old = torch.clone(out_cross_dist) out_cross_dist = ((1 + self.rwr_restart_prob) * in_cross_dist + (1 - self.rwr_restart_prob) * self.gamma * row_norm_adj1 @ out_cross_dist @ row_norm_adj2.T) if torch.max(torch.abs(out_cross_dist - out_cross_dist_old)) < tol: break out_cross_dist = (1 - self.gamma) * out_cross_dist return out_cross_dist