Skip to content

Commit 37897eb

Browse files
Implementation of the graph induced lifting (graph to simplicial complex)
1 parent f42a978 commit 37897eb

File tree

7 files changed

+585
-485
lines changed

7 files changed

+585
-485
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
transform_type: 'lifting'
2-
transform_name: "SimplicialVietorisRipsLifting"
2+
transform_name: "SimplicialGraphInducedLifting"
33
complex_dim: 3
44
preserve_edge_attr: False
55
signed: True
6-
distance_threshold: 2.0
76
feature_lifting: ProjectionSum

modules/transforms/data_transform.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@
1515
from modules.transforms.liftings.graph2simplicial.clique_lifting import (
1616
SimplicialCliqueLifting,
1717
)
18-
from modules.transforms.liftings.graph2simplicial.vietoris_rips_lifting import (
19-
SimplicialVietorisRipsLifting,
18+
from modules.transforms.liftings.graph2simplicial.graph_induced_lifting import (
19+
SimplicialGraphInducedLifting,
2020
)
2121

2222
TRANSFORMS = {
2323
# Graph -> Hypergraph
2424
"HypergraphKNNLifting": HypergraphKNNLifting,
2525
# Graph -> Simplicial Complex
2626
"SimplicialCliqueLifting": SimplicialCliqueLifting,
27-
"SimplicialVietorisRipsLifting": SimplicialVietorisRipsLifting,
27+
"SimplicialGraphInducedLifting": SimplicialGraphInducedLifting,
2828
# Graph -> Cell Complex
2929
"CellCycleLifting": CellCycleLifting,
3030
# Feature Liftings

modules/transforms/liftings/graph2simplicial/vietoris_rips_lifting.py renamed to modules/transforms/liftings/graph2simplicial/graph_induced_lifting.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,20 @@
77
from modules.transforms.liftings.graph2simplicial.base import Graph2SimplicialLifting
88

99

10-
class SimplicialVietorisRipsLifting(Graph2SimplicialLifting):
11-
r"""Lifts graphs to simplicial complex domain using the Vietoris-Rips complex based on pairwise distances.
10+
class SimplicialGraphInducedLifting(Graph2SimplicialLifting):
11+
r"""Lifts graphs to simplicial complex domain by identifying connected subgraphs as simplices.
1212
1313
Parameters
1414
----------
15-
distance_threshold : float
16-
The maximum distance between vertices to form a simplex.
1715
**kwargs : optional
1816
Additional arguments for the class.
1917
"""
2018

21-
def __init__(self, distance_threshold=1.0, **kwargs):
19+
def __init__(self, **kwargs):
2220
super().__init__(**kwargs)
23-
self.distance_threshold = distance_threshold
2421

2522
def lift_topology(self, data: torch_geometric.data.Data) -> dict:
26-
r"""Lifts the topology of a graph to a simplicial complex using the Vietoris-Rips complex.
23+
r"""Lifts the topology of a graph to a simplicial complex by identifying connected subgraphs as simplices.
2724
2825
Parameters
2926
----------
@@ -40,15 +37,10 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict:
4037
all_nodes = list(graph.nodes)
4138
simplices = [set() for _ in range(2, self.complex_dim + 1)]
4239

43-
# Calculate pairwise shortest path distances
44-
path_lengths = dict(nx.all_pairs_shortest_path_length(graph))
45-
4640
for k in range(2, self.complex_dim + 1):
4741
for combination in combinations(all_nodes, k + 1):
48-
if all(
49-
path_lengths[u][v] <= self.distance_threshold
50-
for u, v in combinations(combination, 2)
51-
):
42+
subgraph = graph.subgraph(combination)
43+
if nx.is_connected(subgraph):
5244
simplices[k - 2].add(tuple(sorted(combination)))
5345

5446
for set_k_simplices in simplices:
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
"""Test the message passing module."""
2+
3+
import torch
4+
5+
from modules.data.utils.utils import load_manual_graph
6+
from modules.transforms.liftings.graph2simplicial.graph_induced_lifting import (
7+
SimplicialGraphInducedLifting,
8+
)
9+
10+
11+
class TestSimplicialCliqueLifting:
12+
"""Test the SimplicialCliqueLifting class."""
13+
14+
def setup_method(self):
15+
# Load the graph
16+
self.data = load_manual_graph()
17+
18+
# Initialise the SimplicialCliqueLifting class
19+
self.lifting_signed = SimplicialGraphInducedLifting(complex_dim=3, signed=True)
20+
self.lifting_unsigned = SimplicialGraphInducedLifting(
21+
complex_dim=3, signed=False
22+
)
23+
24+
def test_lift_topology(self):
25+
"""Test the lift_topology method."""
26+
27+
# Test the lift_topology method
28+
lifted_data_signed = self.lifting_signed.forward(self.data.clone())
29+
lifted_data_unsigned = self.lifting_unsigned.forward(self.data.clone())
30+
31+
expected_incidence_1_singular_values_unsigned = torch.tensor(
32+
[3.7417, 2.4495, 2.4495, 2.4495, 2.4495, 2.4495, 2.4495, 2.4495]
33+
)
34+
35+
expected_incidence_1_singular_values_signed = torch.tensor(
36+
[
37+
2.8284e00,
38+
2.8284e00,
39+
2.8284e00,
40+
2.8284e00,
41+
2.8284e00,
42+
2.8284e00,
43+
2.8284e00,
44+
6.8993e-08,
45+
]
46+
)
47+
48+
U, S_unsigned, V = torch.svd(lifted_data_unsigned.incidence_1.to_dense())
49+
U, S_signed, V = torch.svd(lifted_data_signed.incidence_1.to_dense())
50+
51+
assert torch.allclose(
52+
expected_incidence_1_singular_values_unsigned, S_unsigned, atol=1.0e-04
53+
), "Something is wrong with unsigned incidence_1 (nodes to edges)."
54+
assert torch.allclose(
55+
expected_incidence_1_singular_values_signed, S_signed, atol=1.0e-04
56+
), "Something is wrong with signed incidence_1 (nodes to edges)."
57+
58+
expected_incidence_2_singular_values_unsigned = torch.tensor(
59+
[
60+
4.1190,
61+
3.1623,
62+
3.1623,
63+
3.1623,
64+
3.0961,
65+
3.0000,
66+
3.0000,
67+
2.7564,
68+
2.0000,
69+
2.0000,
70+
2.0000,
71+
2.0000,
72+
2.0000,
73+
2.0000,
74+
2.0000,
75+
2.0000,
76+
2.0000,
77+
2.0000,
78+
2.0000,
79+
2.0000,
80+
2.0000,
81+
2.0000,
82+
2.0000,
83+
1.7321,
84+
1.6350,
85+
1.4142,
86+
1.4142,
87+
1.0849,
88+
]
89+
)
90+
91+
expected_incidence_2_singular_values_signed = torch.tensor(
92+
[
93+
2.8284e00,
94+
2.8284e00,
95+
2.8284e00,
96+
2.8284e00,
97+
2.8284e00,
98+
2.8284e00,
99+
2.8284e00,
100+
2.8284e00,
101+
2.8284e00,
102+
2.8284e00,
103+
2.8284e00,
104+
2.8284e00,
105+
2.8284e00,
106+
2.8284e00,
107+
2.8284e00,
108+
2.8284e00,
109+
2.6458e00,
110+
2.6458e00,
111+
2.2361e00,
112+
1.7321e00,
113+
1.7321e00,
114+
9.3758e-07,
115+
4.7145e-07,
116+
4.3417e-07,
117+
4.0241e-07,
118+
3.1333e-07,
119+
2.2512e-07,
120+
1.9160e-07,
121+
]
122+
)
123+
124+
U, S_unsigned, V = torch.svd(lifted_data_unsigned.incidence_2.to_dense())
125+
U, S_signed, V = torch.svd(lifted_data_signed.incidence_2.to_dense())
126+
assert torch.allclose(
127+
expected_incidence_2_singular_values_unsigned, S_unsigned, atol=1.0e-04
128+
), "Something is wrong with unsigned incidence_2 (edges to triangles)."
129+
assert torch.allclose(
130+
expected_incidence_2_singular_values_signed, S_signed, atol=1.0e-04
131+
), "Something is wrong with signed incidence_2 (edges to triangles)."
132+
133+
expected_incidence_3_singular_values_unsigned = torch.tensor(
134+
[
135+
3.8466,
136+
3.1379,
137+
3.0614,
138+
2.8749,
139+
2.8392,
140+
2.8125,
141+
2.5726,
142+
2.3709,
143+
2.2858,
144+
2.2369,
145+
2.1823,
146+
2.0724,
147+
2.0000,
148+
2.0000,
149+
2.0000,
150+
1.8937,
151+
1.7814,
152+
1.7321,
153+
1.7256,
154+
1.5469,
155+
1.5340,
156+
1.4834,
157+
1.4519,
158+
1.4359,
159+
1.4142,
160+
1.0525,
161+
1.0000,
162+
1.0000,
163+
1.0000,
164+
1.0000,
165+
0.9837,
166+
0.9462,
167+
0.8853,
168+
0.7850,
169+
]
170+
)
171+
172+
expected_incidence_3_singular_values_signed = torch.tensor(
173+
[
174+
2.8284e00,
175+
2.8284e00,
176+
2.8284e00,
177+
2.8284e00,
178+
2.8284e00,
179+
2.8284e00,
180+
2.8284e00,
181+
2.8284e00,
182+
2.8284e00,
183+
2.6933e00,
184+
2.6458e00,
185+
2.6458e00,
186+
2.6280e00,
187+
2.4495e00,
188+
2.3040e00,
189+
1.9475e00,
190+
1.7321e00,
191+
1.7321e00,
192+
1.7321e00,
193+
1.4823e00,
194+
1.0000e00,
195+
1.0000e00,
196+
1.0000e00,
197+
1.0000e00,
198+
1.0000e00,
199+
1.0000e00,
200+
1.0000e00,
201+
1.0000e00,
202+
1.0000e00,
203+
7.3584e-01,
204+
2.7959e-07,
205+
2.1776e-07,
206+
1.4498e-07,
207+
5.5373e-08,
208+
]
209+
)
210+
211+
U, S_unsigned, V = torch.svd(lifted_data_unsigned.incidence_3.to_dense())
212+
U, S_signed, V = torch.svd(lifted_data_signed.incidence_3.to_dense())
213+
214+
assert torch.allclose(
215+
expected_incidence_3_singular_values_unsigned, S_unsigned, atol=1.0e-04
216+
), "Something is wrong with unsigned incidence_3 (triangles to tetrahedrons)."
217+
assert torch.allclose(
218+
expected_incidence_3_singular_values_signed, S_signed, atol=1.0e-04
219+
), "Something is wrong with signed incidence_3 (triangles to tetrahedrons)."

test/transforms/liftings/graph2simplicial/test_vietoris_rips_lifting.py

Lines changed: 0 additions & 87 deletions
This file was deleted.

0 commit comments

Comments
 (0)