File tree Expand file tree Collapse file tree 2 files changed +3
-2
lines changed
Expand file tree Collapse file tree 2 files changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -11,6 +11,7 @@ This new release adds support for sparse cost matrices in the exact EMD solver.
1111- Add support for sparse cost matrices in EMD solver (PR #778 , Issue #397 )
1212- Fix deprecated JAX function in ` ot.backend.JaxBackend ` (PR #771 , Issue #770 )
1313- Add test for build from source (PR #772 , Issue #764 )
14+ - Fix device for batch Ot solver in ` ot.batch ` (PR #784 , Issue #783 )
1415
1516## 0.9.6.post1
1617
Original file line number Diff line number Diff line change @@ -310,9 +310,9 @@ def solve_batch(
310310 B , n , m = M .shape
311311
312312 if a is None :
313- a = nx .ones ((B , n )) / n
313+ a = nx .ones ((B , n ), type_as = M ) / n
314314 if b is None :
315- b = nx .ones ((B , m )) / m
315+ b = nx .ones ((B , m ), type_as = M ) / m
316316
317317 if solver == "log_sinkhorn" :
318318 K = - M / reg
You can’t perform that action at this time.
0 commit comments