get_rwr_embeddings
- class get_rwr_embeddings(graph: Data, anchors: Tensor, restart_prob: float = 0.15, max_iters: int = 1000, tol: float = 1e-06, connect_isolated: bool = False, dtype: dtype = torch.float32)[source]
Bases:
Compute Random Walk with Restart (RWR) embeddings for nodes in a graph based on given anchor nodes.
Parameters:
- graphtorch_geometric.data.Data
The input graph.
- anchorstorch.Tensor
A tensor containing the indices of anchor nodes in the graph.
- restart_probfloat, optional
The probability of restarting the random walk at each step. Default is 0.15.
- max_itersint, optional
The maximum number of iterations for RWR. Default is 1000.
- tolfloat, optional
The tolerance for convergence. Default is 1e-6.
- connect_isolatedbool, optional
Whether to connect isolated nodes to all other nodes. Default is False.
- dtypetorch.dtype, optional
The data type for computations. Default is torch.float32.
Returns:
- rwr_embeddingstorch.Tensor
A tensor of shape (num_nodes, num_anchors) representing the RWR embeddings.