get_rwr_embeddings

class get_rwr_embeddings(graph: Data, anchors: Tensor, restart_prob: float = 0.15, max_iters: int = 1000, tol: float = 1e-06, connect_isolated: bool = False, dtype: dtype = torch.float32)[source]

Bases:

Compute Random Walk with Restart (RWR) embeddings for nodes in a graph based on given anchor nodes.

Parameters:

graphtorch_geometric.data.Data

The input graph.

anchorstorch.Tensor

A tensor containing the indices of anchor nodes in the graph.

restart_probfloat, optional

The probability of restarting the random walk at each step. Default is 0.15.

max_itersint, optional

The maximum number of iterations for RWR. Default is 1000.

tolfloat, optional

The tolerance for convergence. Default is 1e-6.

connect_isolatedbool, optional

Whether to connect isolated nodes to all other nodes. Default is False.

dtypetorch.dtype, optional

The data type for computations. Default is torch.float32.

Returns:

rwr_embeddingstorch.Tensor

A tensor of shape (num_nodes, num_anchors) representing the RWR embeddings.