From cd01b79c397ee468951c575b0252f14d07361b6b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 20 Dec 2020 02:18:09 -0800 Subject: [PATCH 1/4] Add ChainRulesCore as dependency --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index c51aa40..dc8fc02 100644 --- a/Project.toml +++ b/Project.toml @@ -4,12 +4,14 @@ authors = ["Chris Rackauckas "] version = "1.8.0" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [compat] +ChainRulesCore = "0.9" Requires = "1.0" julia = "1" From f45d87072ba09487a217dd6e065289daae829d7c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 20 Dec 2020 02:18:28 -0800 Subject: [PATCH 2/4] Add chainrules for expv --- src/ExponentialUtilities.jl | 3 ++- src/krylov_phiv_chainrules.jl | 42 +++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 src/krylov_phiv_chainrules.jl diff --git a/src/ExponentialUtilities.jl b/src/ExponentialUtilities.jl index 7a584b1..4cddebf 100644 --- a/src/ExponentialUtilities.jl +++ b/src/ExponentialUtilities.jl @@ -1,5 +1,5 @@ module ExponentialUtilities -using LinearAlgebra, SparseArrays, Printf, Requires +using LinearAlgebra, SparseArrays, Printf, Requires, ChainRulesCore """ @diagview(A,d) -> view of the `d`th diagonal of `A`. @@ -20,6 +20,7 @@ include("krylov_phiv_adaptive.jl") include("kiops.jl") include("StegrWork.jl") include("krylov_phiv_error_estimate.jl") +include("krylov_phiv_chainrules.jl") export phi, phi!, KrylovSubspace, arnoldi, arnoldi!, lanczos!, ExpvCache, PhivCache, expv, expv!, exp_generic, phiv, phiv!, kiops, expv_timestep, expv_timestep!, phiv_timestep, phiv_timestep!, diff --git a/src/krylov_phiv_chainrules.jl b/src/krylov_phiv_chainrules.jl new file mode 100644 index 0000000..f44e6f0 --- /dev/null +++ b/src/krylov_phiv_chainrules.jl @@ -0,0 +1,42 @@ +function ChainRulesCore.frule((_, Δt, ΔA, Δb), ::typeof(expv), t, A, b; kwargs...) + w = expv(t, A, b; kwargs...) + ∂w = similar(w) + mul!(∂w, A, w) + ∂w .*= Δt + if !isa(Δb, AbstractZero) + ∂w .+= expv(t, A, Δb; kwargs...) + end + # TODO: handle ΔA + ΔA isa AbstractZero || error("ΔA currently cannot be pushed forward") + return w, ∂w +end + +function ChainRulesCore.rrule(::typeof(expv), t, A, b; kwargs...) + w = expv(t, A, b; kwargs...) + function expv_pullback(Δw) + ∂t = Thunk() do + t̄ = A isa AbstractMatrix ? conj(dot(Δw, A, w)) : dot(mul!(similar(w), A, w), Δw) + return t isa Real ? real(t̄) : t̄ + end + # TODO: handle ∂A + ∂A = @thunk error("Adjoint wrt A not yet implemented") + ∂b = Thunk() do + # using similar is necessary to ensure type-stability + b̄ = similar(b) + _copyto!(b̄, expv(t', A', Δw; kwargs...)) + return b̄ + end + return (NO_FIELDS, ∂t, ∂A, ∂b) + end + expv_pullback(::Zero) = (NO_FIELDS, Zero(), Zero(), Zero()) + return w, expv_pullback +end + +function _copyto!(x, y) + if eltype(x) <: Real && !(eltype(y) <: Real) + x .= real.(y) + else + copyto!(x, y) + end + return x +end From 48e6b65f13f60953a9a420302a649e52ed2b494c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 20 Dec 2020 02:18:54 -0800 Subject: [PATCH 3/4] Test new rules --- Project.toml | 4 +++- test/runtests.jl | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index dc8fc02..9d29ebe 100644 --- a/Project.toml +++ b/Project.toml @@ -12,14 +12,16 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [compat] ChainRulesCore = "0.9" +FiniteDifferences = "0.11" Requires = "1.0" julia = "1" [extras] +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["ForwardDiff", "Test", "SafeTestsets", "Random"] +test = ["FiniteDifferences", "ForwardDiff", "Test", "SafeTestsets", "Random"] diff --git a/test/runtests.jl b/test/runtests.jl index 4172aff..63afa48 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using Test, LinearAlgebra, Random, SparseArrays, ExponentialUtilities using ExponentialUtilities: getH, getV, _exp! +using FiniteDifferences using ForwardDiff @testset "Exp" begin @@ -228,3 +229,50 @@ struct OpnormFunctor end @test pv ≈ pv′ atol=1e-12 end end + +@testset "expv chain rules" begin + n = 30 + @testset "frule for T=$T" for T in (Float64, ComplexF64) + t = rand(T) + A = randn(T, n, n) + b = randn(T, n) + Δt = FiniteDifferences.rand_tangent(t) + Δb = FiniteDifferences.rand_tangent(b) + + w = expv(t, A, b) + w_ad, ∂w_ad = frule((NO_FIELDS, Δt, Zero(), Δb), expv, t, A, b) + @test w_ad == w + ∂w_fd = jvp(central_fdm(5, 1), (t, b) -> expv(t, A, b), (t, Δt), (b, Δb)) + @test ∂w_ad ≈ ∂w_fd + + w_ad, ∂w_ad = frule((NO_FIELDS, Δt, Zero(), Zero()), expv, t, A, b) + @test w_ad == w + ∂w_fd = jvp(central_fdm(5, 1), t -> expv(t, A, b), (t, Δt)) + @test ∂w_ad ≈ ∂w_fd + + ΔA = FiniteDifferences.rand_tangent(A) + @test_throws ErrorException frule((NO_FIELDS, Δt, ΔA, Δb), expv, t, A, b) + end + + @testset "rrule for T=$T" for T in (Float64, ComplexF64) + t = rand(T) + A = randn(T, n, n) + b = randn(T, n) + w = expv(t, A, b) + Δw = FiniteDifferences.rand_tangent(w) + + w_ad, back = rrule(expv, t, A, b) + @test w_ad == w + ∂self, ∂t_ad, ∂A_ad, ∂b_ad = @inferred back(Δw) + @test ∂self === NO_FIELDS + @test @inferred(extern(∂t_ad)) isa typeof(t) + @test @inferred(extern(∂b_ad)) isa typeof(b) + + ∂t_fd, ∂A_fd, ∂b_fd = j′vp(central_fdm(5, 1), expv, Δw, t, A, b) + @test extern(∂t_ad) ≈ ∂t_fd + @test extern(∂b_ad) ≈ ∂b_fd + @test_throws ErrorException unthunk(∂A_ad) + + @test @inferred(back(Zero())) === (NO_FIELDS, Zero(), Zero(), Zero()) + end +end From b47753973ab06eb7c113020fe3881dd60c9959a9 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 20 Dec 2020 02:21:35 -0800 Subject: [PATCH 4/4] Load ChainRulesCore for tests --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 63afa48..fa9e841 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ using Test, LinearAlgebra, Random, SparseArrays, ExponentialUtilities using ExponentialUtilities: getH, getV, _exp! -using FiniteDifferences +using ChainRulesCore, FiniteDifferences using ForwardDiff @testset "Exp" begin