Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -443,10 +443,10 @@ frule((_, ΔAs...), ::typeof(+), As::AbstractArray...) = +(As...), +(ΔAs...)

function rrule(::typeof(+), arrs::AbstractArray...)
y = +(arrs...)
arr_axs = map(axes, arrs)
projs = map(ProjectTo, arrs)
function add_pullback(dy_raw)
dy = unthunk(dy_raw) # reshape will otherwise unthunk N times
return (NoTangent(), map(ax -> reshape(dy, ax), arr_axs)...)
dy = unthunk(dy_raw) # projs will otherwise unthunk N times
return (NoTangent(), map(proj -> proj(dy), projs)...)
end
return y, add_pullback
end
2 changes: 2 additions & 0 deletions test/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,5 +216,7 @@
# rev
@gpu test_rrule(+, randn(4, 4), randn(4, 4), randn(4, 4))
@gpu test_rrule(+, randn(3), randn(3,1), randn(3,1,1))
test_rrule(+, randn(3,3), Diagonal(randn(3)), randn(3,3,1))
test_rrule(+, randn(3,3), Diagonal(randn(3)), Symmetric(randn(3,3)))
end
end