Skip to content

Commit dba0dec

Browse files
committed
correct bug with sparse matrices
1 parent a19753e commit dba0dec

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

examples/sliced-wasserstein/plot_sliced_plans.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@
4242
alpha = 0.3
4343

4444

45-
proj_X = X @ thetas.T
46-
proj_Y = Y @ thetas.T
45+
# proj_X = X @ thetas.T
46+
# proj_Y = Y @ thetas.T
4747

4848

4949
##############################################################################

ot/sliced.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -873,15 +873,15 @@ def sliced_plans(
873873
)
874874
** (1 / p)
875875
)
876-
* plan[k]
876+
* plan[k].data
877877
)
878878
for k in range(n_proj)
879879
]
880880
else: # metric == "sqeuclidean"
881881
costs = [
882882
nx.sum(
883883
(nx.sum((X[plan[k].row] - Y[plan[k].col]) ** 2, axis=1))
884-
* plan[k]
884+
* plan[k].data
885885
)
886886
for k in range(n_proj)
887887
]

0 commit comments

Comments
 (0)