Skip to content

Commit 94d8593

Browse files
committed
Expanded test cases
1 parent cbbc903 commit 94d8593

File tree

2 files changed

+57
-7
lines changed

2 files changed

+57
-7
lines changed

modules/transforms/liftings/hypergraph2simplicial/heat_lifting.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def weighted_simplex(simplex: tuple) -> dict:
6969
simplex = tuple of vertex labels
7070
7171
Returns:
72-
dictionary mapping simplices to strictly positive topological.
72+
dictionary mapping simplices to strictly positive topological weights.
7373
7474
"""
7575
weights = defaultdict(float)
@@ -98,7 +98,7 @@ def unit_simplex(sigma: tuple, c: float = 1.0, closure: bool = False) -> dict:
9898
## From: https://stackoverflow.com/questions/42138681/faster-numpy-solution-instead-of-itertools-combinations
9999
@cache
100100
def _combs(n: int, k: int) -> np.ndarray:
101-
"""Faster numpy-version of itertools.combinations over the standard indest set {0, 1, ..., n}"""
101+
"""Faster numpy-version of itertools.combinations over the standard index set {0, 1, ..., n}"""
102102
if n < k:
103103
return np.empty(shape=(0,), dtype=int)
104104
a = np.ones((k, n - k + 1), dtype=int)
@@ -114,12 +114,12 @@ def _combs(n: int, k: int) -> np.ndarray:
114114

115115

116116
def downward_closure(H: list, d: int = 1, coeffs: bool = False):
117-
"""Constructs a simplicial complex from a hypergraph by taking its downward closure, optionally counting higher order interactions.
117+
"""Constructs the d-simplices of the downward closure of a hypergraph, optionally counting higher order interactions.
118118
119119
This function implicitly converts a hypergraph into a simplicial complex by taking the downward closure of each hyperedge
120120
and collecting the corresponding d-simplices. Note that only the d-simplices are returned (maximal p-simplices for p < d
121-
won't be included!). If coeffs = True, a n x D sparse matrix is returned whose non-zero values at index (i,j) count the number of
122-
times the corresponding i-th d-simplex appeared in a j-dimensional hyperedge.
121+
won't be included!). If coeffs = True, a n x D sparse matrix is returned whose (i,j) non-zero values count the number of
122+
times the i-th d-simplex appeared in a j-dimensional hyperedge.
123123
124124
The output of this function is meant to be used in conjunction with top_weights to compute topological weights.
125125
@@ -145,6 +145,7 @@ def downward_closure(H: list, d: int = 1, coeffs: bool = False):
145145
return S
146146

147147
## Extract the lengths of the hyperedges and how many d-simplices we may need
148+
## NOTE: The use of hirola here speeds up the computation tremendously
148149
from hirola import HashTable
149150

150151
H_sizes = np.array([len(he) for he in H])
@@ -203,7 +204,7 @@ def top_weights(simplices: np.ndarray, coeffs: sparray, normalize: bool = False)
203204

204205

205206
def vertex_counts(H: list) -> np.ndarray:
206-
"""Returns the number of times a"""
207+
"""Counts the vertex cardinalities of a set of hyperedges."""
207208
N = np.max([np.max(he) for he in normalize_hg(H)]) + 1
208209
v_counts = np.zeros(N)
209210
for he in normalize_hg(H):
@@ -236,7 +237,6 @@ def lift_topology(self, data: torch_geometric.data) -> dict:
236237
dict
237238
The lifted topology.
238239
"""
239-
print("Lifting to weighted simplicial complex")
240240

241241
## Convert incidence to simple list of hyperedges
242242
R, C = data.incidence_hyperedges.coalesce().indices().detach().numpy()

test/transforms/liftings/hypergraph2simplicial/test_heat_lifting.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
from collections import Counter
22

33
import numpy as np
4+
import toponetx as tnx
5+
import torch
6+
from torch_geometric.data import Data
47

58
from modules.transforms.liftings.hypergraph2simplicial.heat_lifting import (
9+
HypergraphHeatLifting,
610
downward_closure,
711
top_weights,
812
unit_simplex,
@@ -110,3 +114,49 @@ def test_downward_closure(self):
110114
for d in range(3):
111115
d_map = top_weights(*downward_closure(H, d=d, coeffs=True))
112116
assert np.all([np.isclose(sc_lift[s], w) for s, w in d_map.items()])
117+
118+
def test_lift_api(self):
119+
H = [
120+
(0,),
121+
(0, 1),
122+
(1, 3),
123+
(1, 2, 3),
124+
(0, 1, 2, 3),
125+
(0, 1, 4),
126+
(0, 1, 3),
127+
(2, 5),
128+
(0, 2, 5),
129+
(0, 2, 4, 5),
130+
]
131+
## Testing the actual lifting API
132+
lifting = HypergraphHeatLifting(complex_dim=2)
133+
hg = tnx.ColoredHyperGraph()
134+
hg.add_cells_from(H)
135+
B = hg.incidence_matrix(0, 1).tocsr()
136+
B = torch.sparse_coo_tensor(np.array(B.nonzero()), B.data, B.shape)
137+
138+
## Note the only requirement for the lift is the hyperedges
139+
lifted_dataset = lifting.lift_topology(Data(incidence_hyperedges=B))
140+
141+
assert isinstance(lifted_dataset, Data)
142+
assert (
143+
hasattr(lifted_dataset, "incidence_0")
144+
and lifted_dataset.incidence_0.shape[1] == 6
145+
)
146+
assert (
147+
hasattr(lifted_dataset, "incidence_1")
148+
and lifted_dataset.incidence_1.shape[1] == 12
149+
)
150+
assert (
151+
hasattr(lifted_dataset, "incidence_2")
152+
and lifted_dataset.incidence_2.shape[1] == 9
153+
)
154+
assert (
155+
hasattr(lifted_dataset, "weights_0") and len(lifted_dataset.weights_0) == 6
156+
)
157+
assert (
158+
hasattr(lifted_dataset, "weights_1") and len(lifted_dataset.weights_1) == 12
159+
)
160+
assert (
161+
hasattr(lifted_dataset, "weights_2") and len(lifted_dataset.weights_2) == 9
162+
)

0 commit comments

Comments
 (0)