Source code for PlanetAlign.algorithms.wlalign.main

import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
import psutil
import os

from PlanetAlign.data import Dataset
from PlanetAlign.utils import get_anchor_pairs, merge_pyg_graphs_on_anchors, pairwise_cosine_similarity
from PlanetAlign.metrics import hits_ks_scores, mrr_score
from PlanetAlign.algorithms.base_model import BaseModel

from .model import aggregate_label, EmbeddingModel, NegativeSamplingLoss
from .utils import (get_anchor_based_embeddings, get_degree_exp_distribution, get_pyg_successors, get_pyg_predecessors,
                    single_hop_subgraph, get_closest_cross_node_pairs, get_non_anchor_from_merged_graph, get_nodes_outside_subgraph)
from .data import ContextDataset


[docs] class WLAlign(BaseModel): """Embedding-based method WLAlign for pairwise plain network alignment. WLAlign is proposed by the paper "`WL-Align: Weisfeiler-Lehman Relabeling for Aligning Users Across Networks via Regularized Representation Learning. <https://doi.org/10.1109/TKDE.2023.3277843>`_" in TKDE 2023. Parameters ---------- emb_dim : int, optional The dimension of the node embeddings. Default is 128. struct_lr : float, optional The learning rate for the structural model. Default is 5e-3. batch_size : int, optional The batch size for training. Default is 1000. neg_sample_size : int, optional The number of negative samples for training. Default is 20. dtype : torch.dtype, optional Data type of the tensors, choose from torch.float32 or torch.float64. Default is torch.float32. """ def __init__(self, emb_dim: int = 128, struct_lr: float = 5e-3, batch_size: int = 1000, neg_sample_size: int = 20, dtype: torch.dtype = torch.float32): super(WLAlign, self).__init__(dtype=dtype) self.emb_dim = emb_dim self.struct_lr = struct_lr self.batch_size = batch_size self.neg_sample_size = neg_sample_size self.model = None self.criterion = None self.cos = None self.optimizer = None self.verbose = True
[docs] def train(self, dataset: Dataset, gid1: int, gid2: int, use_attr: bool = False, total_epochs: int = 50, struct_epochs: int = 100, save_log: bool = True, verbose: bool = True): """ Parameters ---------- dataset : Dataset The dataset containing the graphs to be aligned and the training/test data. gid1 : int The index of the first graph in the dataset to be aligned. gid2 : int The index of the second graph in the dataset to be aligned. use_attr : bool, optional Whether to use node and edge attributes for alignment. Default is False. total_epochs : int, optional The maximum number of epochs for the optimization. Default is 50. struct_epochs : int, optional The number of epochs for the structural model training. Default is 100. save_log : bool, optional Whether to save the evaluation logs. Default is True. verbose : bool, optional Whether to print the progress during training. Default is True. """ self.check_inputs(dataset, (gid1, gid2), plain_method=True, use_attr=use_attr, pairwise=True, supervised=True) logger = self.init_training_logger(dataset, use_attr, additional_headers=['memory', 'infer_time'], save_log=save_log) process = psutil.Process(os.getpid()) graph1, graph2 = dataset.pyg_graphs[gid1], dataset.pyg_graphs[gid2] anchor_links = get_anchor_pairs(dataset.train_data, gid1, gid2) test_pairs = get_anchor_pairs(dataset.test_data, gid1, gid2) self.verbose = verbose inf_t0 = time.time() merged_graph, merged_anchors, id2node, node2id = merge_pyg_graphs_on_anchors([graph1, graph2], anchor_links) gnd_induced_subset, gnd_induced_edge_index, gnd_induced_mapping = single_hop_subgraph(merged_anchors, edge_index=merged_graph.edge_index, relabel_nodes=True) gnd_induced_subgraph = Data(edge_index=gnd_induced_edge_index, num_nodes=len(gnd_induced_subset), mapping=gnd_induced_subset.numpy(), anchors=gnd_induced_mapping) gnd_induced_subgraph.anchor_successors = {anchor: get_pyg_successors(gnd_induced_subgraph, anchor).numpy() for anchor in gnd_induced_subgraph.anchors.numpy()} gnd_induced_subgraph.anchor_predecessors = {anchor: get_pyg_predecessors(gnd_induced_subgraph, anchor).numpy() for anchor in gnd_induced_subgraph.anchors.numpy()} gnd_sub_id2node = {i: id2node[nid] for i, nid in enumerate(gnd_induced_subset.numpy())} gnd_subgraph_nodes1 = get_non_anchor_from_merged_graph(gnd_induced_subgraph, gnd_sub_id2node, gid1) gnd_subgraph_nodes2 = get_non_anchor_from_merged_graph(gnd_induced_subgraph, gnd_sub_id2node, gid2) current_anchors = torch.clone(merged_anchors) infer_time = time.time() - inf_t0 # Model initialization self.model = EmbeddingModel(merged_graph.num_nodes, self.emb_dim).to(self.dtype).to(self.device) self.criterion = NegativeSamplingLoss() self.cos = nn.CosineEmbeddingLoss(margin=0) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.struct_lr, betas=(0.9, 0.999), eps=1e-8) all_candidate_pairs = np.empty((0, 2), dtype=np.int64) output_emb1, output_emb2 = None, None for epoch in range(total_epochs): t0 = time.time() subset, edge_index, mapping = single_hop_subgraph(current_anchors, edge_index=merged_graph.edge_index, relabel_nodes=True) anchor_induced_subgraph = Data(edge_index=edge_index, num_nodes=len(subset), mapping=subset.numpy(), anchors=mapping) if verbose: print('Number of nodes in the anchor-induced subgraph:', anchor_induced_subgraph.num_nodes) sub_id2node = {i: id2node[nid] for i, nid in enumerate(subset.numpy())} merged2sub_dict = {nid: i for i, nid in enumerate(subset.numpy())} merged2sub = np.vectorize(lambda x: merged2sub_dict[x] if x in merged2sub_dict else -1)(np.arange(merged_graph.num_nodes)) subgraph_nodes1 = get_non_anchor_from_merged_graph(anchor_induced_subgraph, sub_id2node, gid1) subgraph_nodes2 = get_non_anchor_from_merged_graph(anchor_induced_subgraph, sub_id2node, gid2) onehot_embs = get_anchor_based_embeddings(anchor_induced_subgraph).to(self.dtype).to(self.device) layer_embs_list = aggregate_label(onehot_embs.weight.data, anchor_induced_subgraph, num_layers=1, device=self.device) layer_emb = F.normalize(layer_embs_list[0], p=2, dim=1) cross_candidate_pairs = get_closest_cross_node_pairs(layer_emb, subgraph_nodes1, subgraph_nodes2, device=self.device) cross_candidate_pairs_by_gnd = get_closest_cross_node_pairs(layer_emb, merged2sub[gnd_induced_subgraph.mapping[gnd_subgraph_nodes1]], merged2sub[gnd_induced_subgraph.mapping[gnd_subgraph_nodes2]], device=self.device) candidate_pairs = np.vstack([cross_candidate_pairs_by_gnd, cross_candidate_pairs]) candidate_pairs_mapped = anchor_induced_subgraph.mapping[candidate_pairs] # mapped node idx from the anchor induced subgraph to the original merged graph all_candidate_pairs = np.unique(np.vstack([all_candidate_pairs, candidate_pairs_mapped]), axis=0) # Train the anchor embeddings output_embeddings, loss = self.train_anchor(merged_graph, gnd_induced_subgraph, all_candidate_pairs, struct_epochs, curr_epoch=epoch) new_anchors = torch.unique(torch.from_numpy(candidate_pairs_mapped.flatten())) current_anchors = torch.unique(torch.cat([current_anchors, new_anchors])) t1 = time.time() infer_time += t1 - t0 # Evaluate the model with torch.no_grad(): output_embeddings = F.normalize(output_embeddings, p=2, dim=1) output_emb1 = output_embeddings[[node2id[(0, node)] for node in range(graph1.num_nodes)]] output_emb2 = output_embeddings[[node2id[(1, node)] for node in range(graph2.num_nodes)]] S = pairwise_cosine_similarity(output_emb1, output_emb2) hits = hits_ks_scores(S, test_pairs, mode='mean') mrr = mrr_score(S, test_pairs, mode='mean') mem_gb = process.memory_info().rss / 1024 ** 3 logger.log(epoch=epoch+1, loss=loss, epoch_time=t1-t0, hits=hits, mrr=mrr, memory=round(mem_gb, 4), infer_time=round(infer_time, 4), verbose=verbose) return output_emb1, output_emb2, logger
def train_anchor(self, merged_graph, gnd_subgraph, candidate_pairs, struct_epochs, curr_epoch): num_candidates = candidate_pairs.shape[0] candidate_nodes1 = torch.from_numpy(candidate_pairs[:, 0]).to(self.device) candidate_nodes2 = torch.from_numpy(candidate_pairs[:, 1]).to(self.device) node_noise = get_nodes_outside_subgraph(gnd_subgraph, merged_graph) nodes_dist = torch.from_numpy(get_degree_exp_distribution(merged_graph.edge_index.flatten())).to(self.device) noise_dist = torch.from_numpy(get_degree_exp_distribution(node_noise)).to(self.device) if noise_dist.shape[0] == 0: noise_dist = nodes_dist t0 = time.time() overall_loss = 0 for epoch in range(struct_epochs): total_loss = 0 sampled_src, sampled_tgt = self._get_training_samples(gnd_subgraph) sampled_src_mapped, sampled_tgt_mapped = gnd_subgraph.mapping[sampled_src], gnd_subgraph.mapping[sampled_tgt] sampled_dataset = ContextDataset(sampled_src_mapped, sampled_tgt_mapped) sampled_loader = torch.utils.data.DataLoader(sampled_dataset, batch_size=self.batch_size, shuffle=True) for batch_src, batch_tgt in sampled_loader: curr_batch_size = batch_src.shape[0] batch_src = batch_src.to(self.device) batch_tgt = batch_tgt.to(self.device) self.optimizer.zero_grad() target_input_vecs = self.model.forward_input(batch_tgt) source_output_vecs = self.model.forward_output(batch_src) self_input_vecs = self.model.forward_self(batch_src) self_output_vecs = self.model.forward_self(batch_tgt) self_left = self.model.forward_self(candidate_nodes1) self_right = self.model.forward_self(candidate_nodes2) noise_vecs_self1, noise_vecs_input1, noise_vecs_output1 = self.model.forward_noise(curr_batch_size, self.neg_sample_size, nodes_dist) noise_vecs_self2, noise_vecs_input2, noise_vecs_output2 = self.model.forward_noise(curr_batch_size, self.neg_sample_size, noise_dist) loss = self.criterion(self_input_vecs, self_output_vecs, target_input_vecs, source_output_vecs, noise_vecs_self1, noise_vecs_input1, noise_vecs_output1, noise_vecs_self2, noise_vecs_input2, noise_vecs_output2) loss += self.cos(self_left, self_right, torch.ones(num_candidates, dtype=self.dtype).to(self.device)) total_loss += loss.item() loss.backward() self.optimizer.step() if epoch % 10 == 9 and self.verbose: print(f'Epoch {curr_epoch + 1}, Struct epoch {epoch + 1}/{struct_epochs}, Loss: {total_loss:.4f}') overall_loss += total_loss overall_loss /= struct_epochs if self.verbose: print(f'Epoch {curr_epoch + 1}, Time: {time.time() - t0:.2f}s') return self.model.self_embed.weight.data, overall_loss def _get_training_samples(self, gnd_subgraph, num_batches=1): edge_index = gnd_subgraph.edge_index num_edges = gnd_subgraph.num_edges # Sample edges batch_indices = [] for _ in range(num_batches): indices = np.random.choice(num_edges, size=self.batch_size, replace=True) batch_indices.append(indices) all_sampled_indices = np.concatenate(batch_indices) sample_indices = torch.from_numpy(all_sampled_indices).to(edge_index.device) sampled_edges = edge_index[:, sample_indices] # shape: [2, num_batches * batch_size] # Convert sampled edge indices to numpy arrays for vectorized string operations. sampled_edges = sampled_edges.cpu().numpy() src, tgt = sampled_edges[0], sampled_edges[1] # Replace anchor nodes with their successors/predecessors def replace_anchor_with_successor(node): if node in gnd_subgraph.anchors: anchor_successors = gnd_subgraph.anchor_successors return anchor_successors[node][np.random.choice(len(anchor_successors[node]))] return node def replace_anchor_with_predecessor(node): if node in gnd_subgraph.anchors: anchor_predecessors = gnd_subgraph.anchor_predecessors return anchor_predecessors[node][np.random.choice(len(anchor_predecessors[node]))] return node replaced_src = np.vectorize(replace_anchor_with_successor)(src) replaced_tgt = np.vectorize(replace_anchor_with_predecessor)(tgt) # Resample edges that at least one of the nodes is an anchor filtered_src_indices = np.vectorize(lambda x: x in gnd_subgraph.anchors)(src) filtered_tgt_indices = np.vectorize(lambda x: x in gnd_subgraph.anchors)(tgt) filtered_src = src[filtered_src_indices | filtered_tgt_indices] filtered_tgt = tgt[filtered_src_indices | filtered_tgt_indices] # Concatenate replaced and filtered edges sampled_src = np.concatenate([replaced_src, filtered_src]) sampled_tgt = np.concatenate([replaced_tgt, filtered_tgt]) assert len(sampled_src) == len(sampled_tgt), 'The number of source and target nodes should be equal.' return torch.from_numpy(sampled_src).to(torch.int64), torch.from_numpy(sampled_tgt).to(torch.int64)