Skip to content

Commit 83cdacc

Browse files
committed
Add rule for Dict iteration
1 parent 99d5a38 commit 83cdacc

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

src/lib/base.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,45 @@ end
4747
end
4848
end
4949

50+
# This rule behaves much like the getindex adjoint,
51+
# just with an (internal) ordinal index instead of a key.
52+
function _pullback(cx::AContext, ::typeof(iterate), d::Dict, i)
53+
iter = iterate(d, i)
54+
function dict_iterate_pullback(Δ)
55+
(iter === nothing || Δ === nothing) && return
56+
k, v = iter[1]
57+
_, dv = Δ[1]
58+
accum_param(cx, v, dv) === nothing && return
59+
grad = grad_mut(cx, d)
60+
grad[k] = accum(get(grad, k, nothing), dv)
61+
return (nothing, grad, nothing)
62+
end
63+
return iter, dict_iterate_pullback
64+
end
65+
66+
# ...while this one is to avoid duplicating code or differentiating skip_deleted.
67+
# The alternative would be to write a rule for the private _iterate(::Dict, i).
68+
function _pullback(cx::AContext, ::typeof(iterate), d::Dict)
69+
# Calculation of i is the same used in iterate(::Dict)
70+
return _pullback(cx, iterate, d, Base.skip_deleted(d, d.idxfloor))
71+
end
72+
73+
function _pullback(cx::AContext, ::typeof(iterate), vi::Base.ValueIterator{<:Dict}, i::Int)
74+
iter = iterate(vi, i)
75+
function values_iterate_pullback(Δ)
76+
(iter === nothing || Δ === nothing) && return
77+
v, dv = iter[1], Δ[1]
78+
accum_param(cx, v, dv) === nothing && return
79+
# Same as vi.dict.keys[i], but without reaching into Dict internals.
80+
# Iterating the dict instead of keys() is to hit the rules above in nested AD.
81+
k = iterate(vi.dict, i)[1][1]
82+
grad = grad_mut(cx, vi.dict)
83+
grad[k] = accum(get(grad, k, nothing), dv)
84+
return (nothing, (; dict = grad), nothing)
85+
end
86+
return iter, values_iterate_pullback
87+
end
88+
5089
# Channels
5190

5291
grad_mut(ch::Channel) = Channel(ch.sz_max)

test/lib/base.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,36 @@
1010

1111
@test result1 == result2
1212
end
13+
14+
@testset "Dict iteration" begin
15+
# https://github.com/FluxML/Zygote.jl/issues/1065
16+
function sumkv(d)
17+
s = zero(d["c"])
18+
for (k, v) in d
19+
s += v
20+
k == :b && (s += v)
21+
end
22+
return sum(s)
23+
end
24+
25+
function sumvals(d)
26+
s = zero(d["c"])
27+
for v in values(d)
28+
s += v
29+
end
30+
return sum(s)
31+
end
32+
33+
d_num = Dict(:a => 3, :b => 4, "c" => 5)
34+
d_arr = Dict(:a => [3], :b => [4], "c" => [5])
35+
ps = d_arr |> values |> collect |> Params
36+
37+
@test gradient(sumkv, d_num)[1] == Dict(:a => 1, :b => 2, "c" => 1)
38+
grads = gradient(() -> sumkv(d_arr), ps)
39+
@test (grads[d_arr[:a]], grads[d_arr[:b]], grads[d_arr["c"]]) == ([1], [2], [1])
40+
41+
@test gradient(sumvals, d_num)[1] == Dict(:a => 1, :b => 1, "c" => 1)
42+
grads = gradient(() -> sumvals(d_arr), ps)
43+
@test (grads[d_arr[:a]], grads[d_arr[:b]], grads[d_arr["c"]]) == ([1], [1], [1])
44+
end
1345
end

0 commit comments

Comments
 (0)