Skip to content

Commit 5ff0566

Browse files
committed
Fixed edge cases
1 parent b8e0f4e commit 5ff0566

File tree

8 files changed

+107
-4
lines changed

8 files changed

+107
-4
lines changed

src/tdamapper/core.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,14 +279,24 @@ class TrivialCover(ParamsMixin, Generic[T]):
279279
dataset.
280280
"""
281281

282+
def fit(self, X: ArrayRead[T]) -> TrivialCover[T]:
283+
"""
284+
Fit the cover algorithm to the data.
285+
286+
:param X: A dataset of n points. Ignored.
287+
:return: self
288+
"""
289+
return self
290+
282291
def apply(self, X: ArrayRead[T]) -> Iterator[list[int]]:
283292
"""
284293
Covers the dataset with a single open set.
285294
286295
:param X: A dataset of n points.
287296
:return: A generator of lists of ids.
288297
"""
289-
yield list(range(0, len(X)))
298+
if len(X) > 0:
299+
yield list(range(0, len(X)))
290300

291301

292302
class FailSafeClustering(ParamsMixin, Generic[T]):

src/tdamapper/cover.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,8 @@ def fit(self, X: ArrayRead[NDArray[np.float_]]) -> BaseCubicalCover:
377377
:param X: A dataset of n points.
378378
:return: The object itself.
379379
"""
380+
if len(X) == 0:
381+
return self
380382
X_ = np.asarray(X).reshape(len(X), -1).astype(float)
381383
if self.overlap_frac is None:
382384
dim = 1 if X_.ndim == 1 else X_.shape[1]

src/tdamapper/utils/vptree_flat/builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ def build(self) -> VPArray[T]:
9191
9292
:return: A tuple containing the constructed vp-tree and the VPArray.
9393
"""
94-
self._build_iter()
94+
if self._array.size() > 0:
95+
self._build_iter()
9596
return self._array
9697

9798
def _build_iter(self) -> None:

src/tdamapper/utils/vptree_hier/builder.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,10 @@ def build(self) -> tuple[Tree[T], VPArray[T]]:
9494
9595
:return: A tuple containing the constructed vp-tree and the VPArray.
9696
"""
97-
tree = self._build_rec(0, self._array.size())
97+
if self._array.size() > 0:
98+
tree = self._build_rec(0, self._array.size())
99+
else:
100+
tree = Leaf(0, 0)
98101
return tree, self._array
99102

100103
def _build_rec(self, start: int, end: int) -> Tree[T]:

tests/test_unit_cover.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import math
66

7+
import numpy as np
78
import pytest
89

910
from tdamapper.core import TrivialCover
@@ -62,6 +63,26 @@ def count_components(charts):
6263
return len(unique_components)
6364

6465

66+
@pytest.mark.parametrize(
67+
"cover",
68+
[
69+
TrivialCover(),
70+
BallCover(radius=0.1, metric="euclidean"),
71+
KNNCover(neighbors=1, metric="euclidean"),
72+
StandardCubicalCover(n_intervals=2, overlap_frac=0.5),
73+
ProximityCubicalCover(n_intervals=2, overlap_frac=0.5),
74+
],
75+
)
76+
def test_cover_empty(cover):
77+
"""
78+
Test that the cover algorithms handle empty datasets correctly.
79+
"""
80+
empty_data = np.array([])
81+
cover.fit(empty_data)
82+
charts = cover.apply(empty_data)
83+
assert len(list(charts)) == 0
84+
85+
6586
@pytest.mark.parametrize(
6687
"dataset, cover, num_charts, num_components",
6788
[
@@ -115,7 +136,20 @@ def count_components(charts):
115136
(GRID, KNNCover(neighbors=1, metric="euclidean"), 100, 100),
116137
(GRID, KNNCover(neighbors=10, metric="euclidean"), None, 1),
117138
(GRID, StandardCubicalCover(n_intervals=2, overlap_frac=0.5), 4, 1),
139+
(GRID, StandardCubicalCover(n_intervals=2), 4, 1),
118140
(GRID, ProximityCubicalCover(n_intervals=2, overlap_frac=0.5), 4, 1),
141+
(
142+
GRID,
143+
CubicalCover(n_intervals=2, overlap_frac=0.5, algorithm="proximity"),
144+
4,
145+
1,
146+
),
147+
(
148+
GRID,
149+
CubicalCover(n_intervals=2, overlap_frac=0.5, algorithm="standard"),
150+
4,
151+
1,
152+
),
119153
],
120154
)
121155
def test_cover(dataset, cover, num_charts, num_components):

tests/test_unit_heap.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,8 @@ def test_max_heap(data):
3434
for x in data:
3535
m.add(x, x)
3636
_check_heap_property(list(m))
37+
assert len(m) == len(data)
38+
if not data:
39+
assert m.is_empty()
40+
assert m.top() is None
41+
assert m.pop() is None

tests/test_unit_metrics.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""
2+
Unit tests for the metrics module.
3+
"""
4+
15
import math
26

37
import numpy as np
@@ -66,8 +70,11 @@ def _check_values(m1, m2, a, b):
6670
(euclidean(), get_metric("euclidean")),
6771
(manhattan(), get_metric("manhattan")),
6872
(chebyshev(), get_metric("chebyshev")),
69-
(minkowski(p=3), get_metric("minkowski", p=3)),
73+
(manhattan(), get_metric("minkowski", p=1)),
74+
(euclidean(), get_metric("minkowski", p=2)),
7075
(minkowski(p=2.5), get_metric("minkowski", p=2.5)),
76+
(minkowski(p=3), get_metric("minkowski", p=3)),
77+
(chebyshev(), get_metric("minkowski", p=float("inf"))),
7178
(cosine(), get_metric("cosine")),
7279
],
7380
)
@@ -87,3 +94,8 @@ def test_supported_metrics():
8794
]
8895
supported_metrics = get_supported_metrics()
8996
assert set(supported_metrics) == set(expected_metrics)
97+
98+
99+
def test_non_existent_metric():
100+
with pytest.raises(ValueError):
101+
get_metric("non_existent_metric")

tests/test_unit_vptree.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88

99
from tdamapper.utils.metrics import get_metric
10+
from tdamapper.utils.vptree import VPTree
1011
from tdamapper.utils.vptree_flat.vptree import VPTree as FVPT
1112
from tdamapper.utils.vptree_hier.vptree import VPTree as HVPT
1213
from tests.ball_tree import SkBallTree
@@ -117,6 +118,17 @@ def _check_rec(start, end):
117118
_check_rec(0, len(data))
118119

119120

121+
@pytest.mark.parametrize("builder", [HVPT, FVPT])
122+
@pytest.mark.parametrize("dataset", [[], [1], [1, 2]])
123+
def test_vptree_small_dataset(builder, dataset):
124+
"""
125+
Test the vp-tree implementations with an empty dataset.
126+
"""
127+
vpt = builder(dataset, metric=lambda x, y: abs(x - y))
128+
array = vpt.array
129+
assert array.size() == len(dataset)
130+
131+
120132
@pytest.mark.parametrize("pivoting", ["disabled", "random", "furthest"])
121133
@pytest.mark.parametrize("eps", [0.1, 0.5])
122134
@pytest.mark.parametrize("neighbors", [2, 10])
@@ -141,6 +153,30 @@ def test_vptree(builder, dataset, metric, eps, neighbors, pivoting):
141153
_test_nn_search(dataset, metric, vpt)
142154

143155

156+
@pytest.mark.parametrize("pivoting", ["disabled", "random", "furthest"])
157+
@pytest.mark.parametrize("eps", [0.1, 0.5])
158+
@pytest.mark.parametrize("neighbors", [2, 10])
159+
@pytest.mark.parametrize("kind", ["flat", "hierarchical"])
160+
@pytest.mark.parametrize("metric", ["euclidean", "manhattan"])
161+
@pytest.mark.parametrize("dataset", [SIMPLE, TWO_LINES])
162+
def test_vptree_public(kind, dataset, metric, eps, neighbors, pivoting):
163+
"""
164+
Test the vp-tree implementations with various datasets and metrics.
165+
"""
166+
metric = get_metric(metric)
167+
vpt = VPTree(
168+
dataset,
169+
kind=kind,
170+
metric=metric,
171+
leaf_radius=eps,
172+
leaf_capacity=neighbors,
173+
pivoting=pivoting,
174+
)
175+
_test_ball_search(dataset, metric, vpt, eps)
176+
_test_knn_search(dataset, metric, vpt, neighbors)
177+
_test_nn_search(dataset, metric, vpt)
178+
179+
144180
@pytest.mark.parametrize("pivoting", ["disabled", "random", "furthest"])
145181
@pytest.mark.parametrize("eps", [0.1, 0.5])
146182
@pytest.mark.parametrize("neighbors", [2, 10])

0 commit comments

Comments
 (0)