Skip to content

Differentiating Implicit function inside implicit function with Enzyme #153

@benjaminfaber

Description

@benjaminfaber

I've run into an issue where I want to compute the gradient of an implicit function that itself depends on another implicit function. I can do the operation successfully with FowardDiff, however I would like to use a backend that is faster, like Enzyme. When I run the MWE below, I run into a segmentation fault, I think related to the fact that one needs use autodiff_deferred, but it looks like ADTypes doesn't yet have a backend option for autodiff_deferred. Should I open an issue there or is this something that can be implemented in ImplicitDifferentiation/DifferentiationInterface?

Tagging @just-walk

using ImplicitDifferentiation, Optim, Enzyme, ForwardDiff

function f1(x, method)
    f(y) = sum(abs2, y .^ 2 .- x)
    y0 = ones(eltype(x), size(x))
    result = optimize(f, y0, method)
    return Optim.minimizer(result)
end;

function c1(x, y, method)
    ∇₂f = @. 4 * (y^2 - x) * y
    return ∇₂f
end;

implicit_f1 = ImplicitFunction(f1, c1)

function f2(x, method)
    z = implicit_f1(x, method)
    f(y) = sum(abs2, y .^ 2 .- z .* x)
    y0 = ones(eltype(x), size(x))
    result = optimize(f, y0, method)
    return Optim.minimizer(result)
end

function c2(x, y, method)
    z = implicit_f1(x, method)
    ∇₂f = @. 4 * (y^2 - z * x) * y
    return ∇₂f
end

implicit_f2 = ImplicitFunction(f2, c2)

x = [4., 9.]

dx = ([1., 0.], [0., 1.])

# Works
df1 = Enzyme.autodiff(Enzyme.Forward, implicit_f1, BatchDuplicatedNoNeed, BatchDuplicated(x, dx), Const(LBFGS()))[1]

┌ Warning: Using fallback BLAS replacements for (["dasum_64_"]), performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
((var"1" = (var"1" = [0.25000000000806183, 0.0], var"2" = [0.0, 0.16666666667288246]),)

df2 = ForwardDiff.jacobian(_x -> implicit_f2(_x, LBFGS()), x)
2×2 Matrix{Float64}:
 0.53033  0.0
 0.0      0.433013

# Does not work

df2 = autodiff(Forward, implicit_f2, BatchDuplicatedNoNeed, BatchDuplicated(x, dx), Const(LBFGS()))[1]
[32468] signal (11.1): Segmentation fault
in expression starting at /home/bfaber/TempPkg/src/diff_test.jl:44
gc_mark_obj16 at /cache/build/builder-amdci4-0/julialang/julia-release-1-dot-10/src/gc.c:1894 [inlined]
gc_mark_outrefs at /cache/build/builder-amdci4-0/julialang/julia-release-1-dot-10/src/gc.c:2654 [inlined]
gc_mark_loop_serial_ at /cache/build/builder-amdci4-0/julialang/julia-release-1-dot-10/src/gc.c:2697
gc_mark_loop_serial at /cache/build/builder-amdci4-0/julialang/julia-release-1-dot-10/src/gc.c:2720
gc_mark_loop at /cache/build/builder-amdci4-0/julialang/julia-release-1-dot-10/src/gc.c:2901 [inlined]
_jl_gc_collect at /cache/build/builder-amdci4-0/julialang/julia-release-1-dot-10/src/gc.c:3234
ijl_gc_collect at /cache/build/builder-amdci4-0/julialang/julia-release-1-dot-10/src/gc.c:3531
maybe_collect at /cache/build/builder-amdci4-0/julialang/julia-release-1-dot-10/src/gc.c:937 [inlined]
jl_gc_pool_alloc_inner at /cache/build/builder-amdci4-0/julialang/julia-release-1-dot-10/src/gc.c:1300
jl_gc_pool_alloc_noinline at /cache/build/builder-amdci4-0/julialang/julia-release-1-dot-10/src/gc.c:1357 [inlined]
jl_gc_alloc_ at /cache/build/builder-amdci4-0/julialang/julia-release-1-dot-10/src/julia_internal.h:476 [inlined]
jl_gc_alloc at /cache/build/builder-amdci4-0/julialang/julia-release-1-dot-10/src/gc.c:3583
_new_array_ at /cache/build/builder-amdci4-0/julialang/julia-release-1-dot-10/src/array.c:134 [inlined]
_new_array at /cache/build/builder-amdci4-0/julialang/julia-release-1-dot-10/src/array.c:198 [inlined]
ijl_alloc_array_1d at /cache/build/builder-amdci4-0/julialang/julia-release-1-dot-10/src/array.c:436
Array at ./boot.jl:477 [inlined]
.
.
.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions