Skip to content

Commit 55f861c

Browse files
committed
fix device
1 parent 9412193 commit 55f861c

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ This new release adds support for sparse cost matrices in the exact EMD solver.
1111
- Add support for sparse cost matrices in EMD solver (PR #778, Issue #397)
1212
- Fix deprecated JAX function in `ot.backend.JaxBackend` (PR #771, Issue #770)
1313
- Add test for build from source (PR #772, Issue #764)
14+
- Fix device for batch Ot solver in `ot.batch` (PR #784, Issue #783)
1415

1516
## 0.9.6.post1
1617

ot/batch/_linear.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,9 +310,9 @@ def solve_batch(
310310
B, n, m = M.shape
311311

312312
if a is None:
313-
a = nx.ones((B, n)) / n
313+
a = nx.ones((B, n), type_as=M) / n
314314
if b is None:
315-
b = nx.ones((B, m)) / m
315+
b = nx.ones((B, m), type_as=M) / m
316316

317317
if solver == "log_sinkhorn":
318318
K = -M / reg

0 commit comments

Comments
 (0)