Skip to content

Commit cc4696b

Browse files
authored
Allow returning edge indices from random walk (#139)
This commit adds an optional argument in the `random_walk` function, namely `return_edge_indices`. The default behaviour is not changed, but if a user wants to directly use the edges visited by the random walker, we can return the indices of those edges by setting `return_edge_indices` to `True`. New cases are also added to the test suite.
1 parent c77ed13 commit cc4696b

File tree

2 files changed

+73
-7
lines changed

2 files changed

+73
-7
lines changed

test/test_rw.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
@pytest.mark.parametrize('device', devices)
9-
def test_rw(device):
9+
def test_rw_large(device):
1010
row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device)
1111
col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device)
1212
start = tensor([0, 1, 2, 3, 4], torch.long, device)
@@ -21,10 +21,59 @@ def test_rw(device):
2121
assert out[n, i].item() in col[row == cur].tolist()
2222
cur = out[n, i].item()
2323

24+
25+
@pytest.mark.parametrize('device', devices)
26+
def test_rw_small(device):
2427
row = tensor([0, 1], torch.long, device)
2528
col = tensor([1, 0], torch.long, device)
2629
start = tensor([0, 1, 2], torch.long, device)
2730
walk_length = 4
2831

2932
out = random_walk(row, col, start, walk_length, num_nodes=3)
3033
assert out.tolist() == [[0, 1, 0, 1, 0], [1, 0, 1, 0, 1], [2, 2, 2, 2, 2]]
34+
35+
36+
@pytest.mark.parametrize('device', devices)
37+
def test_rw_large_with_edge_indices(device):
38+
row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device)
39+
col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device)
40+
start = tensor([0, 1, 2, 3, 4], torch.long, device)
41+
walk_length = 10
42+
43+
node_seq, edge_seq = random_walk(
44+
row, col, start, walk_length,
45+
return_edge_indices=True,
46+
)
47+
assert node_seq[:, 0].tolist() == start.tolist()
48+
49+
for n in range(start.size(0)):
50+
cur = start[n].item()
51+
for i in range(1, walk_length):
52+
assert node_seq[n, i].item() in col[row == cur].tolist()
53+
cur = node_seq[n, i].item()
54+
55+
assert (edge_seq != -1).all()
56+
57+
58+
@pytest.mark.parametrize('device', devices)
59+
def test_rw_small_with_edge_indices(device):
60+
row = tensor([0, 1], torch.long, device)
61+
col = tensor([1, 0], torch.long, device)
62+
start = tensor([0, 1, 2], torch.long, device)
63+
walk_length = 4
64+
65+
node_seq, edge_seq = random_walk(
66+
row, col, start, walk_length,
67+
num_nodes=3,
68+
return_edge_indices=True,
69+
)
70+
assert node_seq.tolist() == [
71+
[0, 1, 0, 1, 0],
72+
[1, 0, 1, 0, 1],
73+
[2, 2, 2, 2, 2],
74+
]
75+
assert edge_seq.tolist() == [
76+
[0, 1, 0, 1],
77+
[1, 0, 1, 0],
78+
[-1, -1, -1, -1],
79+
]

torch_cluster/rw.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
1-
from typing import Optional
1+
from typing import Optional, Tuple, Union
22

33
import torch
44
from torch import Tensor
55

66

77
@torch.jit.script
8-
def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int,
9-
p: float = 1, q: float = 1, coalesced: bool = True,
10-
num_nodes: Optional[int] = None) -> Tensor:
8+
def random_walk(
9+
row: Tensor,
10+
col: Tensor,
11+
start: Tensor,
12+
walk_length: int,
13+
p: float = 1,
14+
q: float = 1,
15+
coalesced: bool = True,
16+
num_nodes: Optional[int] = None,
17+
return_edge_indices: bool = False,
18+
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
1119
"""Samples random walks of length :obj:`walk_length` from all node indices
1220
in :obj:`start` in the graph given by :obj:`(row, col)` as described in the
1321
`"node2vec: Scalable Feature Learning for Networks"
@@ -28,6 +36,9 @@ def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int,
2836
the graph given by :obj:`(row, col)` according to :obj:`row`.
2937
(default: :obj:`True`)
3038
num_nodes (int, optional): The number of nodes. (default: :obj:`None`)
39+
return_edge_indices (bool, optional): Whether to additionally return
40+
the indices of edges traversed during the random walk.
41+
(default: :obj:`False`)
3142
3243
:rtype: :class:`LongTensor`
3344
"""
@@ -43,5 +54,11 @@ def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int,
4354
rowptr = row.new_zeros(num_nodes + 1)
4455
torch.cumsum(deg, 0, out=rowptr[1:])
4556

46-
return torch.ops.torch_cluster.random_walk(rowptr, col, start, walk_length,
47-
p, q)[0]
57+
node_seq, edge_seq = torch.ops.torch_cluster.random_walk(
58+
rowptr, col, start, walk_length, p, q,
59+
)
60+
61+
if return_edge_indices:
62+
return node_seq, edge_seq
63+
64+
return node_seq

0 commit comments

Comments
 (0)