Skip to content

[BUG] Custom root fails for scalar inputs (jaxopt example) #222

@marvinfriede

Description

@marvinfriede

Required prerequisites

What version of TorchOpt are you using?

0.7.3

System information

3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:36:39) [GCC 10.4.0] linux
0.7.3 2.1.0 2.1.0

Problem description

I tried to recreate the jaxopt example for root finding with implicit differentiation using a very simple iterative solver. In jax, it just works. With torchopt, however, the gradient is zero for scalar inputs.

The problem stems from:

if maxiter is None:
size = sum(cat_shapes(b))
maxiter = 10 * size # copied from SciPy

Here, the size becomes zero for scalar b and the maxiter is wrongly set to 0. The same piece of code in jax produces a size of 1.

Reproducible example code

The Python snippets:

import torch
import torchopt


# https://jaxopt.github.io/stable/root_finding.html
def F(x, factor):
    return factor * x**3 - x - 2


@torchopt.diff.implicit.custom_root(
    F, argnums=(1,), solve=torchopt.linear_solve.solve_cg()
)
def custom_root_solver(init_x, factor):
    """Root solver using gradient descent."""
    maxiter = 100
    lr = 1e-1

    x = init_x
    for _ in range(maxiter):
        grad = F(x, factor)
        x = x - lr * grad

    return x


def wrapper(fac):
    return custom_root_solver(init_x, fac)


init_x = torch.tensor(1.0)
fac = torch.tensor(2.0, requires_grad=True)

root = wrapper(fac)
root_grad = torch.autograd.grad(root, fac)
print(root_grad)

Traceback

No response

Expected behavior

No response

Additional context

It works upon making the tensors 1D:

init_x = torch.tensor([1.0])
fac = torch.tensor([2.0], requires_grad=True)

It just works in jax.

import jax
import jax.numpy as jnp
from jaxopt.implicit_diff import custom_root
from jaxopt import Bisection

jax.config.update("jax_platform_name", "cpu")


def F(x, factor):
  return factor * x ** 3 - x - 2


def bisection_root_solver(init_x, factor):
  bisec = Bisection(optimality_fun=F, lower=1, upper=2)
  return bisec.run(factor=factor).params


@custom_root(F)
def custom_root_solver(init_x, factor):
    """Root solver using gradient descent."""
    maxiter = 100
    lr = 1e-1

    x = init_x
    for _ in range(maxiter):
        grad = F(x, factor)
        x = x - lr * grad

    return x


x_init = jnp.array(3.0)
fac = jnp.array(2.0)

print(custom_root_solver(x_init, fac))
print(bisection_root_solver(x_init, fac))

print(jax.grad(custom_root_solver, argnums=1)(x_init, fac))
print(jax.grad(bisection_root_solver, argnums=1)(x_init, fac))

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions