Skip to content

Commit 9ede4b1

Browse files
committed
solving ruff stuff
1 parent 5bf2cd9 commit 9ede4b1

File tree

6 files changed

+89
-257
lines changed

6 files changed

+89
-257
lines changed

modules/models/combinatorial/spcc.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
2-
from torch.nn.parameter import Parameter
32
from topomodelx.base.aggregation import Aggregation
4-
from topomodelx.nn.combinatorial.hmc_layer import HBS, HBNS
3+
from topomodelx.nn.combinatorial.hmc_layer import HBNS, HBS
54

65

76
class SPCCLayer(torch.nn.Module):

modules/transforms/feature_liftings/feature_liftings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ def lift_features(
2929
torch_geometric.data.Data | dict
3030
The lifted data."""
3131
keys = sorted(
32-
[key.split("_")[1] for key in data.keys() if "incidence" in key]
33-
) # noqa : SIM118
32+
[key.split("_")[1] for key in data if "incidence" in key]
33+
) # : SIM118
3434
for elem in keys:
3535
if f"x_{elem}" not in data:
3636
idx_to_project = 0 if elem == "hyperedges" else int(elem) - 1

modules/transforms/liftings/graph2combinatorial/base.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from modules.transforms.liftings.lifting import GraphLifting
22

33

4-
5-
64
class Graph2CombinatorialLifting(GraphLifting):
75
r"""Abstract class for lifting graphs to combinatorial complexes.
86

modules/transforms/liftings/graph2combinatorial/sp_lifting.py

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
import torch
2-
import numpy as np
31
import networkx as nx
4-
import torch_geometric
2+
import numpy as np
53
import pyflagsercount as pfc
4+
import torch
5+
import torch_geometric
66
from toponetx.classes import CombinatorialComplex
7+
78
from modules.transforms.liftings.graph2combinatorial.base import (
89
Graph2CombinatorialLifting,
910
)
@@ -39,23 +40,23 @@ def _get_complex_connectivity(
3940
if adj[0] < adj[1]:
4041
connectivity[f"{connectivity_info}_{adj[0]}_{adj[1]}"] = (
4142
torch.from_numpy(
42-
(
43+
4344
combinatorial_complex.adjacency_matrix(
4445
adj[0], adj[1]
4546
).todense()
46-
)
47+
4748
)
4849
.to_sparse()
4950
.float()
5051
)
5152
else:
5253
connectivity[f"{connectivity_info}_{adj[0]}_{adj[1]}"] = (
5354
torch.from_numpy(
54-
(
55+
5556
combinatorial_complex.coadjacency_matrix(
5657
adj[0], adj[1]
5758
).todense()
58-
)
59+
5960
)
6061
.to_sparse()
6162
.float()
@@ -64,7 +65,7 @@ def _get_complex_connectivity(
6465
connectivity_info = "incidence"
6566
connectivity[f"{connectivity_info}_{inc[0]}_{inc[1]}"] = (
6667
torch.from_numpy(
67-
(combinatorial_complex.incidence_matrix(inc[0], inc[1]).todense())
68+
combinatorial_complex.incidence_matrix(inc[0], inc[1]).todense()
6869
)
6970
.to_sparse()
7071
.float()
@@ -107,9 +108,8 @@ def _create_flag_complex_from_dataset(self, dataset, complex_dim=2):
107108
list(zip(dataset.edge_index[0].tolist(), dataset.edge_index[1].tolist(), strict=False))
108109
)
109110

110-
dfc = DirectedQConnectivity(dataset_digraph, complex_dim)
111+
return DirectedQConnectivity(dataset_digraph, complex_dim)
111112

112-
return dfc
113113

114114
def lift_topology(self, data: torch_geometric.data.Data) -> dict:
115115

@@ -239,8 +239,7 @@ def _d_i_batched(self, i: int, simplices: torch.tensor) -> torch.tensor:
239239
mask = indices != min(i, n_vertices - 1)
240240
# Use advanced indexing to select vertices while preserving the
241241
# batch structure
242-
d_i = simplices[:, mask]
243-
return d_i
242+
return simplices[:, mask]
244243

245244
def _gen_q_faces_batched(self, simplices: torch.tensor, c: int) -> torch.tensor:
246245
r"""Compute the :math:`q`-dimensional faces of the simplices in the
@@ -359,14 +358,13 @@ def _multiple_contained_chunked(
359358
else:
360359
indices = torch.empty([2, 0], dtype=torch.long)
361360

362-
A = torch.sparse_coo_tensor(
361+
return torch.sparse_coo_tensor(
363362
indices,
364363
torch.ones(indices.size(1), dtype=torch.bool),
365364
size=(Ns, Nt),
366365
device="cpu",
367366
)
368367

369-
return A
370368

371369
def _alpha_q_contained_sparse(
372370
self, sigmas: torch.Tensor, taus: torch.Tensor, q: int, chunk_size: int = 1024
@@ -411,14 +409,13 @@ def _alpha_q_contained_sparse(
411409

412410
values = torch.ones(intersect._indices().size(1))
413411

414-
A = torch.sparse_coo_tensor(
412+
return torch.sparse_coo_tensor(
415413
intersect._indices(),
416414
values,
417415
dtype=torch.bool,
418416
size=(sigmas.size(0), taus.size(0)),
419417
)
420418

421-
return A
422419

423420
def qij_adj(
424421
self,
@@ -473,16 +470,14 @@ def qij_adj(
473470
di_sigmas, dj_taus, q, chunk_size
474471
)
475472

476-
indices = (
473+
return (
477474
torch.cat(
478475
(contained._indices().t(), alpha_q_contained._indices().t()), dim=0
479476
)
480477
.unique(dim=0)
481478
.t()
482479
)
483480

484-
return indices
485-
486481
def find_paths(self, indices: torch.tensor, threshold: int):
487482
r"""Find the paths in the adjacency matrix associated with the
488483
:math:`(q,
@@ -494,15 +489,13 @@ def find_paths(self, indices: torch.tensor, threshold: int):
494489
threshold : int
495490
The length threshold to select paths
496491
497-
498492
Returns
499493
-------
500494
paths : List[List]
501495
List of selected paths.
502496
"""
503497

504498
def dfs(node, adj_list, all_paths, path):
505-
506499
if node not in adj_list: # end of recursion
507500
if len(path) > threshold:
508501
all_paths = add_path(path.copy(), all_paths)
@@ -516,9 +509,9 @@ def dfs(node, adj_list, all_paths, path):
516509
dfs(new_node, adj_list, all_paths, path)
517510
path.pop()
518511

519-
if only_loops: # then we have another longest path
520-
if len(path) > threshold:
521-
all_paths = add_path(path.copy(), all_paths)
512+
if only_loops and len(
513+
path) > threshold: # then we have another longest path
514+
all_paths = add_path(path.copy(), all_paths)
522515

523516
return
524517

@@ -537,12 +530,8 @@ def is_subpath(p1, p2):
537530
return False
538531
if len(p1) == len(p2):
539532
return p1 == p2
540-
else:
541-
diff = len(p2) - len(p1)
542-
for i in range(diff + 1):
543-
if p2[i:i + len(p1)] == p1:
544-
return True
545-
return False
533+
diff = len(p2) - len(p1)
534+
return any(p2[i:i + len(p1)] == p1 for i in range(diff + 1))
546535

547536
def add_path(new_path, all_paths):
548537
for path in all_paths:
@@ -565,3 +554,4 @@ def add_path(new_path, all_paths):
565554
dfs(src, adj_list, all_paths, path)
566555

567556
return all_paths
557+

test/transforms/liftings/graph2combinatorial/test_sp_lifting.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from modules.transforms.liftings.graph2combinatorial.sp_lifting import (
66
DirectedQConnectivity,
7-
SimplicialPathsLifting,
87
)
98

109

0 commit comments

Comments
 (0)