Skip to content
16 changes: 16 additions & 0 deletions ext/LinearSolveMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,20 @@ function Mooncake.to_cr_tangent(x::Mooncake.PossiblyUninitTangent{T}) where {T}
end
end

function Mooncake.increment_and_get_rdata!(f, r::NoRData, t::LinearCache)
f.fields.A .+= t.A
f.fields.b .+= t.b
f.fields.u .+= t.u

return NoRData()
end

# rrules for LinearCache
@from_chainrules MinimalCtx Tuple{typeof(init),LinearProblem,SciMLLinearSolveAlgorithm} true ReverseMode
@from_chainrules MinimalCtx Tuple{typeof(init),LinearProblem,Nothing} true ReverseMode

# rrule for solve!
@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.solve!),LinearCache,SciMLLinearSolveAlgorithm} true ReverseMode
@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.solve!),LinearCache,Nothing} true ReverseMode

end
78 changes: 78 additions & 0 deletions src/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,81 @@ function CRC.rrule(::Type{<:LinearProblem}, A, b, p; kwargs...)
∇prob(∂prob) = (NoTangent(), ∂prob.A, ∂prob.b, ∂prob.p)
return prob, ∇prob
end

function CRC.rrule(T::typeof(LinearSolve.init), prob::LinearSolve.LinearProblem, alg::Nothing, args...; kwargs...)
assump = OperatorAssumptions(issquare(prob.A))
alg = defaultalg(prob.A, prob.b, assump)
CRC.rrule(T, prob, alg, args...; kwargs...)
end

function CRC.rrule(::typeof(LinearSolve.init), prob::LinearSolve.LinearProblem, alg::Union{LinearSolve.SciMLLinearSolveAlgorithm,Nothing}, args...; kwargs...)
init_res = LinearSolve.init(prob, alg)
function init_adjoint(∂init)
∂prob = LinearProblem(∂init.A, ∂init.b, NoTangent())
return NoTangent(), ∂prob, NoTangent(), ntuple((_ -> NoTangent(), length(args))...)
end

return init_res, init_adjoint
end

function CRC.rrule(T::typeof(SciMLBase.solve!), cache::LinearSolve.LinearCache, alg::Nothing, args...; kwargs...)
assump = OperatorAssumptions()
alg = defaultalg(cache.A, cache.b, assump)
CRC.rrule(T, cache, alg, args...; kwargs)
end

function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearSolve.LinearCache, alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; alias_A=default_alias_A(
alg, cache.A, cache.b), kwargs...)
(; A, sensealg) = cache
@assert sensealg isa LinearSolveAdjoint "Currently only `LinearSolveAdjoint` is supported for adjoint sensitivity analysis."

# logic behind caching `A` and `b` for the reverse pass based on rrule above for SciMLBase.solve
if sensealg.linsolve === missing
if !(alg isa AbstractFactorization || alg isa AbstractKrylovSubspaceMethod ||
alg isa DefaultLinearSolver)
A_ = alias_A ? deepcopy(A) : A
end
else
A_ = deepcopy(A)
end

sol = solve!(cache)
function solve!_adjoint(∂sol)
∂∅ = NoTangent()
∂u = ∂sol.u

if sensealg.linsolve === missing
λ = if cache.cacheval isa Factorization
cache.cacheval' \ ∂u
elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization
first(cache.cacheval)' \ ∂u
elseif alg isa AbstractKrylovSubspaceMethod
invprob = LinearProblem(adjoint(cache.A), ∂u)
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
elseif alg isa DefaultLinearSolver
LinearSolve.defaultalg_adjoint_eval(cache, ∂u)
else
invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A`
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
end
else
invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A`
λ = solve(
invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u
end

tu = adjoint(sol.u)
∂A = BroadcastArray(@~ .-(λ .* tu))
∂b = λ

if (iszero(∂b) || iszero(∂A)) && !iszero(tu)
error("Adjoint case currently not handled. Instead of using `solve!(cache); s1 = copy(cache.u) ...`, use `sol = solve!(cache); s1 = copy(sol.u)`.")
end

∂prob = LinearProblem(∂A, ∂b, ∂∅)
∂cache = LinearSolve.init(∂prob, u=∂u)
return (∂∅, ∂cache, ∂∅, ntuple(_ -> ∂∅, length(args))...)
end

return sol, solve!_adjoint
end
132 changes: 132 additions & 0 deletions test/nopre/mooncake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,135 @@ for alg in (
@test results[1] ≈ fA(A)
@test mooncake_gradient ≈ fd_jac rtol = 1e-5
end

# Tests for solve! and init rrules.
n = 4
A = rand(n, n);
b1 = rand(n);
b2 = rand(n);

function f(A, b1, b2; alg=LUFactorization())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
s1 = copy(solve!(cache).u)
cache.b = b2
s2 = solve!(cache).u
norm(s1 + s2)
end

f_primal = f(copy(A), copy(b1), copy(b2))
value, gradient = Mooncake.value_and_gradient!!(
prepare_gradient_cache(f, copy(A), copy(b1), copy(b2)),
f, copy(A), copy(b1), copy(b2)
)

dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1), eltype(x).(b2)), copy(A))
db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x, eltype(x).(b2)), copy(b1))
db22 = ForwardDiff.gradient(x -> f(eltype(x).(A), eltype(x).(b1), x), copy(b2))

@test value == f_primal
@test gradient[2] ≈ dA2
@test gradient[3] ≈ db12
@test gradient[4] ≈ db22

function f2(A, b1, b2; alg=RFLUFactorization())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
s1 = copy(solve!(cache).u)
cache.b = b2
s2 = solve!(cache).u
norm(s1 + s2)
end

f_primal = f2(copy(A), copy(b1), copy(b2))
value, gradient = Mooncake.value_and_gradient!!(
prepare_gradient_cache(f2, copy(A), copy(b1), copy(b2)),
f2, copy(A), copy(b1), copy(b2)
)

@test value == f_primal
@test gradient[2] ≈ dA2
@test gradient[3] ≈ db12
@test gradient[4] ≈ db22

function f3(A, b1, b2; alg=KrylovJL_GMRES())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
s1 = copy(solve!(cache).u)
cache.b = b2
s2 = solve!(cache).u
norm(s1 + s2)
end

f_primal = f3(copy(A), copy(b1), copy(b2))
value, gradient = Mooncake.value_and_gradient!!(
prepare_gradient_cache(f3, copy(A), copy(b1), copy(b2)),
f3, copy(A), copy(b1), copy(b2)
)

@test value == f_primal
@test gradient[2] ≈ dA2 atol = 5e-5
@test gradient[3] ≈ db12
@test gradient[4] ≈ db22

function f4(A, b1, b2; alg=LUFactorization())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
solve!(cache)
s1 = copy(cache.u)
cache.b = b2
solve!(cache)
s2 = copy(cache.u)
norm(s1 + s2)
end

A = rand(n, n);
b1 = rand(n);
b2 = rand(n);
# f_primal = f4(copy(A), copy(b1), copy(b2))

rule = Mooncake.build_rrule(f4, copy(A), copy(b1), copy(b2))
@test_throws "Adjoint case currently not handled" Mooncake.value_and_pullback!!(
rule, 1.0,
f4, copy(A), copy(b1), copy(b2)
)

# dA2 = ForwardDiff.gradient(x -> f4(x, eltype(x).(b1), eltype(x).(b2)), copy(A))
# db12 = ForwardDiff.gradient(x -> f4(eltype(x).(A), x, eltype(x).(b2)), copy(b1))
# db22 = ForwardDiff.gradient(x -> f4(eltype(x).(A), eltype(x).(b1), x), copy(b2))

# @test value == f_primal
# @test grad[2] ≈ dA2
# @test grad[3] ≈ db12
# @test grad[4] ≈ db22

A = rand(n, n);
b1 = rand(n);

function fnice(A, b, alg)
prob = LinearProblem(A, b)
sol1 = solve(prob, alg)
return sum(sol1.u)
end

@testset for alg in (
LUFactorization(),
RFLUFactorization(),
KrylovJL_GMRES()
)
# for B
fb_closure = b -> fnice(A, b, alg)
fd_jac_b = FiniteDiff.finite_difference_jacobian(fb_closure, b1) |> vec

val, en_jac = Mooncake.value_and_gradient!!(
prepare_gradient_cache(fnice, copy(A), copy(b1), alg),
fnice, copy(A), copy(b1), alg
)
@test en_jac[3] ≈ fd_jac_b rtol = 1e-5

# For A
fA_closure = A -> fnice(A, b1, alg)
fd_jac_A = FiniteDiff.finite_difference_jacobian(fA_closure, A) |> vec
A_grad = en_jac[2] |> vec
@test A_grad ≈ fd_jac_A rtol = 1e-5
end
Loading