1- import torch
2- import numpy as np
31import networkx as nx
4- import torch_geometric
2+ import numpy as np
53import pyflagsercount as pfc
4+ import torch
5+ import torch_geometric
66from toponetx .classes import CombinatorialComplex
7+
78from modules .transforms .liftings .graph2combinatorial .base import (
89 Graph2CombinatorialLifting ,
910)
@@ -39,23 +40,23 @@ def _get_complex_connectivity(
3940 if adj [0 ] < adj [1 ]:
4041 connectivity [f"{ connectivity_info } _{ adj [0 ]} _{ adj [1 ]} " ] = (
4142 torch .from_numpy (
42- (
43+
4344 combinatorial_complex .adjacency_matrix (
4445 adj [0 ], adj [1 ]
4546 ).todense ()
46- )
47+
4748 )
4849 .to_sparse ()
4950 .float ()
5051 )
5152 else :
5253 connectivity [f"{ connectivity_info } _{ adj [0 ]} _{ adj [1 ]} " ] = (
5354 torch .from_numpy (
54- (
55+
5556 combinatorial_complex .coadjacency_matrix (
5657 adj [0 ], adj [1 ]
5758 ).todense ()
58- )
59+
5960 )
6061 .to_sparse ()
6162 .float ()
@@ -64,7 +65,7 @@ def _get_complex_connectivity(
6465 connectivity_info = "incidence"
6566 connectivity [f"{ connectivity_info } _{ inc [0 ]} _{ inc [1 ]} " ] = (
6667 torch .from_numpy (
67- ( combinatorial_complex .incidence_matrix (inc [0 ], inc [1 ]).todense () )
68+ combinatorial_complex .incidence_matrix (inc [0 ], inc [1 ]).todense ()
6869 )
6970 .to_sparse ()
7071 .float ()
@@ -107,9 +108,8 @@ def _create_flag_complex_from_dataset(self, dataset, complex_dim=2):
107108 list (zip (dataset .edge_index [0 ].tolist (), dataset .edge_index [1 ].tolist (), strict = False ))
108109 )
109110
110- dfc = DirectedQConnectivity (dataset_digraph , complex_dim )
111+ return DirectedQConnectivity (dataset_digraph , complex_dim )
111112
112- return dfc
113113
114114 def lift_topology (self , data : torch_geometric .data .Data ) -> dict :
115115
@@ -239,8 +239,7 @@ def _d_i_batched(self, i: int, simplices: torch.tensor) -> torch.tensor:
239239 mask = indices != min (i , n_vertices - 1 )
240240 # Use advanced indexing to select vertices while preserving the
241241 # batch structure
242- d_i = simplices [:, mask ]
243- return d_i
242+ return simplices [:, mask ]
244243
245244 def _gen_q_faces_batched (self , simplices : torch .tensor , c : int ) -> torch .tensor :
246245 r"""Compute the :math:`q`-dimensional faces of the simplices in the
@@ -359,14 +358,13 @@ def _multiple_contained_chunked(
359358 else :
360359 indices = torch .empty ([2 , 0 ], dtype = torch .long )
361360
362- A = torch .sparse_coo_tensor (
361+ return torch .sparse_coo_tensor (
363362 indices ,
364363 torch .ones (indices .size (1 ), dtype = torch .bool ),
365364 size = (Ns , Nt ),
366365 device = "cpu" ,
367366 )
368367
369- return A
370368
371369 def _alpha_q_contained_sparse (
372370 self , sigmas : torch .Tensor , taus : torch .Tensor , q : int , chunk_size : int = 1024
@@ -411,14 +409,13 @@ def _alpha_q_contained_sparse(
411409
412410 values = torch .ones (intersect ._indices ().size (1 ))
413411
414- A = torch .sparse_coo_tensor (
412+ return torch .sparse_coo_tensor (
415413 intersect ._indices (),
416414 values ,
417415 dtype = torch .bool ,
418416 size = (sigmas .size (0 ), taus .size (0 )),
419417 )
420418
421- return A
422419
423420 def qij_adj (
424421 self ,
@@ -473,16 +470,14 @@ def qij_adj(
473470 di_sigmas , dj_taus , q , chunk_size
474471 )
475472
476- indices = (
473+ return (
477474 torch .cat (
478475 (contained ._indices ().t (), alpha_q_contained ._indices ().t ()), dim = 0
479476 )
480477 .unique (dim = 0 )
481478 .t ()
482479 )
483480
484- return indices
485-
486481 def find_paths (self , indices : torch .tensor , threshold : int ):
487482 r"""Find the paths in the adjacency matrix associated with the
488483 :math:`(q,
@@ -494,15 +489,13 @@ def find_paths(self, indices: torch.tensor, threshold: int):
494489 threshold : int
495490 The length threshold to select paths
496491
497-
498492 Returns
499493 -------
500494 paths : List[List]
501495 List of selected paths.
502496 """
503497
504498 def dfs (node , adj_list , all_paths , path ):
505-
506499 if node not in adj_list : # end of recursion
507500 if len (path ) > threshold :
508501 all_paths = add_path (path .copy (), all_paths )
@@ -516,9 +509,9 @@ def dfs(node, adj_list, all_paths, path):
516509 dfs (new_node , adj_list , all_paths , path )
517510 path .pop ()
518511
519- if only_loops : # then we have another longest path
520- if len ( path ) > threshold :
521- all_paths = add_path (path .copy (), all_paths )
512+ if only_loops and len (
513+ path ) > threshold : # then we have another longest path
514+ all_paths = add_path (path .copy (), all_paths )
522515
523516 return
524517
@@ -537,12 +530,8 @@ def is_subpath(p1, p2):
537530 return False
538531 if len (p1 ) == len (p2 ):
539532 return p1 == p2
540- else :
541- diff = len (p2 ) - len (p1 )
542- for i in range (diff + 1 ):
543- if p2 [i :i + len (p1 )] == p1 :
544- return True
545- return False
533+ diff = len (p2 ) - len (p1 )
534+ return any (p2 [i :i + len (p1 )] == p1 for i in range (diff + 1 ))
546535
547536 def add_path (new_path , all_paths ):
548537 for path in all_paths :
@@ -565,3 +554,4 @@ def add_path(new_path, all_paths):
565554 dfs (src , adj_list , all_paths , path )
566555
567556 return all_paths
557+
0 commit comments