Skip to content

Commit 033f137

Browse files
committed
Add rule for Dict iteration
1 parent 99d5a38 commit 033f137

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

src/lib/base.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,29 @@ end
4747
end
4848
end
4949

50+
# This rule behaves much like the getindex adjoint,
51+
# just with an (internal) ordinal index instead of a key.
52+
function _pullback(cx::AContext, ::typeof(iterate), d::Dict, i)
53+
iter = iterate(d, i)
54+
function dict_iterate_pullback(Δ)
55+
(iter === nothing || Δ === nothing) && return
56+
k, v = iter[1]
57+
_, dv = Δ[1]
58+
accum_param(cx, v, dv) === nothing && return
59+
grad = grad_mut(cx, d)
60+
grad[k] = accum(get(grad, k, nothing), dv)
61+
return (nothing, grad, nothing)
62+
end
63+
return iter, dict_iterate_pullback
64+
end
65+
66+
# ...while this one is to avoid duplicating code or differentiating skip_deleted.
67+
# The alternative would be to write a rule for the private _iterate(::Dict, i).
68+
function _pullback(cx::AContext, ::typeof(iterate), d::Dict)
69+
# Calculation of i is the same used in iterate(::Dict)
70+
return _pullback(cx, iterate, d, Base.skip_deleted(d, d.idxfloor))
71+
end
72+
5073
# Channels
5174

5275
grad_mut(ch::Channel) = Channel(ch.sz_max)

test/lib/base.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,23 @@
1010

1111
@test result1 == result2
1212
end
13+
14+
@testset "Dict iteration" begin
15+
# https://github.com/FluxML/Zygote.jl/issues/1065
16+
function sumvals(d)
17+
s = zero(d["c"])
18+
for (k, v) in d
19+
s += v
20+
k == :b && (s += v)
21+
end
22+
return sum(s)
23+
end
24+
25+
d = Dict(:a => 3, :b => 4, "c" => 5)
26+
@test gradient(sumvals, d)[1] == Dict(:a => 1, :b => 2, "c" => 1)
27+
28+
d = Dict(:a => [3], :b => [4], "c" => [5])
29+
grads = gradient(() -> sumvals(d), d |> values |> collect |> Params)
30+
@test (grads[d[:a]], grads[d[:b]], grads[d["c"]]) == ([1], [2], [1])
31+
end
1332
end

0 commit comments

Comments
 (0)