Skip to content

Commit 39b30d9

Browse files
committed
change nonzeros by where
1 parent 42e3353 commit 39b30d9

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

ot/sliced.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -829,12 +829,12 @@ def dist(i, j):
829829
warnings.warn(
830830
"TensorFlow sparse indexing is limited, converting to dense"
831831
)
832-
plan = [nx.todense(plan[k]) for k in range(n_proj)]
833-
idx_non_zeros = [nx.nonzero(plan[k]) for k in range(n_proj)]
832+
plan_dense = [nx.todense(plan[k]) for k in range(n_proj)]
833+
idx_non_zeros = [nx.where(plan_dense[k] != 0) for k in range(n_proj)]
834834
costs = [
835835
nx.sum(
836836
dist(idx_non_zeros[k][0], idx_non_zeros[k][1])
837-
* plan[k][idx_non_zeros[k][0], idx_non_zeros[k][1]]
837+
* plan_dense[k][idx_non_zeros[k][0], idx_non_zeros[k][1]]
838838
)
839839
for k in range(n_proj)
840840
]
@@ -844,8 +844,8 @@ def dist(i, j):
844844
for k in range(n_proj)
845845
]
846846

847-
if dense and not (str(nx) == "jax" or str(nx) == "tensorflow"):
848-
plan = [nx.todense(plan[k]) for k in range(n_proj)]
847+
if dense and not (str(nx) == "jax"):
848+
plan = plan_dense.copy()
849849
elif str(nx) == "jax" and not is_perm:
850850
warnings.warn("JAX does not support sparse matrices, converting to dense")
851851
plan = [nx.todense(plan[k]) for k in range(n_proj)]
@@ -1101,7 +1101,7 @@ def expected_sliced(
11011101
>>> x=np.array([[3.,3.], [1.,1.]])
11021102
>>> y=np.array([[2.,2.5], [3.,2.]])
11031103
>>> thetas=np.array([[1, 0], [0, 1]])
1104-
>>> plan, cost = expected_sliced(x, y, thetas)
1104+
>>> plan, cost = expected_sliced(x, y, thetas=thetas)
11051105
>>> plan
11061106
[[0.25 0.25]
11071107
[0.25 0.25]]
@@ -1138,7 +1138,7 @@ def expected_sliced(
11381138

11391139
log_dict = {}
11401140
G, costs, log_dict_plans = sliced_plans(
1141-
X, Y, a, b, metric, p, thetas, n_proj=n_proj, log=True
1141+
X, Y, a, b, metric, p, thetas, n_proj=n_proj, log=True, dense=False
11421142
)
11431143
if log:
11441144
log_dict = {"thetas": log_dict_plans["thetas"], "costs": costs, "G": G}
@@ -1154,7 +1154,6 @@ def expected_sliced(
11541154
weights = nx.ones(n_proj) / n_proj
11551155

11561156
log_dict["weights"] = weights
1157-
11581157
weights = nx.concatenate([G[i].data * weights[i] for i in range(len(G))])
11591158
X_idx = nx.concatenate([G[i].row for i in range(len(G))])
11601159
Y_idx = nx.concatenate([G[i].col for i in range(len(G))])

0 commit comments

Comments
 (0)