Source code for PlanetAlign.algorithms.nettrans.main

import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch_geometric.utils import degree
import time
import psutil
import os

from PlanetAlign.data import Dataset
from PlanetAlign.utils import get_anchor_pairs, get_anchor_embeddings, balance_samples
from PlanetAlign.algorithms.base_model import BaseModel
from PlanetAlign.metrics import hits_ks_scores, mrr_score

from .model import NetTransModel
from .data import ContextDataset
from .sampling import *


[docs] class NetTrans(BaseModel): r"""Embedding-based method NetTrans for pairwise network alignment via network transformation. NetTrans is proposed by the paper "`NetTrans: Neural Cross-Network Transformation. <https://dl.acm.org/doi/10.1145/3394486.3403141>`_" in KDD 2020. Parameters ---------- hid_dim : int, optional Hidden dimension of the model. Default is 128. depth : int, optional Network depth of the model. Default is 2. pooling_ratio : float, optional Pooling ratio of the model. Default is 0.2. attr_coeff : float, optional Coefficient for the attribute loss. Default is 1.0. adj_coeff : float, optional Coefficient for the structural loss. Default is 1.0. rank_coeff : float, optional Coefficient for the ranking loss. Default is 1.0. margin : float, optional Margin for the ranking loss. Default is 1.0. neg_size : int, optional Number of negative samples per anchor link. Default is 20. batch_size : int, optional Batch size for training. Default is 300. lr : float, optional Learning rate for the optimizer. Default is 0.001. temperature : float, optional Initial temperature for the model. Default is 1.0. min_temperature : float, optional Minimum temperature for the model. Default is 0.1. anneal_rate : float, optional Anneal rate for the temperature. Default is 2e-5. dtype : torch.dtype, optional Data type of the tensors, choose from torch.float32 or torch.float64. Default is torch.float32. """ def __init__(self, hid_dim: int = 128, depth: int = 2, pooling_ratio: float = 0.2, attr_coeff: float = 1., adj_coeff: float = 1., rank_coeff: float = 1., margin: float = 1., neg_size: int = 20, batch_size: int = 300, lr: float = 0.001, temperature: float = 1., min_temperature: float = 0.1, anneal_rate: float = 2e-5, dtype: torch.dtype = torch.float32): super(NetTrans, self).__init__(dtype=dtype) self.hid_dim = hid_dim self.depth = depth self.pooling_ratio = pooling_ratio self.attr_coeff = attr_coeff self.adj_coeff = adj_coeff self.rank_coeff = rank_coeff self.margin = margin self.neg_size = neg_size self.batch_size = batch_size self.lr = lr self.temperature = temperature self.min_temperature = min_temperature self.anneal_rate = anneal_rate
[docs] def train(self, dataset: Dataset, gid1: int, gid2: int, use_attr: bool = True, total_epochs: int = 50, save_log: bool = True, verbose: bool = True): """ Parameters ---------- dataset : Dataset The dataset containing graphs to be aligned and the training/test data. gid1 : int The graph id of the first graph to be aligned. gid2 : int The graph id of the second graph to be aligned. use_attr : bool, optional Flag for using attributes. Default is True. total_epochs : int, optional Maximum number of training epochs. Default is 50. save_log : bool, optional Flag for saving the training log. Default is True. verbose : bool, optional Whether to print the progress during training. Default is True. """ self.check_inputs(dataset, (gid1, gid2), plain_method=False, use_attr=use_attr, pairwise=True, supervised=True) logger = self.init_training_logger(dataset, use_attr, additional_headers=['memory', 'infer_time'], save_log=save_log) process = psutil.Process(os.getpid()) graph1, graph2 = dataset.pyg_graphs[gid1], dataset.pyg_graphs[gid2] n1, n2 = graph1.num_nodes, graph2.num_nodes anchor_links = get_anchor_pairs(dataset.train_data, gid1, gid2) test_pairs = get_anchor_pairs(dataset.test_data, gid1, gid2) # Embedding initialization anchor_embeddings1 = get_anchor_embeddings(graph1, anchor_links[:, 0]).to(self.dtype) anchor_embeddings2 = get_anchor_embeddings(graph2, anchor_links[:, 1]).to(self.dtype) node_attr1 = torch.cat([graph1.x.to(self.dtype), anchor_embeddings1], dim=1) if use_attr else anchor_embeddings1 node_attr2 = torch.cat([graph2.x.to(self.dtype), anchor_embeddings2], dim=1) if use_attr else anchor_embeddings2 node_attr1, node_attr2 = node_attr1.to(self.device), node_attr2.to(self.device) edge_weight1 = torch.ones(graph1.edge_index.shape[1], dtype=self.dtype).to(self.device) edge_weight2 = torch.ones(graph2.edge_index.shape[1], dtype=self.dtype).to(self.device) # Data preparation anchor_context_pairs1 = self._sample_anchor_neighbor_pairs(graph1, anchor_links[:, 0]) anchor_context_pairs2 = self._sample_anchor_neighbor_pairs(graph2, anchor_links[:, 1]) anchor_context_pairs1, anchor_context_pairs2 = balance_samples(anchor_context_pairs1, anchor_context_pairs2) neg_context_prob1, anchor_map1 = self._get_neg_context_prob(graph1, anchor_links[:, 0]) neg_context_prob2, anchor_map2 = self._get_neg_context_prob(graph2, anchor_links[:, 1]) # Model initialization anchor_context_dataset = ContextDataset(anchor_context_pairs1, anchor_context_pairs2) data_loader = DataLoader(dataset=anchor_context_dataset, batch_size=self.batch_size, shuffle=True) model = NetTransModel(in_dim=node_attr1.shape[1], hid_dim=self.hid_dim, out_dim=node_attr2.shape[1], pooling_ratio=self.pooling_ratio, depth=self.depth, margin=self.margin).to(self.dtype).to(self.device) optimizer = torch.optim.Adam(model.parameters(), lr=self.lr) # Training edge_index1, edge_index2 = graph1.edge_index.to(self.device), graph2.edge_index.to(self.device) anchor_links = anchor_links.to(self.device) emb1, emb2 = None, None for epoch in range(total_epochs): t0 = time.time() model.train() if epoch % 10 == 1: self.temperature = max(self.temperature * np.exp(-self.anneal_rate * epoch), self.min_temperature) total_loss = 0 x, y = None, None infer_time = 0 for i, (context1, context2) in enumerate(data_loader): context1, context2 = context1.to(self.device), context2.to(self.device) optimizer.zero_grad() infer_t0 = time.time() x, _, y, _, _, _ = model(x=node_attr1, edge_index=edge_index1, edge_weight=edge_weight1, y=node_attr2, edge_index_y=edge_index2, edge_weight_y=edge_weight2, anchor_links=anchor_links.T, temperature=self.temperature) infer_time += time.time() - infer_t0 # Negative sampling with torch.no_grad(): anchor_nodes1 = context1[:, 0].cpu().reshape(-1) pos_context_nodes1 = context1[:, 1].reshape(-1) anchor_nodes2 = context2[:, 0].cpu().reshape(-1) pos_context_nodes2 = context2[:, 1].reshape(-1) negs1, negs2 = uniform_negative_sampling(anchor_nodes1, anchor_nodes2, n1, n2, self.neg_size) neg_context1 = negative_edge_sampling(neg_context_prob1, anchor_map1[anchor_nodes1], self.neg_size) neg_context2 = negative_edge_sampling(neg_context_prob2, anchor_map2[anchor_nodes2], self.neg_size) negs1, negs2 = negs1.flatten().to(self.device), negs2.flatten().to(self.device) neg_context1, neg_context2 = neg_context1.flatten().to(self.device), neg_context2.flatten().to(self.device) neg_emb1, neg_emb2 = y[negs1], x[negs2] neg_context_emb1, neg_context_emb2 = x[neg_context1], y[neg_context2] anchor_emb1, anchor_emb2 = x[anchor_nodes1], y[anchor_nodes2] pos_emb1, pos_emb2 = x[pos_context_nodes1], y[pos_context_nodes2] adj_loss = model.adj_loss(anchor_emb1, pos_emb1, neg_context_emb1) + model.adj_loss(anchor_emb2, pos_emb2, neg_context_emb2) align_loss = model.align_loss(anchor_emb1, anchor_emb2, neg_emb1, neg_emb2) batch_loss = self.adj_coeff * adj_loss + self.rank_coeff * align_loss total_loss += batch_loss.item() batch_loss.backward() optimizer.step() t1 = time.time() model.eval() emb1 = F.normalize(x.detach(), p=2, dim=1) emb2 = F.normalize(y.detach(), p=2, dim=1) S = emb1 @ emb2.T hits = hits_ks_scores(S, test_pairs, mode='mean') mrr = mrr_score(S, test_pairs, mode='mean') mem_gb = process.memory_info().rss / 1024 ** 3 logger.log(epoch=epoch+1, loss=total_loss, epoch_time=t1-t0, mrr=mrr, hits=hits, memory=round(mem_gb, 4), infer_time=round(infer_time, 4), verbose=verbose) return emb1, emb2, logger
@staticmethod def _sample_anchor_neighbor_pairs(graph, anchors): degrees = degree(graph.edge_index[0], num_nodes=graph.num_nodes) sampled_context_pairs = torch.empty((0, 2), dtype=torch.int64) for node in anchors: neighbors = graph.edge_index[1, graph.edge_index[0] == node] if len(neighbors) > 100: p = degrees[neighbors] / degrees[neighbors].sum() neighbors = neighbors[torch.multinomial(p, 100, replacement=True)] context = torch.vstack([torch.tensor([node] * len(neighbors), dtype=torch.int64), neighbors]).T sampled_context_pairs = torch.vstack([sampled_context_pairs, context]) # Shuffle the context pairs sampled_context_pairs = sampled_context_pairs[torch.randperm(sampled_context_pairs.size(0))] return sampled_context_pairs def _get_neg_context_prob(self, graph, anchors): prob = torch.ones((len(anchors), graph.num_nodes), dtype=self.dtype) for i, anchor in enumerate(anchors): neighbors = graph.edge_index[1, graph.edge_index[0] == anchor] prob[i][neighbors] = 0 anchor_node_map = -1 * torch.ones(graph.num_nodes, dtype=torch.int64) anchor_node_map[anchors] = torch.arange(len(anchors), dtype=torch.int64) return prob, anchor_node_map