diff --git a/.github/workflows/build_doc.yml b/.github/workflows/build_doc.yml index d30e1cf31..6575510c5 100644 --- a/.github/workflows/build_doc.yml +++ b/.github/workflows/build_doc.yml @@ -15,11 +15,16 @@ jobs: steps: - uses: actions/checkout@v4 # Standard drop-in approach that should work for most people. - + - name: Free Disk Space (Ubuntu) + uses: insightsengineering/disk-space-reclaimer@v1 + with: + android: true + dotnet: true - name: Set up Python 3.10 uses: actions/setup-python@v5 with: python-version: "3.10" + cache: 'pip' - name: Get Python running run: | diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index 35a26fc6b..28f5edbb3 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -64,11 +64,17 @@ jobs: python-version: ["3.10", "3.11", "3.12", "3.13"] steps: + - name: Free Disk Space (Ubuntu) + uses: insightsengineering/disk-space-reclaimer@v1 + with: + android: true + dotnet: true - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + cache: 'pip' - name: Install POT run: | pip install -e . @@ -93,6 +99,7 @@ jobs: uses: actions/setup-python@v5 with: python-version: "3.13" + cache: 'pip' - name: Install dependencies run: | python -m pip install --upgrade pip setuptools @@ -121,6 +128,7 @@ jobs: uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + cache: 'pip' - name: Install POT run: | pip install -e . @@ -148,6 +156,7 @@ jobs: uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + cache: 'pip' - name: RC.exe run: | function Invoke-VSDevEnvironment { diff --git a/RELEASES.md b/RELEASES.md index 4d73da648..ad75950eb 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -11,6 +11,7 @@ This new release adds support for sparse cost matrices in the exact EMD solver. - Add support for sparse cost matrices in EMD solver (PR #778, Issue #397) - Fix deprecated JAX function in `ot.backend.JaxBackend` (PR #771, Issue #770) - Add test for build from source (PR #772, Issue #764) +- Fix device for batch Ot solver in `ot.batch` (PR #784, Issue #783) ## 0.9.6.post1 diff --git a/ot/batch/_linear.py b/ot/batch/_linear.py index d6ac6cea4..a63fcb404 100644 --- a/ot/batch/_linear.py +++ b/ot/batch/_linear.py @@ -310,9 +310,9 @@ def solve_batch( B, n, m = M.shape if a is None: - a = nx.ones((B, n)) / n + a = nx.ones((B, n), type_as=M) / n if b is None: - b = nx.ones((B, m)) / m + b = nx.ones((B, m), type_as=M) / m if solver == "log_sinkhorn": K = -M / reg