Skip to content

Commit 282ac99

Browse files
committed
update sliced plans with sparse matrix for tf compatibility
1 parent 58f7e11 commit 282ac99

File tree

2 files changed

+29
-24
lines changed

2 files changed

+29
-24
lines changed

ot/sliced.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -825,7 +825,8 @@ def sliced_plans(
825825
X_theta, Y_theta, a, b, p, require_sort=True, return_plan=True
826826
)
827827

828-
plan = plan.tocsr().tocoo() # especially for tensorflow compatibility
828+
if str(nx) == "tensorflow": # tf does not support duplicate entries
829+
plan = [plan[k].tocsr().tocoo() for k in range(n_proj)]
829830

830831
if str(nx) == "jax":
831832
plan = [nx.todense(plan[k]) for k in range(n_proj)]
@@ -853,26 +854,30 @@ def sliced_plans(
853854
)
854855
for k in range(n_proj)
855856
]
856-
857-
if metric in ("minkowski", "euclidean", "cityblock"):
858-
costs = [
859-
nx.sum(
860-
(
861-
(nx.sum(nx.abs(X[plan[k].row] - Y[plan[k].col]) ** p, axis=1))
862-
** (1 / p)
857+
else:
858+
if metric in ("minkowski", "euclidean", "cityblock"):
859+
costs = [
860+
nx.sum(
861+
(
862+
(
863+
nx.sum(
864+
nx.abs(X[plan[k].row] - Y[plan[k].col]) ** p, axis=1
865+
)
866+
)
867+
** (1 / p)
868+
)
869+
* plan[k].data
863870
)
864-
* plan[k].data
865-
)
866-
for k in range(n_proj)
867-
]
868-
else: # metric = "sqeuclidean"
869-
costs = [
870-
nx.sum(
871-
(nx.sum((X[plan[k].row] - Y[plan[k].col]) ** 2, axis=1))
872-
* plan[k].data
873-
)
874-
for k in range(n_proj)
875-
]
871+
for k in range(n_proj)
872+
]
873+
else: # metric = "sqeuclidean"
874+
costs = [
875+
nx.sum(
876+
(nx.sum((X[plan[k].row] - Y[plan[k].col]) ** 2, axis=1))
877+
* plan[k].data
878+
)
879+
for k in range(n_proj)
880+
]
876881

877882
if dense:
878883
plan = [nx.todense(plan[k]) for k in range(n_proj)]

test/test_sliced.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,7 @@ def test_sliced_plans():
776776
ot.sliced.sliced_plans(x, y, thetas=thetas, metric="minkowski")
777777

778778
# test with an unsupported metric
779-
with pytest.raises(ValueError):
779+
with pytest.raises(AssertionError):
780780
ot.sliced.sliced_plans(x, y, thetas=thetas, metric="mahalanobis")
781781

782782
# test with a warm theta
@@ -798,7 +798,7 @@ def test_sliced_plans():
798798
b /= b.sum()
799799

800800
# test with the minkowski metric
801-
ot.sliced.sliced_plans(x, y, metric="minkowski")
801+
ot.sliced.sliced_plans(x, y, n_proj=10, metric="minkowski")
802802

803803

804804
def test_min_pivot_sliced():
@@ -857,7 +857,7 @@ def test_min_pivot_sliced():
857857
ot.sliced.min_pivot_sliced(x, y, thetas=thetas, metric="cityblock")
858858

859859
# test with an unsupported metric
860-
with pytest.raises(ValueError):
860+
with pytest.raises(AssertionError):
861861
ot.sliced.min_pivot_sliced(x, y, thetas=thetas, metric="mahalanobis")
862862

863863
# test with a warm theta
@@ -922,7 +922,7 @@ def test_expected_sliced():
922922
ot.sliced.expected_sliced(x, y, thetas=thetas, metric="minkowski")
923923

924924
# test with an unsupported metric
925-
with pytest.raises(ValueError):
925+
with pytest.raises(AssertionError):
926926
ot.sliced.expected_sliced(x, y, thetas=thetas, metric="mahalanobis")
927927

928928

0 commit comments

Comments
 (0)