Skip to content

Commit 5b9d51c

Browse files
committed
modif dense vs coo_matrix
1 parent 858ef9f commit 5b9d51c

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

ot/sliced.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,7 @@ def sliced_plans(
766766
assert (
767767
X.shape[1] == Y.shape[1]
768768
), f"X ({X.shape}) and Y ({Y.shape}) must have the same number of columns"
769+
769770
if metric == "euclidean":
770771
p = 2
771772
elif metric == "cityblock":
@@ -818,13 +819,18 @@ def dist(i, j):
818819
X_theta, Y_theta, a, b, p, require_sort=True, return_plan=True
819820
)
820821

821-
if str(nx) == "jax": # dense computation for jax
822+
if str(nx) == "jax" or str(nx) == "tensorflow":
822823
if not dense:
823-
warnings.warn(
824-
"JAX does not support sparse matrices, converting to dense"
825-
)
824+
if str(nx) == "jax":
825+
warnings.warn(
826+
"JAX does not support sparse matrices, converting to dense"
827+
)
828+
else:
829+
warnings.warn(
830+
"TensorFlow sparse indexing is limited, converting to dense"
831+
)
826832
plan = [nx.todense(plan[k]) for k in range(n_proj)]
827-
idx_non_zeros = [np.nonzero(plan[k]) for k in range(n_proj)]
833+
idx_non_zeros = [nx.nonzero(plan[k]) for k in range(n_proj)]
828834
costs = [
829835
nx.sum(
830836
dist(idx_non_zeros[k][0], idx_non_zeros[k][1])
@@ -833,25 +839,22 @@ def dist(i, j):
833839
for k in range(n_proj)
834840
]
835841
else:
836-
if str(nx) == "tensorflow": # tf does not support multiple indexing
837-
plan = [plan[k].tocsr().tocoo() for k in range(n_proj)]
838-
839842
costs = [
840843
nx.sum(dist(plan[k].row, plan[k].col) * plan[k].data)
841844
for k in range(n_proj)
842845
]
843846

844-
if dense and not str(nx) == "jax":
847+
if dense and not (str(nx) == "jax" or str(nx) == "tensorflow"):
845848
plan = [nx.todense(plan[k]) for k in range(n_proj)]
846-
elif str(nx) == "jax":
849+
elif str(nx) == "jax" and not is_perm:
847850
warnings.warn("JAX does not support sparse matrices, converting to dense")
848851
plan = [nx.todense(plan[k]) for k in range(n_proj)]
849852

850853
if log:
851854
log_dict = {"X_theta": X_theta, "Y_theta": Y_theta, "thetas": thetas}
852-
return plan, costs, log_dict
855+
return plan, nx.stack(costs), log_dict
853856
else:
854-
return plan, costs
857+
return plan, nx.stack(costs)
855858

856859

857860
def min_pivot_sliced(
@@ -945,7 +948,7 @@ def min_pivot_sliced(
945948
>>> x=np.array([[3.,3.], [1.,1.]])
946949
>>> y=np.array([[2.,2.5], [3.,2.]])
947950
>>> thetas=np.array([[1, 0], [0, 1]])
948-
>>> plan, cost = ot.expected_sliced(x, y, thetas)
951+
>>> plan, cost = min_pivot_sliced(x, y, thetas)
949952
>>> plan
950953
[[0 0.5]
951954
[0.5 0]]
@@ -984,6 +987,7 @@ def min_pivot_sliced(
984987
warm_theta=warm_theta,
985988
log=True,
986989
)
990+
987991
pos_min = nx.argmin(costs)
988992
cost = costs[pos_min]
989993
plan = G[pos_min]
@@ -1097,7 +1101,7 @@ def expected_sliced(
10971101
>>> x=np.array([[3.,3.], [1.,1.]])
10981102
>>> y=np.array([[2.,2.5], [3.,2.]])
10991103
>>> thetas=np.array([[1, 0], [0, 1]])
1100-
>>> plan, cost = ot.expected_sliced(x, y, thetas)
1104+
>>> plan, cost = expected_sliced(x, y, thetas)
11011105
>>> plan
11021106
[[0.25 0.25]
11031107
[0.25 0.25]]

0 commit comments

Comments
 (0)