Skip to content

Commit a8f365c

Browse files
committed
Added parametric tests
1 parent cf9e14d commit a8f365c

File tree

3 files changed

+108
-106
lines changed

3 files changed

+108
-106
lines changed

tests/test_unit_cover.py

Lines changed: 1 addition & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
Unit tests for the cover algorithms.
33
"""
44

5-
import numpy as np
65
import pytest
76

87
from tdamapper.core import TrivialCover
@@ -14,51 +13,7 @@
1413
StandardCubicalCover,
1514
)
1615
from tdamapper.utils.unionfind import UnionFind
17-
18-
19-
def dataset_simple():
20-
"""
21-
Create a simple dataset of points in a 2D space.
22-
23-
This dataset consists of four points forming the corners of a rectangle
24-
such that two sides are longer than the other two.
25-
"""
26-
return [
27-
np.array([0.0, 1.0]),
28-
np.array([1.1, 0.0]),
29-
np.array([0.0, 0.0]),
30-
np.array([1.1, 1.0]),
31-
]
32-
33-
34-
def dataset_random(dim=1, num=1000):
35-
"""
36-
Create a random dataset of points in the unit square.
37-
"""
38-
return [np.random.rand(dim) for _ in range(num)]
39-
40-
41-
def dataset_two_lines(num=1000):
42-
"""
43-
Create a dataset consisting of two lines in the unit square.
44-
One line is horizontal at y=0, the other is vertical at x=1.
45-
"""
46-
t = np.linspace(0.0, 1.0, num)
47-
line1 = np.array([[x, 0.0] for x in t])
48-
line2 = np.array([[x, 1.0] for x in t])
49-
return np.concatenate((line1, line2), axis=0)
50-
51-
52-
def dataset_grid(num=1000):
53-
"""
54-
Create a grid dataset in the unit square.
55-
The grid consists of points evenly spaced in both dimensions.
56-
"""
57-
t = np.linspace(0.0, 1.0, num)
58-
s = np.linspace(0.0, 1.0, num)
59-
grid = np.array([[x, y] for x in t for y in s])
60-
return grid
61-
16+
from tests.test_utils import dataset_grid, dataset_simple, dataset_two_lines
6217

6318
SIMPLE = dataset_simple()
6419

tests/test_unit_vptree.py

Lines changed: 56 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,58 @@
1+
"""
2+
Unit tests for the vp-tree implementations.
3+
"""
4+
15
import random
26

37
import numpy as np
8+
import pytest
49

510
from tdamapper.utils.metrics import get_metric
611
from tdamapper.utils.vptree_flat.vptree import VPTree as FVPT
712
from tdamapper.utils.vptree_hier.vptree import VPTree as HVPT
813
from tests.ball_tree import SkBallTree
14+
from tests.test_utils import (
15+
dataset_grid,
16+
dataset_random,
17+
dataset_simple,
18+
dataset_two_lines,
19+
)
20+
21+
22+
def distance(metric):
23+
"""
24+
Get the distance function for the specified metric.
25+
"""
26+
return get_metric(metric)
27+
928

10-
distance = get_metric("euclidean")
29+
def distance_refs(metric, data):
30+
"""
31+
Get the distance function for the specified metric, using data references.
32+
This is useful for testing with datasets that are not numpy arrays.
33+
"""
34+
d = get_metric(metric)
1135

36+
def dist_refs(i, j):
37+
return d(data[i, :], data[j, :])
38+
39+
return dist_refs
1240

13-
def dataset(dim=10, num=1000):
14-
return [np.random.rand(dim) for _ in range(num)]
1541

42+
SIMPLE = dataset_simple()
43+
SIMPLE_REFS = np.array(list(range(len(SIMPLE))))
1644

17-
eps = 0.25
45+
TWO_LINES = dataset_two_lines(100)
46+
TWO_LINES_REFS = np.array(list(range(len(TWO_LINES))))
1847

19-
neighbors = 5
48+
GRID = dataset_grid(10)
49+
GRID_REFS = np.array(list(range(len(GRID))))
2050

51+
RANDOM = dataset_random(2, 100)
52+
RANDOM_REFS = np.array(list(range(len(RANDOM))))
2153

22-
def _test_ball_search(data, dist, vpt):
54+
55+
def _test_ball_search(data, dist, vpt, eps):
2356
for _ in range(len(data) // 10):
2457
point = random.choice(data)
2558
ball = vpt.ball_search(point, eps)
@@ -31,7 +64,7 @@ def _test_ball_search(data, dist, vpt):
3164
assert any(d(x, y) == 0.0 for y in ball)
3265

3366

34-
def _test_knn_search(data, dist, vpt):
67+
def _test_knn_search(data, dist, vpt, neighbors):
3568
for _ in range(len(data) // 10):
3669
point = random.choice(data)
3770
neigh = vpt.knn_search(point, neighbors)
@@ -54,65 +87,28 @@ def _test_nn_search(data, dist, vpt):
5487
assert 0.0 == d(val, neigh[0])
5588

5689

57-
def _test_vptree(builder, data, dist):
58-
vpt = builder(data, metric=dist, leaf_radius=eps, leaf_capacity=neighbors)
59-
_test_ball_search(data, dist, vpt)
60-
_test_knn_search(data, dist, vpt)
61-
_test_nn_search(data, dist, vpt)
62-
vpt = builder(
63-
data,
64-
metric=dist,
65-
leaf_radius=eps,
66-
leaf_capacity=neighbors,
67-
pivoting="random",
68-
)
69-
_test_ball_search(data, dist, vpt)
70-
_test_knn_search(data, dist, vpt)
71-
_test_nn_search(data, dist, vpt)
90+
def _test_vptree(builder, data, dist, eps, neighbors, pivoting):
7291
vpt = builder(
7392
data,
7493
metric=dist,
7594
leaf_radius=eps,
7695
leaf_capacity=neighbors,
77-
pivoting="furthest",
96+
pivoting=pivoting,
7897
)
79-
_test_ball_search(data, dist, vpt)
80-
_test_knn_search(data, dist, vpt)
98+
_test_ball_search(data, dist, vpt, eps)
99+
_test_knn_search(data, dist, vpt, neighbors)
81100
_test_nn_search(data, dist, vpt)
82101

83102

84-
def test_vptree_hier_refs():
85-
data = dataset()
86-
data_refs = list(range(len(data)))
87-
d = get_metric(distance)
88-
89-
def dist_refs(i, j):
90-
return d(data[i], data[j])
91-
92-
_test_vptree(HVPT, data_refs, dist_refs)
93-
94-
95-
def test_vptree_hier_data():
96-
data = dataset()
97-
_test_vptree(HVPT, data, distance)
98-
99-
100-
def test_vptree_flat_refs():
101-
data = dataset()
102-
data_refs = list(range(len(data)))
103-
d = get_metric(distance)
104-
105-
def dist_refs(i, j):
106-
return d(data[i], data[j])
107-
108-
_test_vptree(FVPT, data_refs, dist_refs)
109-
110-
111-
def test_vptree_flat_data():
112-
data = dataset()
113-
_test_vptree(FVPT, data, distance)
114-
115-
116-
def test_ball_tree_data():
117-
data = dataset()
118-
_test_vptree(SkBallTree, data, distance)
103+
@pytest.mark.parametrize("pivoting", ["disabled", "random", "furthest"])
104+
@pytest.mark.parametrize("eps", [0.1, 0.25, 0.5])
105+
@pytest.mark.parametrize("neighbors", [1, 5, 10])
106+
@pytest.mark.parametrize("builder", [HVPT, FVPT])
107+
@pytest.mark.parametrize("metric", ["euclidean"])
108+
@pytest.mark.parametrize("dataset", [SIMPLE, TWO_LINES, GRID, RANDOM])
109+
def test_vptree(builder, dataset, metric, eps, neighbors, pivoting):
110+
"""
111+
Test the vp-tree implementations with various datasets and metrics.
112+
"""
113+
metric = get_metric(metric)
114+
_test_vptree(builder, dataset, metric, eps, neighbors, pivoting)

tests/test_utils.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""
2+
Test utilities.
3+
"""
4+
5+
import numpy as np
6+
7+
8+
def dataset_simple():
9+
"""
10+
Create a simple dataset of points in a 2D space.
11+
12+
This dataset consists of four points forming the corners of a rectangle
13+
such that two sides are longer than the other two.
14+
"""
15+
return np.array(
16+
[
17+
[0.0, 1.0],
18+
[1.1, 0.0],
19+
[0.0, 0.0],
20+
[1.1, 1.0],
21+
]
22+
)
23+
24+
25+
def dataset_random(dim=1, num=1000):
26+
"""
27+
Create a random dataset of points in the unit square.
28+
"""
29+
return np.array([np.random.rand(dim) for _ in range(num)])
30+
31+
32+
def dataset_two_lines(num=1000):
33+
"""
34+
Create a dataset consisting of two lines in the unit square.
35+
One line is horizontal at y=0, the other is vertical at x=1.
36+
"""
37+
t = np.linspace(0.0, 1.0, num)
38+
line1 = np.array([[x, 0.0] for x in t])
39+
line2 = np.array([[x, 1.0] for x in t])
40+
return np.concatenate((line1, line2), axis=0)
41+
42+
43+
def dataset_grid(num=1000):
44+
"""
45+
Create a grid dataset in the unit square.
46+
The grid consists of points evenly spaced in both dimensions.
47+
"""
48+
t = np.linspace(0.0, 1.0, num)
49+
s = np.linspace(0.0, 1.0, num)
50+
grid = np.array([[x, y] for x in t for y in s])
51+
return grid

0 commit comments

Comments
 (0)