Skip to content

Commit 2f4b675

Browse files
committed
update tests with backend and improve code coverage
1 parent 7561530 commit 2f4b675

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

test/test_sliced.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -833,8 +833,10 @@ def test_min_pivot_sliced():
833833
for c in costs:
834834
assert c > 0
835835

836-
# test with the minkowski metric
836+
# test with different metrics
837837
ot.sliced.min_pivot_sliced(x, y, thetas=thetas, metric="minkowski")
838+
ot.sliced.min_pivot_sliced(x, y, thetas=thetas, metric="euclidean")
839+
ot.sliced.min_pivot_sliced(x, y, thetas=thetas, metric="cityblock")
838840

839841
# test with an unsupported metric
840842
with pytest.raises(ValueError):
@@ -872,6 +874,7 @@ def test_expected_sliced():
872874

873875
# test without provided thetas
874876
ot.sliced.expected_sliced(x, y, n_proj=n_proj, log=True)
877+
ot.sliced.expected_sliced(x, y, a, b, n_proj=n_proj, log=True)
875878

876879
# test with invalid shapes
877880
with pytest.raises(AssertionError):
@@ -921,8 +924,8 @@ def test_sliced_plans_backends(nx):
921924

922925
x_b, y_b, a_b, b_b = nx.from_numpy(x, y, a, b)
923926

924-
thetas = ot.sliced.get_random_projections(d, n_proj, seed=0, backend=nx).T
925-
thetas_t = nx.to_numpy(thetas)
927+
thetas_b = ot.sliced.get_random_projections(d, n_proj, seed=0, backend=nx).T
928+
thetas = nx.to_numpy(thetas_b)
926929

927930
context = (
928931
nullcontext()
@@ -932,7 +935,7 @@ def test_sliced_plans_backends(nx):
932935

933936
with context:
934937
_, expected_cost_b = ot.sliced.expected_sliced(
935-
x_b, y_b, a_b, b_b, dense=True, thetas=thetas_t
938+
x_b, y_b, a_b, b_b, dense=True, thetas=thetas_b
936939
)
937940
# result should be the same than numpy version
938941
_, expected_cost = ot.sliced.expected_sliced(
@@ -942,7 +945,7 @@ def test_sliced_plans_backends(nx):
942945

943946
# for min_pivot
944947
_, min_cost_b = ot.sliced.min_pivot_sliced(
945-
x_b, y_b, a_b, b_b, dense=True, thetas=thetas_t
948+
x_b, y_b, a_b, b_b, dense=True, thetas=thetas_b
946949
)
947950
# result should be the same than numpy version
948951
_, min_cost = ot.sliced.min_pivot_sliced(x, y, a, b, dense=True, thetas=thetas)

0 commit comments

Comments
 (0)