-
Notifications
You must be signed in to change notification settings - Fork 39
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Required prerequisites
- I have read the documentation https://torchopt.readthedocs.io.
- I have searched the Issue Tracker and Discussions that this hasn't already been reported. (+1 or comment there if it has.)
- Consider asking first in a Discussion.
What version of TorchOpt are you using?
0.7.3
System information
3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:36:39) [GCC 10.4.0] linux
0.7.3 2.1.0 2.1.0
Problem description
I tried to recreate the jaxopt example for root finding with implicit differentiation using a very simple iterative solver. In jax, it just works. With torchopt, however, the gradient is zero for scalar inputs.
The problem stems from:
torchopt/torchopt/linalg/cg.py
Lines 120 to 122 in a4cfc49
if maxiter is None: | |
size = sum(cat_shapes(b)) | |
maxiter = 10 * size # copied from SciPy |
Here, the size becomes zero for scalar b
and the maxiter
is wrongly set to 0. The same piece of code in jax produces a size of 1.
Reproducible example code
The Python snippets:
import torch
import torchopt
# https://jaxopt.github.io/stable/root_finding.html
def F(x, factor):
return factor * x**3 - x - 2
@torchopt.diff.implicit.custom_root(
F, argnums=(1,), solve=torchopt.linear_solve.solve_cg()
)
def custom_root_solver(init_x, factor):
"""Root solver using gradient descent."""
maxiter = 100
lr = 1e-1
x = init_x
for _ in range(maxiter):
grad = F(x, factor)
x = x - lr * grad
return x
def wrapper(fac):
return custom_root_solver(init_x, fac)
init_x = torch.tensor(1.0)
fac = torch.tensor(2.0, requires_grad=True)
root = wrapper(fac)
root_grad = torch.autograd.grad(root, fac)
print(root_grad)
Traceback
No response
Expected behavior
No response
Additional context
It works upon making the tensors 1D:
init_x = torch.tensor([1.0])
fac = torch.tensor([2.0], requires_grad=True)
It just works in jax.
import jax
import jax.numpy as jnp
from jaxopt.implicit_diff import custom_root
from jaxopt import Bisection
jax.config.update("jax_platform_name", "cpu")
def F(x, factor):
return factor * x ** 3 - x - 2
def bisection_root_solver(init_x, factor):
bisec = Bisection(optimality_fun=F, lower=1, upper=2)
return bisec.run(factor=factor).params
@custom_root(F)
def custom_root_solver(init_x, factor):
"""Root solver using gradient descent."""
maxiter = 100
lr = 1e-1
x = init_x
for _ in range(maxiter):
grad = F(x, factor)
x = x - lr * grad
return x
x_init = jnp.array(3.0)
fac = jnp.array(2.0)
print(custom_root_solver(x_init, fac))
print(bisection_root_solver(x_init, fac))
print(jax.grad(custom_root_solver, argnums=1)(x_init, fac))
print(jax.grad(bisection_root_solver, argnums=1)(x_init, fac))
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working