Source code for PlanetAlign.algorithms.crossmna.main

from typing import List, Tuple, Union
import torch
import numpy as np
import torch.nn.functional as F
from torch_geometric.utils import degree
import time
import psutil
import os

from PlanetAlign.data import Dataset
from PlanetAlign.algorithms.base_model import BaseModel
from PlanetAlign.utils import merge_pyg_graphs_on_anchors, get_pairwise_anchor_pairs
from PlanetAlign.metrics import hits_ks_scores, mrr_score

from .model import MultiNetworkEmb
from .sampler import AliasSampler


[docs] class CrossMNA(BaseModel): """Embedding-based method CrossMNA for plain multi-network alignment. CrossMNA is proposed by the paper: "`Cross-Network Embedding for Multi-Network Alignment. <https://doi.org/10.1145/3308558.3313499>`_" in WWW 2019. Parameters ---------- batch_size : int, optional Batch size for training. Default is 4096. neg_samples : int, optional Number of negative samples per positive sample. Default is 1. node_emb_dims : int, optional Dimensions of output node embeddings. Default is 200. graph_emb_dims : int, optional Dimensions of output graph embeddings. Default is 100. lr : float, optional Learning rate for the optimizer. Default is 0.02. dtype : torch.dtype, optional Data type of the tensors, choose from torch.float32 or torch.float64. Default is torch.float32. """ def __init__(self, batch_size: int = 512 * 8, neg_samples: int = 1, node_emb_dims: int = 200, graph_emb_dims: int = 100, lr: float = 0.02, dtype: torch.dtype = torch.float32): super(CrossMNA, self).__init__(dtype=dtype) assert isinstance(batch_size, int), 'Batch size must be an integer' assert isinstance(neg_samples, int), 'Number of negative samples must be an integer' assert isinstance(node_emb_dims, int), 'Node embedding dimensions must be an integer' assert isinstance(graph_emb_dims, int), 'Graph embedding dimensions must be an integer' assert lr > 0, 'Learning rate must be positive' self.batch_size = batch_size self.neg_samples = neg_samples self.node_emb_dims = node_emb_dims self.graph_emb_dims = graph_emb_dims self.lr = lr
[docs] def train(self, dataset: Dataset, gids: Union[List[int], Tuple[int, ...]], use_attr: bool = False, total_epochs: int = 400, save_log: bool = True, verbose: bool = True): """ Parameters ---------- dataset : Dataset The dataset containing the graphs to be aligned and the training/test data. gids : list or tuple The indices of the graphs in the dataset to be aligned. use_attr : bool, optional Whether to use node attributes for alignment. Default is True. total_epochs : int, optional Total number of training epochs. Default is 400. save_log : bool, optional Whether to save the training log. Default is True. verbose : bool, optional Whether to print training progress. Default is True. """ self.check_inputs(dataset, gids, plain_method=True, use_attr=use_attr, pairwise=False, supervised=True) logger = self.init_training_logger(dataset, use_attr, additional_headers=['memory', 'infer_time'], save_log=save_log) process = psutil.Process(os.getpid()) graphs = [dataset.pyg_graphs[gid] for gid in gids] anchor_links = dataset.train_data[:, gids] test_pairs_dict = get_pairwise_anchor_pairs(dataset.test_data[:, gids]) # Initialization alias_samplers_list = [] for gid in gids: graph = dataset.pyg_graphs[gid] alias_sampler = self._init_sampler(graph) alias_samplers_list.append(alias_sampler) _, _, id2node, node2id = merge_pyg_graphs_on_anchors(graphs, anchor_links) # Model initialization model = MultiNetworkEmb(num_of_nodes=len(id2node), num_layer=len(graphs), batch_size=self.batch_size, K=self.neg_samples, node_emb_dims=self.node_emb_dims, layer_emb_dims=self.graph_emb_dims).to(self.dtype).to(self.device) optimizer = torch.optim.RMSprop(model.parameters(), lr=self.lr, alpha=0.99, eps=1.0, centered=True, momentum=0.0) # Training out_embs_dict = {} infer_time = 0 for epoch in range(total_epochs): t0 = time.time() train_samples = self._generate_samples(graphs, alias_samplers_list) total_loss = 0 for u_i, u_j, label, gid_vec in train_samples: u_i = u_i.cpu().numpy() u_j = u_j.cpu().numpy() gid_vec = gid_vec.cpu().numpy() label = label.to(self.dtype).to(self.device) mapped_u_i = np.array([node2id[(gid, u)] for gid, u in zip(gid_vec, u_i)]) mapped_u_j = np.array([node2id[(gid, u)] for gid, u in zip(gid_vec, u_j)]) optimizer.zero_grad() loss = model(mapped_u_i, mapped_u_j, gid_vec, label) total_loss += loss.item() loss.backward() optimizer.step() t1 = time.time() infer_time += t1 - t0 with torch.no_grad(): mem_gb = process.memory_info().rss / 1024 ** 3 out_embs_dict = {} embeddings = F.normalize(model.embedding, p=2, dim=1) for gid, graph in enumerate(graphs): embs = embeddings[[node2id[(gid, u)] for u in range(graph.num_nodes)]] out_embs_dict[gids[gid]] = embs for id1 in range(len(gids)): for id2 in range(id1 + 1, len(gids)): if verbose: print(f'Graph {gids[id1]} vs Graph {gids[id2]}') test_pairs = test_pairs_dict[(id1, id2)] emb1 = out_embs_dict[gids[id1]] emb2 = out_embs_dict[gids[id2]] S = emb1 @ emb2.T hits = hits_ks_scores(S, test_pairs, mode='mean') mrr = mrr_score(S, test_pairs, mode='mean') logger.log(epoch=epoch+1, loss=total_loss, epoch_time=t1-t0, mrr=mrr, hits=hits, memory=round(mem_gb, 4), infer_time=round(infer_time, 4), verbose=verbose) return out_embs_dict, logger
def _generate_samples(self, graphs, alias_samplers_list): all_edges = torch.empty((2, 0), dtype=torch.int64) all_gid_vec = torch.empty((0,), dtype=torch.int) all_neg_samples = torch.empty((self.neg_samples, 0), dtype=torch.int64) for gid, graph in enumerate(graphs): all_gid_vec = torch.cat([all_gid_vec, torch.tensor([gid] * graph.num_edges, dtype=torch.int)]) all_edges = torch.hstack([all_edges, graph.edge_index]) neg_samples = alias_samplers_list[gid].sample(num_samples=self.neg_samples * graph.num_edges).reshape(self.neg_samples, graph.num_edges) all_neg_samples = torch.hstack([all_neg_samples, torch.from_numpy(neg_samples)]) # Shuffle sampled node pairs num_all_edges = all_edges.shape[1] perm = torch.randperm(num_all_edges) all_edges = all_edges[:, perm] all_gid_vec = all_gid_vec[perm] all_neg_samples = all_neg_samples[:, perm] ui_samples = torch.repeat_interleave(all_edges[0, :], repeats=self.neg_samples+1, dim=0) uj_samples = torch.vstack([all_edges[1, :], all_neg_samples]).T.flatten() gid_samples = torch.repeat_interleave(all_gid_vec, repeats=self.neg_samples+1) label_samples = torch.vstack([torch.ones(num_all_edges), -torch.ones(self.neg_samples, num_all_edges)]).T.flatten() assert ui_samples.shape == uj_samples.shape == gid_samples.shape == label_samples.shape, 'Shape mismatch' # Divide sampled node pairs into batches batched_samples = [] num_batches = ui_samples.shape[0] // self.batch_size for i in range(num_batches): left = i * self.batch_size right = (i + 1) * self.batch_size u_i = ui_samples[left:right] u_j = uj_samples[left:right] label = label_samples[left:right] gid_vec = gid_samples[left:right] batched_samples.append((u_i, u_j, label, gid_vec)) return batched_samples @staticmethod def _init_sampler(graph): node_positive_distribution = degree(graph.edge_index[0], num_nodes=graph.num_nodes) ** 0.75 node_positive_distribution /= node_positive_distribution.sum() return AliasSampler(prob=node_positive_distribution)