1+ """
2+ Unit tests for the vp-tree implementations.
3+ """
4+
15import random
26
37import numpy as np
8+ import pytest
49
510from tdamapper .utils .metrics import get_metric
611from tdamapper .utils .vptree_flat .vptree import VPTree as FVPT
712from tdamapper .utils .vptree_hier .vptree import VPTree as HVPT
813from tests .ball_tree import SkBallTree
14+ from tests .test_utils import (
15+ dataset_grid ,
16+ dataset_random ,
17+ dataset_simple ,
18+ dataset_two_lines ,
19+ )
20+
21+
22+ def distance (metric ):
23+ """
24+ Get the distance function for the specified metric.
25+ """
26+ return get_metric (metric )
27+
928
10- distance = get_metric ("euclidean" )
29+ def distance_refs (metric , data ):
30+ """
31+ Get the distance function for the specified metric, using data references.
32+ This is useful for testing with datasets that are not numpy arrays.
33+ """
34+ d = get_metric (metric )
1135
36+ def dist_refs (i , j ):
37+ return d (data [i , :], data [j , :])
38+
39+ return dist_refs
1240
13- def dataset (dim = 10 , num = 1000 ):
14- return [np .random .rand (dim ) for _ in range (num )]
1541
42+ SIMPLE = dataset_simple ()
43+ SIMPLE_REFS = np .array (list (range (len (SIMPLE ))))
1644
17- eps = 0.25
45+ TWO_LINES = dataset_two_lines (100 )
46+ TWO_LINES_REFS = np .array (list (range (len (TWO_LINES ))))
1847
19- neighbors = 5
48+ GRID = dataset_grid (10 )
49+ GRID_REFS = np .array (list (range (len (GRID ))))
2050
51+ RANDOM = dataset_random (2 , 100 )
52+ RANDOM_REFS = np .array (list (range (len (RANDOM ))))
2153
22- def _test_ball_search (data , dist , vpt ):
54+
55+ def _test_ball_search (data , dist , vpt , eps ):
2356 for _ in range (len (data ) // 10 ):
2457 point = random .choice (data )
2558 ball = vpt .ball_search (point , eps )
@@ -31,7 +64,7 @@ def _test_ball_search(data, dist, vpt):
3164 assert any (d (x , y ) == 0.0 for y in ball )
3265
3366
34- def _test_knn_search (data , dist , vpt ):
67+ def _test_knn_search (data , dist , vpt , neighbors ):
3568 for _ in range (len (data ) // 10 ):
3669 point = random .choice (data )
3770 neigh = vpt .knn_search (point , neighbors )
@@ -54,65 +87,28 @@ def _test_nn_search(data, dist, vpt):
5487 assert 0.0 == d (val , neigh [0 ])
5588
5689
57- def _test_vptree (builder , data , dist ):
58- vpt = builder (data , metric = dist , leaf_radius = eps , leaf_capacity = neighbors )
59- _test_ball_search (data , dist , vpt )
60- _test_knn_search (data , dist , vpt )
61- _test_nn_search (data , dist , vpt )
62- vpt = builder (
63- data ,
64- metric = dist ,
65- leaf_radius = eps ,
66- leaf_capacity = neighbors ,
67- pivoting = "random" ,
68- )
69- _test_ball_search (data , dist , vpt )
70- _test_knn_search (data , dist , vpt )
71- _test_nn_search (data , dist , vpt )
90+ def _test_vptree (builder , data , dist , eps , neighbors , pivoting ):
7291 vpt = builder (
7392 data ,
7493 metric = dist ,
7594 leaf_radius = eps ,
7695 leaf_capacity = neighbors ,
77- pivoting = "furthest" ,
96+ pivoting = pivoting ,
7897 )
79- _test_ball_search (data , dist , vpt )
80- _test_knn_search (data , dist , vpt )
98+ _test_ball_search (data , dist , vpt , eps )
99+ _test_knn_search (data , dist , vpt , neighbors )
81100 _test_nn_search (data , dist , vpt )
82101
83102
84- def test_vptree_hier_refs ():
85- data = dataset ()
86- data_refs = list (range (len (data )))
87- d = get_metric (distance )
88-
89- def dist_refs (i , j ):
90- return d (data [i ], data [j ])
91-
92- _test_vptree (HVPT , data_refs , dist_refs )
93-
94-
95- def test_vptree_hier_data ():
96- data = dataset ()
97- _test_vptree (HVPT , data , distance )
98-
99-
100- def test_vptree_flat_refs ():
101- data = dataset ()
102- data_refs = list (range (len (data )))
103- d = get_metric (distance )
104-
105- def dist_refs (i , j ):
106- return d (data [i ], data [j ])
107-
108- _test_vptree (FVPT , data_refs , dist_refs )
109-
110-
111- def test_vptree_flat_data ():
112- data = dataset ()
113- _test_vptree (FVPT , data , distance )
114-
115-
116- def test_ball_tree_data ():
117- data = dataset ()
118- _test_vptree (SkBallTree , data , distance )
103+ @pytest .mark .parametrize ("pivoting" , ["disabled" , "random" , "furthest" ])
104+ @pytest .mark .parametrize ("eps" , [0.1 , 0.25 , 0.5 ])
105+ @pytest .mark .parametrize ("neighbors" , [1 , 5 , 10 ])
106+ @pytest .mark .parametrize ("builder" , [HVPT , FVPT ])
107+ @pytest .mark .parametrize ("metric" , ["euclidean" ])
108+ @pytest .mark .parametrize ("dataset" , [SIMPLE , TWO_LINES , GRID , RANDOM ])
109+ def test_vptree (builder , dataset , metric , eps , neighbors , pivoting ):
110+ """
111+ Test the vp-tree implementations with various datasets and metrics.
112+ """
113+ metric = get_metric (metric )
114+ _test_vptree (builder , dataset , metric , eps , neighbors , pivoting )
0 commit comments