import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import time
import psutil
import os
from PlanetAlign.data import Dataset
from PlanetAlign.utils import get_anchor_pairs, get_batch_rwr_scores, balance_samples, get_anchor_embeddings
from PlanetAlign.metrics import hits_ks_scores, mrr_score
from PlanetAlign.algorithms.base_model import BaseModel
from .utils import load_walks, extract_pairs, merge_graphs, ContextDataset, negative_sampling_exact
from .model import Model
[docs]
class NeXtAlign(BaseModel):
"""Embedding-based method NeXtAlign for pairwise network alignment.
NeXtAlign is proposed by the paper: "`Balancing Consistency and Disparity in Network Alignment. <https://dl.acm.org/doi/abs/10.1145/3447548.3467331>`_"
in KDD 2021
Parameters
----------
p : float, optional
Hyperparameter in node2vec. Default is 1.
q : float, optional
Hyperparameter in node2vec. Default is 1.
num_walks : int, optional
Number of random walks during context pair generation. Default is 10.
walk_length : int, optional
Length of each random walk. Default is 80.
rwr_restart_prob : float, optional
Restart probability for random walk with restart. Default is 0.15.
out_dim : int, optional
Dimension of the output embeddings. Default is 128.
dist : str, optional
Distance metric for the similarity matrix. Default is 'L1'.
batch_size : int, optional
Batch size for training. Default is 300.
neg_size : int, optional
Number of negative samples. Default is 20.
coeff1 : float, optional
Coefficient for the within-network link prediction loss. Default is 1.
coeff2 : float, optional
Coefficient for the anchor link prediction loss. Default is 1.
lr : float, optional
Learning rate for the optimizer. Default is 0.01.
dtype : torch.dtype, optional
Data type of the tensors, choose from torch.float32 or torch.float64. Default is torch.float32.
"""
def __init__(self,
p: float = 1,
q: float = 1,
num_walks: int = 10,
walk_length: int = 80,
rwr_restart_prob: float = 0.15,
out_dim: int = 128,
dist: str = 'L1',
batch_size: int = 300,
neg_size: int = 20,
coeff1: float = 1,
coeff2: float = 1,
lr: float = 0.01,
dtype: torch.dtype = torch.float32):
super(NeXtAlign, self).__init__(dtype=dtype)
assert dist in ['L1', 'inner'], 'Invalid distance metric'
self.p = p
self.q = q
self.num_walks = num_walks
self.walk_length = walk_length
self.rwr_restart_prob = rwr_restart_prob
self.out_dim = out_dim
self.dist = dist
self.batch_size = batch_size
self.neg_size = neg_size
self.coeff1 = coeff1
self.coeff2 = coeff2
self.lr = lr
[docs]
def train(self,
dataset: Dataset,
gid1: int,
gid2: int,
use_attr: bool = True,
total_epochs: int = 50,
save_log: bool = True,
verbose: bool = True):
"""
Parameters
----------
dataset : Dataset
The dataset containing the graphs to be aligned and the training/test data.
gid1 : int
The index of the first graph in the dataset to be aligned.
gid2 : int
The index of the second graph in the dataset to be aligned.
use_attr : bool, optional
Whether to use node attributes in the model. Default is True.
total_epochs : int, optional
Total number of epochs for training. Default is 50.
save_log : bool, optional
Whether to save the training log. Default is True.
verbose : bool, optional
Whether to print the training progress. Default is True.
"""
self.check_inputs(dataset, (gid1, gid2), plain_method=False, use_attr=use_attr, pairwise=True, supervised=True)
if dataset.pyg_graphs[gid1].num_nodes > dataset.pyg_graphs[gid2].num_nodes:
gid1, gid2 = gid2, gid1
logger = self.init_training_logger(dataset, use_attr, additional_headers=['memory', 'infer_time'], save_log=save_log)
process = psutil.Process(os.getpid())
graph1, graph2 = dataset.pyg_graphs[gid1], dataset.pyg_graphs[gid2]
anchor_links = get_anchor_pairs(dataset.train_data, gid1, gid2)
test_pairs = get_anchor_pairs(dataset.test_data, gid1, gid2)
t0 = time.time()
if verbose:
print('Sampling positive context pairs...')
context_pairs1, context_pairs2 = self._sample_positive_context_pairs(graph1, graph2, anchor_links, verbose=verbose)
if verbose:
print('Done, Time Spent: %.2f seconds' % (time.time() - t0))
t0 = time.time()
if verbose:
print('Generating initial positional embeddings...', end=' ')
input_emb1, input_emb2 = self._generate_initial_embeddings(graph1, graph2, anchor_links, use_attr)
if verbose:
print('Done, Time Spent: %.2f seconds' % (time.time() - t0))
rwr_time = time.time() - t0
t0 = time.time()
if verbose:
print('Merging graphs...', end=' ')
num_nodes_merged = graph1.num_nodes + graph2.num_nodes - anchor_links.shape[0]
edge_index, edge_types, x, node_mapping1, node_mapping2 = self.merge_two_graphs(graph1, graph2,
input_emb1, input_emb2,
anchor_links)
if verbose:
print('Done, Time Spent: %.2f seconds' % (time.time() - t0))
onehot_emb = torch.arange(num_nodes_merged, dtype=torch.int64)
x = (onehot_emb, x[0], x[1]) if use_attr else (onehot_emb, x)
# Context dataset
node2vec_context_dataset = ContextDataset(context_pairs1, context_pairs2)
data_loader = DataLoader(dataset=node2vec_context_dataset, batch_size=self.batch_size, shuffle=True)
# Model
model = Model(num_nodes=num_nodes_merged,
out_features=self.out_dim,
anchor_nodes=anchor_links[:, 0],
distance=self.dist,
num_anchors=anchor_links.shape[0],
num_attrs=x[2].shape[1] if use_attr else 0)
optimizer = torch.optim.Adam(model.parameters(), lr=self.lr)
# Move to device
model = model.to(self.dtype).to(self.device)
x = tuple([sub.to(self.dtype).to(self.device) for sub in x])
edge_index = edge_index.to(self.device)
edge_types = edge_types.to(self.device)
node_mapping1 = node_mapping1.to(self.device)
node_mapping2 = node_mapping2.to(self.device)
# Train
if verbose:
print('Training...')
for epoch in range(total_epochs):
t0 = time.time()
model.train()
epoch_loss = 0
out_x = None
infer_time = rwr_time
for i, data in enumerate(data_loader):
nodes1, nodes2 = data
nodes1 = nodes1.to(self.device)
nodes2 = nodes2.to(self.device)
anchor_nodes1 = nodes1[:, 0].reshape((-1,))
pos_context_nodes1 = nodes1[:, 1].reshape((-1,))
anchor_nodes2 = nodes2[:, 0].reshape((-1,))
pos_context_nodes2 = nodes2[:, 1].reshape((-1,))
# forward pass
optimizer.zero_grad()
inf_t0 = time.time()
out_x = model(edge_index.T, x, edge_types)
infer_time += time.time() - inf_t0
context_pos1_emb = out_x[node_mapping1[pos_context_nodes1]]
context_pos2_emb = out_x[node_mapping2[pos_context_nodes2]]
pn_examples1, _ = negative_sampling_exact(out_x, self.neg_size, anchor_nodes1, node_mapping1,
'p_n', 'g1')
pn_examples2, _ = negative_sampling_exact(out_x, self.neg_size, anchor_nodes2, node_mapping2,
'p_n', 'g2')
pnc_examples1, _ = negative_sampling_exact(out_x, self.neg_size, anchor_nodes1, node_mapping1,
'p_nc', 'g1', node_mapping2=node_mapping2)
pnc_examples2, _ = negative_sampling_exact(out_x, self.neg_size, anchor_nodes2, node_mapping2,
'p_nc', 'g2', node_mapping2=node_mapping1)
# get node embeddings
pn_examples1 = torch.from_numpy(pn_examples1).reshape((-1,)).to(self.device)
pn_examples2 = torch.from_numpy(pn_examples2).reshape((-1,)).to(self.device)
pnc_examples1 = torch.from_numpy(pnc_examples1).reshape((-1,)).to(self.device)
pnc_examples2 = torch.from_numpy(pnc_examples2).reshape((-1,)).to(self.device)
anchor1_emb = out_x[node_mapping1[anchor_nodes1]]
anchor2_emb = out_x[node_mapping2[anchor_nodes2]]
context_neg1_emb = out_x[node_mapping1[pn_examples1]]
context_neg2_emb = out_x[node_mapping2[pn_examples2]]
anchor_neg1_emb = out_x[node_mapping2[pnc_examples1]]
anchor_neg2_emb = out_x[node_mapping1[pnc_examples2]]
input_embs = (anchor1_emb, anchor2_emb, context_pos1_emb, context_pos2_emb, context_neg1_emb,
context_neg2_emb, anchor_neg1_emb, anchor_neg2_emb)
# compute loss
loss1, loss2 = model.loss(input_embs)
batch_loss = self.coeff1 * loss1 + self.coeff2 * loss2
# if verbose:
# print(f'Epoch: {epoch + 1}/{total_epochs}, Iteration: {i + 1}/{len(data_loader)}, '
# f'Batch loss: {batch_loss.item():.4f}, Loss1: {loss1.item():.4f}, Loss2: {loss2.item():.4f}')
# backward pass
epoch_loss += batch_loss.item()
batch_loss.backward()
optimizer.step()
t1 = time.time()
# Evaluation
model.eval()
with torch.no_grad():
out_x = F.normalize(out_x, p=2, dim=1)
emb1 = out_x[node_mapping1]
emb2 = out_x[node_mapping2]
S = self.get_cross_alignment_mat(emb1, emb2, model.score_lin.weight[0].detach())
hits = hits_ks_scores(S, test_pairs, mode='mean')
mrr = mrr_score(S, test_pairs, mode='mean')
mem_gb = process.memory_info().rss / 1024 ** 3
logger.log(epoch=epoch+1,
loss=epoch_loss,
epoch_time=t1-t0,
hits=hits,
mrr=mrr,
memory=round(mem_gb, 4),
infer_time=round(infer_time, 4),
verbose=verbose)
return emb1, emb2, logger
def _sample_positive_context_pairs(self, graph1, graph2, anchor_links, verbose=True):
walks1 = load_walks(graph1, self.p, self.q, self.num_walks, self.walk_length, verbose=verbose)
walks2 = load_walks(graph2, self.p, self.q, self.num_walks, self.walk_length, verbose=verbose)
context_pairs1 = extract_pairs(walks1, anchor_links[:, 0])
context_pairs2 = extract_pairs(walks2, anchor_links[:, 1])
context_pairs1, context_pairs2 = balance_samples(context_pairs1, context_pairs2)
return context_pairs1, context_pairs2
def _generate_initial_embeddings(self, graph1, graph2, anchor_links, use_attr):
rwr_emb1 = get_batch_rwr_scores(graph1, anchor_links[:, 0], self.rwr_restart_prob, device=self.device).cpu().to(self.dtype)
rwr_emb2 = get_batch_rwr_scores(graph2, anchor_links[:, 1], self.rwr_restart_prob, device=self.device).cpu().to(self.dtype)
anchor_emb1 = get_anchor_embeddings(graph1, anchor_links[:, 0])
anchor_emb2 = get_anchor_embeddings(graph2, anchor_links[:, 1])
rwr_emb1[anchor_links[:, 0], :] = 0
rwr_emb2[anchor_links[:, 1], :] = 0
pos_emb1 = anchor_emb1 + rwr_emb1
pos_emb2 = anchor_emb2 + rwr_emb2
if use_attr:
input_emb1 = (pos_emb1, graph1.x)
input_emb2 = (pos_emb2, graph2.x)
else:
input_emb1 = pos_emb1
input_emb2 = pos_emb2
return input_emb1, input_emb2
@staticmethod
def merge_two_graphs(graph1, graph2, input_emb1, input_emb2, anchor_links):
node_mapping1 = torch.arange(graph1.num_nodes, dtype=torch.int64)
edge_index, edge_types, x, node_mapping2 = merge_graphs(graph1, graph2, input_emb1, input_emb2, anchor_links)
return edge_index, edge_types, x, node_mapping1, node_mapping2
@torch.no_grad()
def get_cross_alignment_mat(self, emb1, emb2, weights):
dim = emb1.shape[1]
emb1_1 = emb1[:, :dim // 2]
emb1_2 = emb1[:, dim // 2: dim]
emb2_1 = emb2[:, :dim // 2]
emb2_2 = emb2[:, dim // 2: dim]
if self.dist == 'inner':
S = weights[0] * emb1_1.dot(emb2_1.T) + weights[1] * emb1_1.dot(emb2_2.T) + weights[2] * emb1_2.dot(emb2_1.T) + \
weights[3] * emb1_2.dot(emb2_2.T)
else:
S = weights[0] * torch.cdist(emb1_1, emb2_1, p=1) + \
weights[1] * torch.cdist(emb1_1, emb2_2, p=1) + \
weights[2] * torch.cdist(emb1_2, emb2_1, p=1) + \
weights[3] * torch.cdist(emb1_2, emb2_2, p=1)
S = -S
return torch.sigmoid(S)