import time
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.utils import degree
import psutil
import os
from PlanetAlign.data import Dataset
from PlanetAlign.metrics import hits_ks_scores, mrr_score
from PlanetAlign.utils import get_anchor_pairs
from .base_model import BaseModel
[docs]
class IONE(BaseModel):
"""Embedding-based method IONE for pairwise plain network alignment.
IONE is proposed by the paper "`Aligning Users Across Social Networks Using Network Embedding <https://www.ijcai.org/Proceedings/16/Papers/254.pdf>`_"
in IJCAI 2016.
Parameters
----------
out_dim: int, optional
The output dimension of the embeddings. Default is 100.
dtype: torch.dtype, optional
Data type of the tensors, choose from torch.float32 or torch.float64. Default is torch.float32.
"""
def __init__(self,
out_dim: int = 100,
dtype: torch.dtype = torch.float32):
super(IONE, self).__init__(dtype=dtype)
assert out_dim > 0, 'Output dimension must be a positive integer'
self.out_dim = out_dim
self.base_epochs = 100000
[docs]
def train(self,
dataset: Dataset,
gid1: int,
gid2: int,
use_attr: bool = False,
total_epochs: int = 100,
save_log: bool = True,
verbose: bool = True):
"""
Parameters
----------
dataset : Dataset
The dataset containing graphs to be aligned and the training/test data.
gid1 : int
The graph id of the first graph to be aligned.
gid2 : int
The graph id of the second graph to be aligned.
use_attr : bool, optional
Flag for using attributes. **Must be False for IONE**. Default is False.
total_epochs : int, optional
Maximum number of training epochs. Default is 10000000.
save_log : bool, optional
Flag for saving the log. Default is True.
verbose : bool, optional
Flag for printing the log. Default is True.
"""
assert total_epochs > 0, 'Total epochs must be a positive integer'
self.check_inputs(dataset, (gid1, gid2), plain_method=True, use_attr=use_attr, pairwise=True, supervised=True)
logger = self.init_training_logger(dataset, use_attr, additional_headers=['memory', 'infer_time'], save_log=save_log)
process = psutil.Process(os.getpid())
graph1, graph2 = dataset.pyg_graphs[gid1], dataset.pyg_graphs[gid2]
anchor_links = get_anchor_pairs(dataset.train_data, gid1, gid2)
test_pairs = get_anchor_pairs(dataset.test_data, gid1, gid2)
two_order_x = IONEUpdate(graph1, self.out_dim, self.dtype).to(self.device)
two_order_y = IONEUpdate(graph2, self.out_dim, self.dtype).to(self.device)
anchor_map1 = {anchor[0].item(): anchor[1].item() for anchor in anchor_links}
anchor_map2 = {anchor[1].item(): anchor[0].item() for anchor in anchor_links}
infer_time = 0
S = torch.zeros(graph1.num_nodes, graph2.num_nodes, dtype=self.dtype).to(self.device)
for epoch in range(total_epochs):
t0 = time.time()
for _ in range(self.base_epochs):
two_order_x(i=epoch,
iter_count=total_epochs,
two_order_embeddings=two_order_x.embeddings,
two_order_emb_context_input=two_order_x.emb_context_input,
two_order_emb_context_output=two_order_x.emb_context_output,
anchors=anchor_map1,
same_network=True)
two_order_y(i=epoch,
iter_count=total_epochs,
two_order_embeddings=two_order_x.embeddings,
two_order_emb_context_input=two_order_x.emb_context_input,
two_order_emb_context_output=two_order_x.emb_context_output,
anchors=anchor_map2,
same_network=False)
t1 = time.time()
infer_time += t1 - t0
S_old = S.clone()
emb_x = F.normalize(two_order_x.embeddings, p=2, dim=1)
emb_y = F.normalize(two_order_y.embeddings, p=2, dim=1)
S = emb_x @ emb_y.T
diff = torch.norm(S - S_old)
hits = hits_ks_scores(S, test_pairs, mode='mean')
mrr = mrr_score(S, test_pairs, mode='mean')
mem_gb = process.memory_info().rss / 1024 ** 3
logger.log(epoch=epoch+1,
loss=diff.item(),
epoch_time=t1-t0,
mrr=mrr,
hits=hits,
memory=round(mem_gb, 4),
infer_time=round(infer_time, 4),
verbose=verbose)
return S, logger
class IONEUpdate:
def __init__(self,
graph: Data,
out_dim: int,
dtype: torch.dtype = torch.float32):
assert dtype in [torch.float32, torch.float64], 'Invalid floating point dtype'
self.dtype = dtype
self.device = 'cpu'
self.graph = graph
self.dimension = out_dim
self.embeddings = torch.empty((self.graph.num_nodes, self.dimension), dtype=self.dtype).uniform_(
-0.5 / self.dimension, 0.5 / self.dimension)
self.emb_context_input = torch.zeros(self.graph.num_nodes, self.dimension, dtype=self.dtype)
self.emb_context_output = torch.zeros(self.graph.num_nodes, self.dimension, dtype=self.dtype)
self.vertex = (degree(self.graph.edge_index[0], num_nodes=self.graph.num_nodes) +
degree(self.graph.edge_index[1], num_nodes=self.graph.num_nodes))
self.init_rho = 0.025
self.rho = 0
self.num_negative = 5
self.neg_table_size = 10000000
self.edge_weight = []
self.prob = torch.zeros(self.graph.num_edges, dtype=self.dtype)
self.alias = torch.zeros(self.graph.num_edges, dtype=torch.int64)
self.neg_table = torch.zeros(self.neg_table_size, dtype=torch.int64)
# Initialize tables
start = time.time()
self.init_alias_table()
print(f'{self.graph.name}: alias table initialized in {time.time() - start:.2f} seconds')
start = time.time()
self.init_neg_table()
print(f'{self.graph.name}: negative table initialized in {time.time() - start:.2f} seconds')
def forward(self, i, iter_count, two_order_embeddings, two_order_emb_context_input, two_order_emb_context_output,
anchors, same_network=True):
vec_error = torch.zeros(self.dimension, dtype=self.dtype).to(self.device)
if i % int(iter_count / 10) == 0:
self.rho = self.init_rho * (1.0 - i / iter_count)
if self.rho < self.init_rho * 0.0001:
self.rho = self.init_rho * 0.0001
edge_id = self.sample_edge(torch.rand(1).item(), torch.rand(1).item())
uid_1, uid_2 = self.graph.edge_index[:, edge_id].cpu()
uid_1, uid_2 = uid_1.item(), uid_2.item()
d = 0
while d < self.num_negative + 1:
if d == 0:
label = 1
target = uid_2
else:
neg_index = torch.randint(0, self.neg_table_size, (1,)).item()
target = self.neg_table[neg_index].cpu().item()
assert not isinstance(target, torch.Tensor), 'Target should not be a tensor'
if target == uid_1 or target == uid_2:
continue
label = 0
vec_error += self.update(vec_u=self.embeddings[uid_1],
vec_v=self.emb_context_input[target],
label=label,
source=uid_1,
target=target,
two_order_embeddings=two_order_embeddings,
two_order_emb_context=two_order_emb_context_input,
anchors=anchors,
same_network=same_network)
self.update_reverse(vec_u=self.embeddings[target],
vec_v=self.emb_context_output[uid_1],
label=label,
source=target,
target=uid_1,
two_order_embeddings=two_order_embeddings,
two_order_emb_context=two_order_emb_context_output,
anchors=anchors,
same_network=same_network)
d = d + 1
if uid_1 in anchors:
vec_u = two_order_embeddings[anchors[uid_1]] if not same_network else None
if vec_u is None:
self.embeddings[uid_1] += vec_error
else:
two_order_embeddings[anchors[uid_1]] += vec_error
else:
self.embeddings[uid_1] += vec_error
def init_alias_table(self):
self.edge_weight = torch.ones(self.graph.num_edges, dtype=self.dtype)
norm_prob = F.normalize(self.edge_weight, p=1, dim=0) * self.graph.num_edges
small_block = torch.flip(torch.argwhere(norm_prob < 1).flatten(), dims=[0])
large_block = torch.flip(torch.argwhere(norm_prob >= 1).flatten(), dims=[0])
num_small_block = len(small_block)
num_large_block = len(large_block)
while num_small_block > 0 and num_large_block > 0:
num_small_block = num_small_block - 1
cur_small_block = small_block[num_small_block]
num_large_block = num_large_block - 1
cur_large_block = large_block[num_large_block]
self.prob[cur_small_block] = norm_prob[cur_small_block]
self.alias[cur_small_block] = cur_large_block
norm_prob[cur_large_block] = norm_prob[cur_large_block] + norm_prob[cur_small_block] - 1
if norm_prob[cur_large_block] < 1:
small_block[num_small_block] = cur_large_block
num_small_block = num_small_block + 1
else:
large_block[num_large_block] = cur_large_block
num_large_block = num_large_block + 1
while num_large_block > 0:
num_large_block = num_large_block - 1
self.prob[large_block[num_large_block]] = 1
while num_small_block > 0:
num_small_block = num_small_block - 1
self.prob[small_block[num_small_block]] = 1
def sample_edge(self, rand1: float, rand2: float) -> int:
k = int(len(self.edge_weight) * rand1)
return k if rand2 < self.prob[k] else self.alias[k]
def init_neg_table(self):
total_sum = torch.sum(self.vertex ** 0.75).cpu().item()
cumulative_sum = 0
por = 0
perm_node_list = torch.randperm(self.graph.num_nodes).numpy().tolist()
list_iter = iter(perm_node_list)
current = next(list_iter)
vertex = self.vertex.cpu().numpy().tolist()
self.neg_table = []
for i in range(self.neg_table_size):
if (i + 1) / self.neg_table_size > por:
cumulative_sum += vertex[current] ** 0.75
por = cumulative_sum / total_sum
if por >= 1:
self.neg_table.append(current)
continue
if i != 0:
current = next(list_iter)
self.neg_table.append(current)
self.neg_table = torch.tensor(self.neg_table, dtype=torch.int64).to(self.vertex.device)
def update(self, vec_u, vec_v, label, source, target, two_order_embeddings, two_order_emb_context,
anchors, same_network=True):
if source in anchors:
vec_u = two_order_embeddings[anchors[source]] if not same_network else two_order_embeddings[source]
if target in anchors:
vec_v = two_order_emb_context[anchors[target]] if not same_network else two_order_emb_context[target]
x = vec_u @ vec_v
g = (label - torch.sigmoid(x)) * self.rho
vec_error = g * vec_v
if target in anchors:
if same_network:
vec_v += g * vec_u
else:
two_order_emb_context[anchors[target]] += g * vec_u
else:
vec_v += g * vec_u
return vec_error
def update_reverse(self, vec_u, vec_v, label, source, target, two_order_embeddings, two_order_emb_context,
anchors, same_network=True):
if source in anchors:
vec_u = two_order_embeddings[anchors[source]] if not same_network else two_order_embeddings[source]
if target in anchors:
vec_v = two_order_emb_context[anchors[target]] if not same_network else two_order_emb_context[target]
x = vec_u @ vec_v
g = (label - torch.sigmoid(x)) * self.rho
vec_error = g * vec_v
if target in anchors:
if same_network:
vec_v += g * vec_u
else:
two_order_emb_context[anchors[target]] += g * vec_u
else:
vec_v += g * vec_u
uid_1 = source
if uid_1 in anchors:
if same_network:
self.embeddings[uid_1] += vec_error
else:
two_order_embeddings[anchors[uid_1]] += vec_error
else:
self.embeddings[uid_1] += vec_error
def to(self, device):
assert device in ['cpu', 'cuda'] or isinstance(device, torch.device), 'Invalid device'
self.device = device
self.embeddings = self.embeddings.to(device)
self.emb_context_input = self.emb_context_input.to(device)
self.emb_context_output = self.emb_context_output.to(device)
self.vertex = self.vertex.to(device)
self.prob = self.prob.to(device)
self.alias = self.alias.to(device)
self.neg_table = self.neg_table.to(device)
return self
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)