@@ -833,8 +833,10 @@ def test_min_pivot_sliced():
833833 for c in costs :
834834 assert c > 0
835835
836- # test with the minkowski metric
836+ # test with different metrics
837837 ot .sliced .min_pivot_sliced (x , y , thetas = thetas , metric = "minkowski" )
838+ ot .sliced .min_pivot_sliced (x , y , thetas = thetas , metric = "euclidean" )
839+ ot .sliced .min_pivot_sliced (x , y , thetas = thetas , metric = "cityblock" )
838840
839841 # test with an unsupported metric
840842 with pytest .raises (ValueError ):
@@ -872,6 +874,7 @@ def test_expected_sliced():
872874
873875 # test without provided thetas
874876 ot .sliced .expected_sliced (x , y , n_proj = n_proj , log = True )
877+ ot .sliced .expected_sliced (x , y , a , b , n_proj = n_proj , log = True )
875878
876879 # test with invalid shapes
877880 with pytest .raises (AssertionError ):
@@ -921,8 +924,8 @@ def test_sliced_plans_backends(nx):
921924
922925 x_b , y_b , a_b , b_b = nx .from_numpy (x , y , a , b )
923926
924- thetas = ot .sliced .get_random_projections (d , n_proj , seed = 0 , backend = nx ).T
925- thetas_t = nx .to_numpy (thetas )
927+ thetas_b = ot .sliced .get_random_projections (d , n_proj , seed = 0 , backend = nx ).T
928+ thetas = nx .to_numpy (thetas_b )
926929
927930 context = (
928931 nullcontext ()
@@ -932,7 +935,7 @@ def test_sliced_plans_backends(nx):
932935
933936 with context :
934937 _ , expected_cost_b = ot .sliced .expected_sliced (
935- x_b , y_b , a_b , b_b , dense = True , thetas = thetas_t
938+ x_b , y_b , a_b , b_b , dense = True , thetas = thetas_b
936939 )
937940 # result should be the same than numpy version
938941 _ , expected_cost = ot .sliced .expected_sliced (
@@ -942,7 +945,7 @@ def test_sliced_plans_backends(nx):
942945
943946 # for min_pivot
944947 _ , min_cost_b = ot .sliced .min_pivot_sliced (
945- x_b , y_b , a_b , b_b , dense = True , thetas = thetas_t
948+ x_b , y_b , a_b , b_b , dense = True , thetas = thetas_b
946949 )
947950 # result should be the same than numpy version
948951 _ , min_cost = ot .sliced .min_pivot_sliced (x , y , a , b , dense = True , thetas = thetas )
0 commit comments