|
47 | 47 | end |
48 | 48 | end |
49 | 49 |
|
| 50 | +# This rule behaves much like the getindex adjoint, |
| 51 | +# just with an (internal) ordinal index instead of a key. |
| 52 | +function _pullback(cx::AContext, ::typeof(iterate), d::Dict, i) |
| 53 | + iter = iterate(d, i) |
| 54 | + function dict_iterate_pullback(Δ) |
| 55 | + (iter === nothing || Δ === nothing) && return |
| 56 | + k, v = iter[1] |
| 57 | + _, dv = Δ[1] |
| 58 | + accum_param(cx, v, dv) === nothing && return |
| 59 | + grad = grad_mut(cx, d) |
| 60 | + grad[k] = accum(get(grad, k, nothing), dv) |
| 61 | + return (nothing, grad, nothing) |
| 62 | + end |
| 63 | + return iter, dict_iterate_pullback |
| 64 | +end |
| 65 | + |
| 66 | +# ...while this one is to avoid duplicating code or differentiating skip_deleted. |
| 67 | +# The alternative would be to write a rule for the private _iterate(::Dict, i). |
| 68 | +function _pullback(cx::AContext, ::typeof(iterate), d::Dict) |
| 69 | + # Calculation of i is the same used in iterate(::Dict) |
| 70 | + return _pullback(cx, iterate, d, Base.skip_deleted(d, d.idxfloor)) |
| 71 | +end |
| 72 | + |
| 73 | +function _pullback(cx::AContext, ::typeof(iterate), vi::Base.ValueIterator{<:Dict}, i::Int) |
| 74 | + iter = iterate(vi, i) |
| 75 | + function values_iterate_pullback(Δ) |
| 76 | + (iter === nothing || Δ === nothing) && return |
| 77 | + v, dv = iter[1], Δ[1] |
| 78 | + accum_param(cx, v, dv) === nothing && return |
| 79 | + # Same as vi.dict.keys[i], but without reaching into Dict internals. |
| 80 | + # Iterating the dict instead of keys() is to hit the rules above in nested AD. |
| 81 | + k = iterate(vi.dict, i)[1][1] |
| 82 | + grad = grad_mut(cx, vi.dict) |
| 83 | + grad[k] = accum(get(grad, k, nothing), dv) |
| 84 | + return (nothing, (; dict = grad), nothing) |
| 85 | + end |
| 86 | + return iter, values_iterate_pullback |
| 87 | +end |
| 88 | + |
50 | 89 | # Channels |
51 | 90 |
|
52 | 91 | grad_mut(ch::Channel) = Channel(ch.sz_max) |
|
0 commit comments