Skip to content

Commit 7561530

Browse files
committed
update tests with backend
1 parent 85056c6 commit 7561530

File tree

2 files changed

+146
-87
lines changed

2 files changed

+146
-87
lines changed

ot/sliced.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,6 @@ def sliced_plans(
695695
n_proj=None,
696696
dense=False,
697697
log=False,
698-
backend=None,
699698
):
700699
r"""
701700
Computes all the permutations that sort the projections of two `(n, d)`
@@ -731,9 +730,6 @@ def sliced_plans(
731730
The number of projection directions. Required if thetas is None.
732731
log : bool, optional
733732
If True, returns additional logging information. Default is False.
734-
backend : ot.backend, optional
735-
Backend to use for computations. If None, the backend is inferred from
736-
the input arrays. Default is None.
737733
738734
Returns
739735
-------
@@ -746,7 +742,17 @@ def sliced_plans(
746742
Returned only if `log` is True.
747743
"""
748744

749-
nx = get_backend(X, Y) if backend is None else backend
745+
X, Y = list_to_array(X, Y)
746+
747+
if a is not None and b is not None and thetas is None:
748+
nx = get_backend(X, Y, a, b)
749+
elif a is not None and b is not None and thetas is not None:
750+
nx = get_backend(X, Y, a, b, thetas)
751+
elif a is None and b is None and thetas is not None:
752+
nx = get_backend(X, Y, thetas)
753+
else:
754+
nx = get_backend(X, Y)
755+
750756
assert X.ndim == 2, f"X must be a 2d array, got {X.ndim}d array instead"
751757
assert Y.ndim == 2, f"Y must be a 2d array, got {Y.ndim}d array instead"
752758

@@ -870,7 +876,6 @@ def min_pivot_sliced(
870876
dense=True,
871877
log=False,
872878
warm_theta=None,
873-
backend=None,
874879
):
875880
r"""
876881
Computes the cost and permutation associated to the min-Pivot Sliced
@@ -924,9 +929,6 @@ def min_pivot_sliced(
924929
If True, returns additional logging information. Default is False.
925930
warm_theta : array-like, shape (d,), optional
926931
A theta to add to the list of thetas. Default is None.
927-
backend : ot.backend, optional
928-
Backend to use for computations. If None, the backend is inferred from
929-
the input arrays. Default is None.
930932
931933
Returns
932934
-------
@@ -961,16 +963,24 @@ def min_pivot_sliced(
961963
2.125
962964
"""
963965

964-
nx = get_backend(X, Y) if backend is None else backend
966+
X, Y = list_to_array(X, Y)
967+
968+
if a is not None and b is not None and thetas is None:
969+
nx = get_backend(X, Y, a, b)
970+
elif a is not None and b is not None and thetas is not None:
971+
nx = get_backend(X, Y, a, b, thetas)
972+
elif a is None and b is None and thetas is not None:
973+
nx = get_backend(X, Y, thetas)
974+
else:
975+
nx = get_backend(X, Y)
976+
965977
assert X.ndim == 2, f"X must be a 2d array, got {X.ndim}d array instead"
966978
assert Y.ndim == 2, f"Y must be a 2d array, got {Y.ndim}d array instead"
967979

968980
assert (
969981
X.shape[1] == Y.shape[1]
970982
), f"X ({X.shape}) and Y ({Y.shape}) must have the same number of columns"
971983

972-
nx = get_backend(X, Y) if backend is None else backend
973-
974984
log_dict = {}
975985
G, costs, log_dict_plans = sliced_plans(
976986
X,
@@ -983,7 +993,6 @@ def min_pivot_sliced(
983993
n_proj=n_proj,
984994
warm_theta=warm_theta,
985995
log=True,
986-
backend=nx,
987996
)
988997
pos_min = nx.argmin(costs)
989998
cost = costs[pos_min]
@@ -1024,7 +1033,6 @@ def expected_sliced(
10241033
n_proj=None,
10251034
dense=True,
10261035
log=False,
1027-
backend=None,
10281036
beta=0.0,
10291037
):
10301038
r"""
@@ -1072,9 +1080,6 @@ def expected_sliced(
10721080
format.
10731081
log : bool, optional
10741082
If True, returns additional logging information. Default is False.
1075-
backend : ot.backend, optional
1076-
Backend to use for computations. If None, the backend is inferred from
1077-
the input arrays. Default is None.
10781083
beta : float, optional
10791084
Inverse-temperature parameter which weights each projection's
10801085
contribution to the expected plan. Default is 0 (uniform weighting).
@@ -1110,7 +1115,16 @@ def expected_sliced(
11101115
2.625
11111116
"""
11121117

1113-
nx = get_backend(X, Y) if backend is None else backend
1118+
X, Y = list_to_array(X, Y)
1119+
1120+
if a is not None and b is not None and thetas is None:
1121+
nx = get_backend(X, Y, a, b)
1122+
elif a is not None and b is not None and thetas is not None:
1123+
nx = get_backend(X, Y, a, b, thetas)
1124+
elif a is None and b is None and thetas is not None:
1125+
nx = get_backend(X, Y, thetas)
1126+
else:
1127+
nx = get_backend(X, Y)
11141128

11151129
assert X.ndim == 2, f"X must be a 2d array, got {X.ndim}d array instead"
11161130
assert Y.ndim == 2, f"Y must be a 2d array, got {Y.ndim}d array instead"
@@ -1130,7 +1144,7 @@ def expected_sliced(
11301144

11311145
log_dict = {}
11321146
G, costs, log_dict_plans = sliced_plans(
1133-
X, Y, a, b, metric, p, thetas, n_proj=n_proj, log=True, backend=nx
1147+
X, Y, a, b, metric, p, thetas, n_proj=n_proj, log=True
11341148
)
11351149
if log:
11361150
log_dict = {"thetas": log_dict_plans["thetas"], "costs": costs, "G": G}

test/test_sliced.py

Lines changed: 113 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ def test_linear_sliced_sphere_backend_type_devices(nx):
729729
np.testing.assert_almost_equal(sw_np, nx.to_numpy(valb))
730730

731731

732-
def test_sliced_permutations(nx):
732+
def test_sliced_permutations():
733733
n = 4
734734
n_proj = 10
735735
d = 2
@@ -738,15 +738,7 @@ def test_sliced_permutations(nx):
738738
x = rng.randn(n, 2)
739739
y = rng.randn(n, 2)
740740

741-
x_b, y_b = nx.from_numpy(x, y)
742741
thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T
743-
thetas_b = nx.from_numpy(thetas)
744-
745-
plan, _ = ot.sliced.sliced_plans(x, y, thetas=thetas, dense=True)
746-
plan_b, _, _ = ot.sliced.sliced_plans(
747-
x_b, y_b, thetas=thetas_b, log=True, dense=True, backend=nx
748-
)
749-
np.testing.assert_almost_equal(plan, nx.to_numpy(plan_b))
750742

751743
# test without provided thetas
752744
_, _ = ot.sliced.sliced_plans(x, y, n_proj=n_proj)
@@ -756,7 +748,7 @@ def test_sliced_permutations(nx):
756748
ot.sliced.sliced_plans(x[:, 1:], y, thetas=thetas)
757749

758750

759-
def test_sliced_plans(nx):
751+
def test_sliced_plans():
760752
x = [1, 2]
761753
with pytest.raises(AssertionError):
762754
ot.sliced.min_pivot_sliced(x, x, n_proj=2)
@@ -775,39 +767,23 @@ def test_sliced_plans(nx):
775767
b = rng.uniform(0, 1, m)
776768
b /= b.sum()
777769

778-
x_b, y_b = nx.from_numpy(x, y)
779-
print(x_b)
780-
t_X = torch.tensor(x_b)
781-
t_Y = torch.tensor(y_b)
782770
thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T
783-
thetas_b = nx.from_numpy(thetas)
771+
772+
# test with a and b not uniform
773+
ot.sliced.sliced_plans(x, y, a, b, thetas=thetas, dense=True)
784774

785775
# test with the minkowski metric
786-
ot.sliced.min_pivot_sliced(x, y, thetas=thetas, metric="minkowski")
776+
ot.sliced.sliced_plans(x, y, thetas=thetas, metric="minkowski")
787777

788778
# test with an unsupported metric
789779
with pytest.raises(ValueError):
790-
ot.sliced.min_pivot_sliced(x, y, thetas=thetas, metric="mahalanobis")
780+
ot.sliced.sliced_plans(x, y, thetas=thetas, metric="mahalanobis")
791781

792782
# test with a warm theta
793-
ot.sliced.min_pivot_sliced(x, y, n_proj=10, warm_theta=thetas[-1])
783+
ot.sliced.sliced_plans(x, y, n_proj=10, warm_theta=thetas[-1])
794784

795-
# test with a and b uniform
796-
plan, _ = ot.sliced.sliced_plans(x, y, thetas=thetas, dense=True)
797-
plan_b, _, _ = ot.sliced.sliced_plans(
798-
x_b, y_b, thetas=thetas_b, log=True, dense=True, backend=nx
799-
)
800-
np.testing.assert_almost_equal(plan, nx.to_numpy(plan_b))
801785

802-
# test with a and b not uniform
803-
plan, _ = ot.sliced.sliced_plans(x, y, a, b, thetas=thetas, dense=True)
804-
plan_b, _, _ = ot.sliced.sliced_plans(
805-
x_b, y_b, a, b, thetas=thetas_b, log=True, dense=True, backend=nx
806-
)
807-
np.testing.assert_almost_equal(plan, nx.to_numpy(plan_b))
808-
809-
810-
def test_min_pivot_sliced(nx):
786+
def test_min_pivot_sliced():
811787
x = [1, 2]
812788
with pytest.raises(AssertionError):
813789
ot.sliced.min_pivot_sliced(x, x, n_proj=2)
@@ -825,17 +801,13 @@ def test_min_pivot_sliced(nx):
825801
b = rng.uniform(0, 1, m)
826802
b /= b.sum()
827803

828-
x_b, y_b = nx.from_numpy(x, y)
829804
thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T
830-
thetas_b = nx.from_numpy(thetas)
831805

832-
G, min_cost = ot.sliced.min_pivot_sliced(x, y, a, b, thetas=thetas, dense=True)
833-
G_b, min_cost_b, _ = ot.sliced.min_pivot_sliced(
834-
x_b, y_b, a, b, thetas=thetas_b, log=True, dense=True
835-
)
806+
# identity of the indiscernibles
807+
_, min_cost = ot.min_pivot_sliced(x, x, a, a, n_proj=10)
808+
np.testing.assert_almost_equal(min_cost, 0.0)
836809

837-
np.testing.assert_almost_equal(G, nx.to_numpy(G_b))
838-
np.testing.assert_almost_equal(min_cost, nx.to_numpy(min_cost_b))
810+
_, min_cost = ot.sliced.min_pivot_sliced(x, y, a, b, thetas=thetas, dense=True)
839811

840812
# result should be an upper-bound of W2 and relatively close
841813
w2 = ot.emd2(a, b, ot.dist(x, y))
@@ -849,8 +821,30 @@ def test_min_pivot_sliced(nx):
849821
with pytest.raises(AssertionError):
850822
ot.sliced.min_pivot_sliced(x[:, 1:], y, thetas=thetas)
851823

824+
# test the logs
825+
_, min_cost, log = ot.sliced.min_pivot_sliced(
826+
x, y, a, b, thetas=thetas, dense=False, log=True
827+
)
828+
assert len(log) == 5
829+
costs = log["costs"]
830+
assert len(costs) == thetas.shape[0]
831+
assert len(log["min_theta"]) == d
832+
assert (log["thetas"] == thetas).all()
833+
for c in costs:
834+
assert c > 0
835+
836+
# test with the minkowski metric
837+
ot.sliced.min_pivot_sliced(x, y, thetas=thetas, metric="minkowski")
838+
839+
# test with an unsupported metric
840+
with pytest.raises(ValueError):
841+
ot.sliced.min_pivot_sliced(x, y, thetas=thetas, metric="mahalanobis")
852842

853-
def test_expected_sliced(nx):
843+
# test with a warm theta
844+
ot.sliced.min_pivot_sliced(x, y, n_proj=10, warm_theta=thetas[-1])
845+
846+
847+
def test_expected_sliced():
854848
x = [1, 2]
855849
with pytest.raises(AssertionError):
856850
ot.sliced.min_pivot_sliced(x, x, n_proj=2)
@@ -868,9 +862,67 @@ def test_expected_sliced(nx):
868862
b = rng.uniform(0, 1, m)
869863
b /= b.sum()
870864

871-
x_b, y_b = nx.from_numpy(x, y)
872865
thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T
873-
thetas_b = nx.from_numpy(thetas)
866+
867+
_, expected_cost = ot.sliced.expected_sliced(x, y, a, b, dense=True, thetas=thetas)
868+
# result should be a coarse upper-bound of W2
869+
w2 = ot.emd2(a, b, ot.dist(x, y))
870+
assert expected_cost >= w2
871+
assert expected_cost <= 3 * w2
872+
873+
# test without provided thetas
874+
ot.sliced.expected_sliced(x, y, n_proj=n_proj, log=True)
875+
876+
# test with invalid shapes
877+
with pytest.raises(AssertionError):
878+
ot.sliced.min_pivot_sliced(x[:, 1:], y, thetas=thetas)
879+
880+
# with a small temperature (i.e. large beta), the cost should be close
881+
# to min_pivot
882+
_, expected_cost = ot.sliced.expected_sliced(
883+
x, y, a, b, thetas=thetas, dense=True, beta=100.0
884+
)
885+
_, min_cost = ot.sliced.min_pivot_sliced(x, y, a, b, thetas=thetas, dense=True)
886+
np.testing.assert_almost_equal(expected_cost, min_cost, decimal=3)
887+
888+
# test the logs
889+
_, min_cost, log = ot.sliced.expected_sliced(
890+
x, y, a, b, thetas=thetas, dense=False, log=True
891+
)
892+
assert len(log) == 4
893+
costs = log["costs"]
894+
assert len(costs) == thetas.shape[0]
895+
assert len(log["weights"]) == thetas.shape[0]
896+
assert (log["thetas"] == thetas).all()
897+
for c in costs:
898+
assert c > 0
899+
900+
# test with the minkowski metric
901+
ot.sliced.expected_sliced(x, y, thetas=thetas, metric="minkowski")
902+
903+
# test with an unsupported metric
904+
with pytest.raises(ValueError):
905+
ot.sliced.expected_sliced(x, y, thetas=thetas, metric="mahalanobis")
906+
907+
908+
def test_sliced_plans_backends(nx):
909+
n = 10
910+
m = 24
911+
n_proj = 10
912+
d = 2
913+
rng = np.random.RandomState(0)
914+
915+
x = rng.randn(n, 2)
916+
y = rng.randn(m, 2)
917+
a = rng.uniform(0, 1, n)
918+
a /= a.sum()
919+
b = rng.uniform(0, 1, m)
920+
b /= b.sum()
921+
922+
x_b, y_b, a_b, b_b = nx.from_numpy(x, y, a, b)
923+
924+
thetas = ot.sliced.get_random_projections(d, n_proj, seed=0, backend=nx).T
925+
thetas_t = nx.to_numpy(thetas)
874926

875927
context = (
876928
nullcontext()
@@ -879,32 +931,25 @@ def test_expected_sliced(nx):
879931
)
880932

881933
with context:
882-
expected_plan, expected_cost = ot.sliced.expected_sliced(
883-
x, y, a, b, dense=True, thetas=thetas
934+
_, expected_cost_b = ot.sliced.expected_sliced(
935+
x_b, y_b, a_b, b_b, dense=True, thetas=thetas_t
884936
)
885-
expected_plan_b, expected_cost_b, _ = ot.sliced.expected_sliced(
886-
x_b, y_b, a, b, thetas=thetas_b, dense=True, log=True
937+
# result should be the same than numpy version
938+
_, expected_cost = ot.sliced.expected_sliced(
939+
x, y, a, b, dense=True, thetas=thetas
887940
)
941+
np.testing.assert_almost_equal(expected_cost_b, expected_cost)
888942

889-
np.testing.assert_almost_equal(expected_plan, nx.to_numpy(expected_plan_b))
890-
np.testing.assert_almost_equal(expected_cost, nx.to_numpy(expected_cost_b))
891-
892-
# result should be a coarse upper-bound of W2
893-
w2 = ot.emd2(a, b, ot.dist(x, y))
894-
assert expected_cost >= w2
895-
assert expected_cost <= 3 * w2
896-
897-
# test without provided thetas
898-
ot.sliced.expected_sliced(x, y, n_proj=n_proj, log=True)
943+
# for min_pivot
944+
_, min_cost_b = ot.sliced.min_pivot_sliced(
945+
x_b, y_b, a_b, b_b, dense=True, thetas=thetas_t
946+
)
947+
# result should be the same than numpy version
948+
_, min_cost = ot.sliced.min_pivot_sliced(x, y, a, b, dense=True, thetas=thetas)
949+
np.testing.assert_almost_equal(min_cost_b, min_cost)
899950

900-
# test with invalid shapes
901-
with pytest.raises(AssertionError):
902-
ot.sliced.min_pivot_sliced(x[:, 1:], y, thetas=thetas)
951+
# for sliced_plans
952+
thetas = ot.sliced.get_random_projections(d, n_proj, seed=0, backend=nx).T
903953

904-
# with a small temperature (i.e. large beta), the cost should be close
905-
# to min_pivot
906-
_, expected_cost = ot.sliced.expected_sliced(
907-
x, y, a, b, thetas=thetas, dense=True, beta=100.0
908-
)
909-
_, min_cost = ot.sliced.min_pivot_sliced(x, y, a, b, thetas=thetas, dense=True)
910-
np.testing.assert_almost_equal(expected_cost, min_cost, decimal=3)
954+
# test with the minkowski metric
955+
ot.sliced.min_pivot_sliced(x, y, thetas=thetas, metric="minkowski")

0 commit comments

Comments
 (0)