Source code for PlanetAlign.metrics.mrr

import torch


[docs] def mrr_score(similarity: torch.Tensor, test_pairs: torch.Tensor, mode: str = 'mean') -> float: r"""Mean Reciprocal Rank (MRR) score of pairwise alignment results. Parameters ---------- similarity : torch.Tensor Similarity matrix of shape (n1, n2) where n1 and n2 are the number of nodes in graph1 and graph2. test_pairs : torch.Tensor Test pairs of shape (m, 2) where m is the number of test pairs. mode : str, optional Mode for MRR score. Options are 'mean', 'max', 'ltr' (left-to-right), 'rtl' (right-to-left). Default is 'mean'. """ if mode == 'mean': mrr = mrr_mean_score(similarity, test_pairs) elif mode == 'max': mrr = mrr_max_score(similarity, test_pairs) elif mode == 'ltr': mrr = mrr_ltr_score(similarity, test_pairs) elif mode == 'rtl': mrr = mrr_rtl_score(similarity, test_pairs) else: raise ValueError(f"Invalid mode: {mode}") return mrr
def mrr_ltr_score(similarity, test_pairs): r"""Mean Reciprocal Rank (MRR) score of graph1(left) to graph2(right) alignment.""" test_pairs = test_pairs.to(similarity.device) ranks1 = torch.argsort(-similarity[test_pairs[:, 0]], dim=1) signal1_hit = ranks1 == test_pairs[:, 1].view(-1, 1) mrr = torch.mean(1 / (torch.where(signal1_hit)[1].float() + 1)).item() return mrr def mrr_rtl_score(similarity, test_pairs): r"""Mean Reciprocal Rank (MRR) score of graph2(right) to graph1(left) alignment.""" test_pairs = test_pairs.to(similarity.device) ranks2 = torch.argsort(-similarity.T[test_pairs[:, 1]], dim=1) signal2_hit = ranks2 == test_pairs[:, 0].view(-1, 1) mrr = torch.mean(1 / (torch.where(signal2_hit)[1].float() + 1)).item() return mrr def mrr_max_score(similarity, test_pairs): r"""Max Mean Reciprocal Rank (MRR) score of left-to-right and right-to-left alignments.""" mrr_ltr = mrr_ltr_score(similarity, test_pairs) mrr_rtl = mrr_rtl_score(similarity, test_pairs) mrr = max(mrr_ltr, mrr_rtl) return mrr def mrr_mean_score(similarity, test_pairs): r"""Mean Mean Reciprocal Rank (MRR) score of left-to-right and right-to-left alignments.""" mrr_ltr = mrr_ltr_score(similarity, test_pairs) mrr_rtl = mrr_rtl_score(similarity, test_pairs) mrr = (mrr_ltr + mrr_rtl) / 2 return mrr