Skip to content

Commit 5c80f55

Browse files
authored
Merge pull request #1248 from FluxML/bc/no-cache-context
Improved type stability with explicit params
2 parents b9530c7 + 3433cdd commit 5c80f55

File tree

6 files changed

+109
-25
lines changed

6 files changed

+109
-25
lines changed

src/compiler/interface.jl

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@ using Core: Typeof
44
import Base: copy!, IdSet
55
import Base.Broadcast: broadcasted, materialize!
66

7-
mutable struct Context <: AContext
7+
# Internal container used to track accumulated gradients of mutable types (including params).
8+
# Type param I ∈ (true, false) indicates whether implicit params are in use.
9+
# By default, this should be false unless pullback(f, ::Params) is called.
10+
mutable struct Context{I} <: AContext
811
cache::Union{IdDict{Any,Any},Nothing}
912
end
1013

11-
Context() = Context(nothing)
14+
Context() = Context{false}(nothing)
1215

1316
cache(cx::Context) = cx.cache === nothing ? (cx.cache = IdDict()) : cx.cache
1417

@@ -36,10 +39,28 @@ _pullback(f, args...) = _pullback(Context(), f, args...)
3639
tailmemaybe(::Nothing) = nothing
3740
tailmemaybe(x::Tuple) = Base.tail(x)
3841

39-
function pullback(f, args...)
40-
y, back = _pullback(f, args...)
42+
@inline pullback(f, args...) = pullback(f, Context(), args...)
43+
function pullback(f, cx::AContext, args...)
44+
y, back = _pullback(cx, f, args...)
4145
y, Δ -> tailmemaybe(back(Δ))
4246
end
47+
function pullback(cx::Context, f, args...)
48+
ChainRulesCore.ignore_derivatives() do
49+
@warn """
50+
Incorrect argument order for pullback, please use:
51+
52+
pullback(f, __context__::Context, args)
53+
54+
instead of:
55+
56+
pullback(__context__::Context, f, args)
57+
58+
This is usually caused by a call to pullback in a higher-order @adjoint.
59+
The above warning will become an error in Zygote 0.7.
60+
"""
61+
end
62+
return pullback(f, cx, args...)
63+
end
4364

4465
sensitivity(y::Number) = one(y)
4566
sensitivity(y::Complex) = error("Output is complex, so the gradient is not defined.")
@@ -334,21 +355,21 @@ function Base.map(f, gs1::Grads, gss::ADictOrGrads...)
334355
end
335356

336357
function Base.map!(f, gsout::Grads, gss::ADictOrGrads...)
337-
all(issetequal(gsout.params, keys(gs)) for gs in gss) ||
358+
all(issetequal(gsout.params, keys(gs)) for gs in gss) ||
338359
throw(ArgumentError("map! expects Grads objects with the same Params."))
339360
for p in gsout.params
340-
gsout[p] = f((_getformap(gs, p) for gs in gss)...)
361+
gsout[p] = f((_getformap(gs, p) for gs in gss)...)
341362
end
342363
return gsout
343364
end
344365

345366
function _getformap(gs, p)
346367
g = gs[p]
347-
isnothing(g) ? fill!(similar(p), 0) : g
368+
isnothing(g) ? fill!(similar(p), 0) : g
348369
end
349370

350371
function pullback(f, ps::Params)
351-
cx = Context()
372+
cx = Context{true}(nothing)
352373
y, back = _pullback(cx, f)
353374
y, function (Δ)
354375
for p in ps

src/lib/array.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,15 +310,15 @@ end
310310

311311
@adjoint function sum(f, xs::AbstractArray{<:AbstractArray}; kws...)
312312
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
313-
return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
313+
return pullback((f, xs) -> sum(f.(xs); kws...), __context__, f, xs)
314314
end
315315

316316
@adjoint function sum(xs::AbstractArray{Bool}; dims = :)
317317
sum(xs, dims = dims), Δ -> (nothing,)
318318
end
319319

320320
function _pullback(cx::AContext, ::typeof(prod), f, xs::AbstractArray)
321-
y, back = pullback(cx, ((f, xs) -> prod(f.(xs))), f, xs)
321+
y, back = pullback((f, xs) -> prod(f.(xs)), cx, f, xs)
322322
y, ȳ -> (nothing, back(ȳ)...)
323323
end
324324

src/lib/broadcast.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ using Base.Broadcast: Broadcasted, AbstractArrayStyle, broadcasted, materialize
3030
# Utilities
3131
# =========
3232

33+
# ChainRules already marks this non-differentiable,
34+
# But inference can still give up because of the Zygote -> CR wrapper layer
35+
@nograd Broadcast.combine_styles
36+
3337
accum_sum(xs; dims = :) = reduce(accum, xs, dims = dims)
3438

3539
# Work around reducedim_init issue
@@ -82,16 +86,16 @@ _minus(::Nothing) = nothing
8286
@adjoint broadcasted(::typeof(*), x::Numeric, y::Numeric) = x.*y,
8387
Δ -> (nothing, unbroadcast(x, Δ .* conj.(y)), unbroadcast(y, Δ .* conj.(x)))
8488
@adjoint broadcasted(::typeof(*), x::Number, y::AbstractArray{<:Number}) =
85-
_pullback(*, x, y) # this uses dot(y,Δ) instead of sum(Δ .* conj.(y))
89+
_pullback(__context__, *, x, y) # this uses dot(y,Δ) instead of sum(Δ .* conj.(y))
8690
@adjoint broadcasted(::typeof(*), x::AbstractArray{<:Number}, y::Number) =
87-
_pullback(*, x, y)
91+
_pullback(__context__, *, x, y)
8892

8993
@adjoint function broadcasted(::typeof(/), x::Numeric, y::Numeric)
9094
res = x ./ y
9195
res, Δ -> (nothing, unbroadcast(x, Δ ./ conj.(y)), unbroadcast(y, .-Δ .* conj.(res ./ y)))
9296
end
9397
@adjoint broadcasted(::typeof(/), x::AbstractArray{<:Number}, y::Number) =
94-
_pullback(/, x, y)
98+
_pullback(__context__, /, x, y)
9599

96100
@adjoint function broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::Numeric, exp::Val{p}) where p
97101
y = Base.literal_pow.(^, x, exp)
@@ -273,7 +277,7 @@ using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve
273277
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible
274278
@adjoint function sum(f, xs::AbstractGPUArray; kws...)
275279
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
276-
return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
280+
return pullback((f, xs) -> sum(f.(xs); kws...), __context__, f, xs)
277281
end
278282

279283
@adjoint function Base.convert(::Type{T}, xs::Array) where {T<:AbstractGPUArray}

src/lib/lib.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ accum(x, y) =
2121

2222
accum(x, y, zs...) = accum(accum(x, y), zs...)
2323

24-
accum(x::Tuple, ys::Tuple...) = accum.(x, ys...)
24+
accum(x::Tuple, ys::Tuple...) = map(accum, x, ys...)
2525
accum(x::AbstractArray, ys::AbstractArray...) = accum.(x, ys...)
2626

2727
@generated function accum(x::NamedTuple, y::NamedTuple)
@@ -48,6 +48,7 @@ end
4848

4949
@adjoint Base.typeassert(x, T) = Base.typeassert(x, T), Δ -> (Δ, nothing)
5050

51+
accum_param(::Context{false}, _, Δ) = Δ
5152
@generated function accum_param(cx::Context, x, Δ)
5253
isbitstype(x) && return :(Δ)
5354
quote

test/compiler.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Zygote, Test
2-
using Zygote: pullback, @adjoint
2+
using Zygote: pullback, @adjoint, Context
33

44
macro test_inferred(ex)
55
:(let res = nothing
@@ -160,13 +160,18 @@ end
160160
@testset "inference for `getproperty`" begin
161161
Gaussian = _Gaussian(:getproperty)
162162
g = Gaussian(randn(3), randn(3, 3))
163-
y, back = @inferred pullback(x -> x.m, g)
164-
@test y == getfield(g, :m)
165-
# This type instability is due to the handling of non-bitstypes in `accum_param`
163+
y_explicit, back_explicit = @inferred pullback(x -> x.m, g)
164+
y_implicit, back_implicit = @inferred pullback(x -> x.m, Context{true}(nothing), g)
165+
@test y_explicit == y_implicit == getfield(g, :m)
166+
167+
∇args = ((m = [1.0, 0.0, 0.0], P = nothing),)
166168
if VERSION > v"1.7-"
167-
@test Base.return_types(back, Tuple{Vector{Float64}}) == Any[Union{Tuple{Nothing}, typeof(((m = [1.0, 0.0, 0.0], P = nothing),))}]
169+
# This type instability is due to the handling of non-bitstypes in `accum_param`
170+
@test Base.return_types(back_implicit, Tuple{Vector{Float64}}) == Any[Union{Tuple{Nothing}, typeof(∇args)}]
171+
# But the same should infer if implicit parameters are disabled
172+
@test Base.return_types(back_explicit, Tuple{Vector{Float64}}) == Any[typeof(∇args)]
168173
end
169-
@test back([1., 0, 0]) == ((m = [1.0, 0.0, 0.0], P = nothing),)
174+
@test back_explicit([1., 0, 0]) == back_implicit([1., 0, 0]) == ∇args
170175

171176
Base.getproperty(g::Gaussian, s::Symbol) = 2getfield(g, s)
172177
y, back = pullback(x -> x.m, g)

test/features.jl

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ end
476476
@test_broken gradient(x -> abs2(x[1].x) + 7 * x[1].x.re, [Ref(1+im)]) == ([(x = 9.0 + 2.0im,)],)
477477
@test_broken gradient(x -> abs2(x[1].x) + 7 * real(x[1].x), [Ref(1+im)]) == ([(x = 9.0 + 2.0im,)],) # worked on 0.6.0, 0.6.20
478478

479-
@test_broken gradient(x -> abs2(x[].x) + 7 * real(x[].x), Ref(Ref(1+im))) == ((x = 9.0 + 2.0im,),) # gives nothing, same in 0.6.0
479+
@test gradient(x -> abs2(x[].x) + 7 * real(x[].x), Ref(Ref(1+im))) == ((x = (x = 9.0 + 2.0im,),),) # gave `nothing` from 0.6.0 to 0.6.41
480480

481481
# Array of mutables:
482482
@test gradient(x -> sum(getindex.(x).^2), Ref.(1:3))[1] == [(;x=2i) for i in 1:3]
@@ -490,6 +490,59 @@ end
490490
@test gradient(x -> sum(sum, Ref(x) .* [1,2,3]), [4,5]) == ([6.0, 6.0],)
491491
end
492492

493+
@testset "mutable accum_param bugs" begin
494+
mutable struct Mut{T}; x::T; end
495+
struct Imm{T}; x::T; end
496+
497+
# Indexing a tuple containing a mutable struct gave `nothing`
498+
x1 = (Mut(3.0),)
499+
x2 = (Imm(3.0),)
500+
x3 = (Ref(3.0),)
501+
@test gradient(x -> x[1].x^2, x1)[1] == ((x = 6.0,),) # fails on v0.6.0 v0.6.41
502+
@test gradient(x -> x[1].x^2, x2)[1] == ((x = 6.0,),)
503+
@test gradient(x -> x[1].x^2, x3)[1] == ((x = 6.0,),) # fails on v0.6.0 v0.6.41
504+
i1 = 1
505+
@test gradient(x -> x[i1].x^2, x1)[1] == ((x = 6.0,),) # fails on v0.6.0 v0.6.41
506+
@test gradient(x -> x[i1].x^2, x2)[1] == ((x = 6.0,),)
507+
@test gradient(x -> x[i1].x^2, x3)[1] == ((x = 6.0,),) # fails on v0.6.0 v0.6.41
508+
509+
@test gradient(x -> x[1][1].x^2, [x1])[1] == [((x = 6.0,),)] # fails on v0.6.0 v0.6.41
510+
@test gradient(x -> x[1][1].x^2, [x2])[1] == [((x = 6.0,),)]
511+
@test gradient(x -> x[1][1].x^2, [x3])[1] == [((x = 6.0,),)] # fails on v0.6.0 v0.6.41
512+
513+
# When `getfield` returns a mutable struct, it gave `nothing`:
514+
x4 = Imm(Mut(4.0))
515+
x5 = Mut(Mut(4.0))
516+
x6 = Imm(Imm(4.0))
517+
@test gradient(x -> x.x.x^3, x4)[1] == (x = (x = 48.0,),) # fails on v0.6.0 v0.6.41
518+
@test gradient(x -> x.x.x^3, x5)[1] == (x = (x = 48.0,),) # fails on v0.6.0
519+
@test gradient(x -> x.x.x^3, x6)[1] == (x = (x = 48.0,),) # fails on v0.6.41
520+
521+
@test gradient(x -> x[2].x.x^3, [x4, x4])[1] == [nothing, (x = (x = 48.0,),)] # fails on v0.6.0 v0.6.41
522+
@test gradient(x -> x[2].x.x^3, [x4, x5])[1] == [nothing, (x = (x = 48.0,),)] # fails on v0.6.0
523+
@test gradient(x -> x[2].x.x^3, [x4, x6])[1] == [nothing, (x = (x = 48.0,),)] # fails on v0.6.41
524+
525+
# Check when using implicit parameters, Params cases used to pass:
526+
y1 = [3.0]
527+
y2 = (Mut(y1),)
528+
y3 = (Imm(y1),)
529+
@test gradient(x -> sum(x[1].x)^2, y2)[1] == ((x = [6.0],),) # fails on v0.6.0 v0.6.41
530+
@test gradient(() -> sum(y2[1].x)^2, Params([y1]))[y1] == [6.0]
531+
@test gradient(x -> sum(x[1].x)^2, y3)[1] == ((x = [6.0],),)
532+
@test gradient(() -> sum(y3[1].x)^2, Params([y1]))[y1] == [6.0]
533+
534+
@test gradient(x -> sum(x[1].x .+ x[1].x)^3, y2)[1] == ((x = [216.0],),) # fails on v0.6.0 v0.6.41
535+
@test gradient(() -> sum(y2[1].x .+ y2[1].x)^3, Params([y1]))[y1] == [216.0]
536+
@test gradient(x -> sum(x[1].x .+ x[1].x)^3, y3)[1] == ((x = [216.0],),)
537+
@test gradient(() -> sum(y3[1].x .+ y3[1].x)^3, Params([y1]))[y1] == [216.0]
538+
539+
i1 = 1
540+
@test gradient(x -> sum(x[i1].x .+ x[1].x)^3, y2)[1] == ((x = [216.0],),) # fails on v0.6.0 v0.6.41
541+
@test gradient(() -> sum(y2[i1].x .+ y2[1].x)^3, Params([y1]))[y1] == [216.0]
542+
@test gradient(x -> sum(x[i1].x .+ x[1].x)^3, y3)[1] == ((x = [216.0],),)
543+
@test gradient(() -> sum(y3[i1].x .+ y3[1].x)^3, Params([y1]))[y1] == [216.0]
544+
end
545+
493546
@testset "NamedTuples" begin
494547
@test gradient(x -> x.a, (a=1, b=2)) == ((a = 1, b = nothing),)
495548
@test gradient(x -> x[1].a, [(a=1, b=2)]) == ([(a = 1, b = nothing)],)
@@ -517,7 +570,7 @@ end
517570
@test (x->10*(x => 2)[2])'(100) === nothing
518571

519572
@test gradient(x-> (:x => x)[2], 17) == (1,)
520-
573+
521574
d = Dict(:x=>1.0, :y=>3.0);
522575
@test gradient(d -> Dict(:x => d[:x])[:x], d) == (Dict(:x => 1),)
523576
end
@@ -546,7 +599,7 @@ end
546599
# zip
547600
if VERSION >= v"1.5"
548601
# On Julia 1.4 and earlier, [x/y for (x,y) in zip(10:14, 1:10)] is a DimensionMismatch,
549-
# while on 1.5 - 1.7 it stops early.
602+
# while on 1.5 - 1.7 it stops early.
550603

551604
@test gradient(10:14, 1:10) do xs, ys
552605
sum([x/y for (x,y) in zip(xs, ys)])
@@ -608,7 +661,7 @@ end
608661

609662
# Iterators.Product with enumerate
610663
@test gradient([2 3; 4 5]) do xs
611-
sum([x^i+y for (i,x) in enumerate(xs), y in xs])
664+
sum([x^i+y for (i,x) in enumerate(xs), y in xs])
612665
end == ([8 112; 36 2004],)
613666
end
614667

0 commit comments

Comments
 (0)