77import pytest
88
99from tdamapper .utils .metrics import get_metric
10+ from tdamapper .utils .vptree import VPTree
1011from tdamapper .utils .vptree_flat .vptree import VPTree as FVPT
1112from tdamapper .utils .vptree_hier .vptree import VPTree as HVPT
1213from tests .ball_tree import SkBallTree
@@ -117,6 +118,17 @@ def _check_rec(start, end):
117118 _check_rec (0 , len (data ))
118119
119120
121+ @pytest .mark .parametrize ("builder" , [HVPT , FVPT ])
122+ @pytest .mark .parametrize ("dataset" , [[], [1 ], [1 , 2 ]])
123+ def test_vptree_small_dataset (builder , dataset ):
124+ """
125+ Test the vp-tree implementations with an empty dataset.
126+ """
127+ vpt = builder (dataset , metric = lambda x , y : abs (x - y ))
128+ array = vpt .array
129+ assert array .size () == len (dataset )
130+
131+
120132@pytest .mark .parametrize ("pivoting" , ["disabled" , "random" , "furthest" ])
121133@pytest .mark .parametrize ("eps" , [0.1 , 0.5 ])
122134@pytest .mark .parametrize ("neighbors" , [2 , 10 ])
@@ -141,6 +153,30 @@ def test_vptree(builder, dataset, metric, eps, neighbors, pivoting):
141153 _test_nn_search (dataset , metric , vpt )
142154
143155
156+ @pytest .mark .parametrize ("pivoting" , ["disabled" , "random" , "furthest" ])
157+ @pytest .mark .parametrize ("eps" , [0.1 , 0.5 ])
158+ @pytest .mark .parametrize ("neighbors" , [2 , 10 ])
159+ @pytest .mark .parametrize ("kind" , ["flat" , "hierarchical" ])
160+ @pytest .mark .parametrize ("metric" , ["euclidean" , "manhattan" ])
161+ @pytest .mark .parametrize ("dataset" , [SIMPLE , TWO_LINES ])
162+ def test_vptree_public (kind , dataset , metric , eps , neighbors , pivoting ):
163+ """
164+ Test the vp-tree implementations with various datasets and metrics.
165+ """
166+ metric = get_metric (metric )
167+ vpt = VPTree (
168+ dataset ,
169+ kind = kind ,
170+ metric = metric ,
171+ leaf_radius = eps ,
172+ leaf_capacity = neighbors ,
173+ pivoting = pivoting ,
174+ )
175+ _test_ball_search (dataset , metric , vpt , eps )
176+ _test_knn_search (dataset , metric , vpt , neighbors )
177+ _test_nn_search (dataset , metric , vpt )
178+
179+
144180@pytest .mark .parametrize ("pivoting" , ["disabled" , "random" , "furthest" ])
145181@pytest .mark .parametrize ("eps" , [0.1 , 0.5 ])
146182@pytest .mark .parametrize ("neighbors" , [2 , 10 ])
0 commit comments