Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/tdamapper/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def deprecated(msg: str) -> Callable[..., Any]:
"""

def deprecated_func(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
def wrapper(*args: list[Any], **kwargs: Any) -> Any:
warnings.warn(msg, DeprecationWarning, stacklevel=2)
return func(*args, **kwargs)

Expand Down Expand Up @@ -179,10 +179,12 @@ def __repr__(self) -> str:
obj_noargs = type(self)()
args_repr = []
for k, v in self.__dict__.items():
if not self._is_param_public(k):
continue
v_default = getattr(obj_noargs, k)
v_default_repr = repr(v_default)
v_repr = repr(v)
if self._is_param_public(k) and not v_repr == v_default_repr:
if not v_repr == v_default_repr:
args_repr.append(f"{k}={v_repr}")
return f"{self.__class__.__name__}({', '.join(args_repr)})"

Expand Down Expand Up @@ -211,7 +213,7 @@ def profile(n_lines: int = 10) -> Callable[..., Any]:
"""

def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
def wrapper(*args: list[Any], **kwargs: Any) -> Any:
profiler = cProfile.Profile()
profiler.enable()
result = func(*args, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions src/tdamapper/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def _umap(X: NDArray[np.float_]) -> NDArray[np.float_]:


def run_mapper(
df: pd.DataFrame, **kwargs: dict[str, Any]
df: pd.DataFrame, **kwargs: Any
) -> Optional[tuple[nx.Graph, pd.DataFrame]]:
"""
Runs the Mapper algorithm on the provided DataFrame and returns the Mapper
Expand Down Expand Up @@ -301,7 +301,7 @@ def create_mapper_figure(
df_y: pd.DataFrame,
df_target: pd.DataFrame,
mapper_graph: nx.Graph,
**kwargs: dict[str, Any],
**kwargs: Any,
) -> go.Figure:
"""
Renders the Mapper graph as a Plotly figure.
Expand Down
12 changes: 11 additions & 1 deletion src/tdamapper/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,14 +279,24 @@ class TrivialCover(ParamsMixin, Generic[T]):
dataset.
"""

def fit(self, X: ArrayRead[T]) -> TrivialCover[T]:
"""
Fit the cover algorithm to the data.

:param X: A dataset of n points. Ignored.
:return: self
"""
return self

def apply(self, X: ArrayRead[T]) -> Iterator[list[int]]:
"""
Covers the dataset with a single open set.

:param X: A dataset of n points.
:return: A generator of lists of ids.
"""
yield list(range(0, len(X)))
if len(X) > 0:
yield list(range(0, len(X)))


class FailSafeClustering(ParamsMixin, Generic[T]):
Expand Down
2 changes: 2 additions & 0 deletions src/tdamapper/cover.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,8 @@ def fit(self, X: ArrayRead[NDArray[np.float_]]) -> BaseCubicalCover:
:param X: A dataset of n points.
:return: The object itself.
"""
if len(X) == 0:
return self
X_ = np.asarray(X).reshape(len(X), -1).astype(float)
if self.overlap_frac is None:
dim = 1 if X_.ndim == 1 else X_.shape[1]
Expand Down
12 changes: 6 additions & 6 deletions src/tdamapper/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,19 +159,19 @@ def fit(
"""
y_ = X if y is None else y
X, y_ = self._validate_X_y(X, y_)
self._cover = TrivialCover() if self.cover is None else self.cover
self._clustering = (
TrivialClustering() if self.clustering is None else self.clustering
)
self._cover = TrivialCover()
if self.cover is not None:
self._cover = clone(self.cover)
self._clustering = TrivialClustering()
if self.clustering is not None:
self._clustering = clone(self.clustering)
self._verbose = self.verbose
self._failsafe = self.failsafe
if self._failsafe:
self._clustering = FailSafeClustering(
clustering=self._clustering,
verbose=self._verbose,
)
self._cover = clone(self._cover)
self._clustering = clone(self._clustering)
self._n_jobs = self.n_jobs
self.graph_ = mapper_graph(
X,
Expand Down
59 changes: 57 additions & 2 deletions src/tdamapper/utils/heap.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""
This module implements a max-heap data structure that allows for efficient
retrieval and removal of the maximum element. The heap supports adding
elements, retrieving the maximum element, and removing the maximum element
while maintaining the heap property.
"""

from __future__ import annotations

from typing import Generic, Iterator, Optional, Protocol, TypeVar
Expand All @@ -16,6 +23,9 @@ def _parent(i: int) -> int:


class Comparable(Protocol):
"""
Protocol for comparison methods required for a key in the heap.
"""

def __lt__(self: K, other: K) -> bool: ...

Expand All @@ -32,6 +42,14 @@ def __ge__(self: K, other: K) -> bool: ...


class _HeapNode(Generic[K, V]):
"""
A node in the heap that holds a key-value pair.

The key is used for comparison, and the value is stored alongside it.

:param key: The key used for comparison.
:param value: The value associated with the key.
"""

_key: K
_value: V
Expand All @@ -41,6 +59,11 @@ def __init__(self, key: K, value: V) -> None:
self._value = value

def get(self) -> tuple[K, V]:
"""
Returns the key-value pair stored in the node.

:return: A tuple containing the key and value.
"""
return self._key, self._value

def __lt__(self, other: _HeapNode[K, V]) -> bool:
Expand All @@ -57,6 +80,12 @@ def __ge__(self, other: _HeapNode[K, V]) -> bool:


class MaxHeap(Generic[K, V]):
"""
A max-heap implementation that allows for efficient retrieval of the
maximum element. This heap supports adding elements, retrieving the maximum
element, and removing the maximum element while maintaining the heap
property.
"""

_heap: list[_HeapNode[K, V]]
_iter: Iterator[_HeapNode[K, V]]
Expand All @@ -75,12 +104,32 @@ def __next__(self) -> tuple[K, V]:
def __len__(self) -> int:
return len(self._heap)

def is_empty(self) -> bool:
"""
Check if the heap is empty.

:return: True if the heap is empty, False otherwise.
"""
return len(self._heap) == 0

def top(self) -> Optional[tuple[K, V]]:
"""
Returns the maximum element in the heap without removing it.

:return: A tuple containing the key and value of the maximum element,
or None if the heap is empty.
"""
if not self._heap:
return None
return self._heap[0].get()

def pop(self) -> Optional[tuple[K, V]]:
"""
Removes and returns the maximum element from the heap.

:return: A tuple containing the key and value of the maximum element,
or None if the heap is empty.
"""
if not self._heap:
return None
max_val = self._heap[0]
Expand All @@ -89,8 +138,14 @@ def pop(self) -> Optional[tuple[K, V]]:
self._bubble_down()
return max_val.get()

def add(self, key: K, val: V) -> None:
self._heap.append(_HeapNode(key, val))
def add(self, key: K, value: V) -> None:
"""
Adds a new key-value pair to the heap.

:param key: The key used for comparison.
:param value: The value associated with the key.
"""
self._heap.append(_HeapNode(key, value))
self._bubble_up()

def _get_local_max(self, i: int) -> int:
Expand Down
17 changes: 6 additions & 11 deletions src/tdamapper/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_supported_metrics() -> list[MetricLiteral]:
return list(get_args(MetricLiteral))


def euclidean(**kwargs: dict[str, Any]) -> Metric[Any]:
def euclidean() -> Metric[Any]:
"""
Return the Euclidean distance function for vectors.

Expand All @@ -70,7 +70,7 @@ def euclidean(**kwargs: dict[str, Any]) -> Metric[Any]:
return _metrics.euclidean


def manhattan(**kwargs: dict[str, Any]) -> Metric[Any]:
def manhattan() -> Metric[Any]:
"""
Return the Manhattan distance function for vectors.

Expand All @@ -82,7 +82,7 @@ def manhattan(**kwargs: dict[str, Any]) -> Metric[Any]:
return _metrics.manhattan


def chebyshev(**kwargs: dict[str, Any]) -> Metric[Any]:
def chebyshev() -> Metric[Any]:
"""
Return the Chebyshev distance function for vectors.

Expand All @@ -94,7 +94,7 @@ def chebyshev(**kwargs: dict[str, Any]) -> Metric[Any]:
return _metrics.chebyshev


def minkowski(**kwargs: dict[str, Any]) -> Metric[Any]:
def minkowski(p: Union[int, float]) -> Metric[Any]:
"""
Return the Minkowski distance function for order p on vectors.

Expand All @@ -106,9 +106,6 @@ def minkowski(**kwargs: dict[str, Any]) -> Metric[Any]:
:param p: The order of the Minkowski distance.
:return: The Minkowski distance function.
"""
p = kwargs.get("p", 2)
if not isinstance(p, (int, float)):
raise TypeError("p must be an integer or a float")
if p == 1:
return manhattan()
if p == 2:
Expand All @@ -122,7 +119,7 @@ def dist(x: Any, y: Any) -> float:
return dist


def cosine(**kwargs: dict[str, Any]) -> Metric[Any]:
def cosine() -> Metric[Any]:
"""
Return the cosine distance function for vectors.

Expand All @@ -145,9 +142,7 @@ def cosine(**kwargs: dict[str, Any]) -> Metric[Any]:
return _metrics.cosine


def get_metric(
metric: Union[MetricLiteral, Metric[Any]], **kwargs: dict[str, Any]
) -> Metric[Any]:
def get_metric(metric: Union[MetricLiteral, Metric[Any]], **kwargs: Any) -> Metric[Any]:
"""
Return a distance function based on the specified string or callable.

Expand Down
33 changes: 30 additions & 3 deletions src/tdamapper/utils/unionfind.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,36 @@
"""
This module implements a Union-Find data structure that supports union and
find operations.
"""

from typing import Any, Iterable


class UnionFind:
"""
A Union-Find data structure that supports union and find operations.

This implementation uses path compression for efficient find operations
and union by size to keep the tree flat. It allows for efficient
determination of connected components in a set of elements.

:param X: An iterable of elements to initialize the Union-Find structure.
"""

_parent: dict[Any, Any]
_size: dict[Any, int]

def __init__(self, X: Iterable[Any]):
self._parent = {x: x for x in X}
self._size = {x: 1 for x in X}
def __init__(self, items: Iterable[Any]):
self._parent = {x: x for x in items}
self._size = {x: 1 for x in items}

def find(self, x: Any) -> Any:
"""
Finds the class of an element, applying path compression.

:param x: The element to find the class of.
:return: The representative of the class containing x.
"""
root = x
while root != self._parent[root]:
root = self._parent[root]
Expand All @@ -22,6 +42,13 @@ def find(self, x: Any) -> Any:
return root

def union(self, x: Any, y: Any) -> Any:
"""
Unites the classes of two elements.

:param x: The first element.
:param y: The second element.
:return: The representative of the class after the union operation.
"""
x, y = self.find(x), self.find(y)
if x != y:
x_size, y_size = self._size[x], self._size[y]
Expand Down
3 changes: 2 additions & 1 deletion src/tdamapper/utils/vptree_flat/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def build(self) -> VPArray[T]:

:return: A tuple containing the constructed vp-tree and the VPArray.
"""
self._build_iter()
if self._array.size() > 0:
self._build_iter()
return self._array

def _build_iter(self) -> None:
Expand Down
9 changes: 7 additions & 2 deletions src/tdamapper/utils/vptree_hier/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,15 @@ def build(self) -> tuple[Tree[T], VPArray[T]]:

:return: A tuple containing the constructed vp-tree and the VPArray.
"""
tree = self._build_rec(0, self._array.size())
if self._array.size() > 0:
tree = self._build_rec(0, self._array.size())
else:
tree = Leaf(0, 0)
return tree, self._array

def _build_rec(self, start: int, end: int) -> Tree[T]:
if end - start <= self._leaf_capacity:
return Leaf(start, end)
mid = _mid(start, end)
self._update(start, end)
v_point = self._array.get_point(start)
Expand All @@ -106,7 +111,7 @@ def _build_rec(self, start: int, end: int) -> Tree[T]:
self._array.set_distance(start, v_radius)
left: Tree[T]
right: Tree[T]
if (end - start <= 2 * self._leaf_capacity) or (v_radius <= self._leaf_radius):
if v_radius <= self._leaf_radius:
left = Leaf(start + 1, mid)
right = Leaf(mid, end)
else:
Expand Down
Loading