@@ -46,7 +46,7 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,
4646 (ET === Union{} || ET === Any) &&
4747 error (" mapreduce cannot figure the output element type, please pass an explicit init value" )
4848
49- init = neutral_element (op, ET)
49+ init = AK . neutral_element (op, ET)
5050 else
5151 ET = typeof (init)
5252 end
@@ -98,7 +98,7 @@ Base.any(f::Function, A::AnyGPUArray) = AK.any(f, A)
9898Base. all (f:: Function , A:: AnyGPUArray ) = AK. all (f, A)
9999
100100Base. count (pred:: Function , A:: AnyGPUArray ; dims= :, init= 0 ) =
101- AK. count (pred, A; init, dims= dims isa Colon ? nothing : dims)# mapreduce(pred, Base.add_sum, A; init=init, dims=dims)
101+ AK. count (pred, A; init, dims= dims isa Colon ? nothing : dims)
102102
103103# avoid calling into `initarray!`
104104for (fname, op) in [(:sum , :(Base. add_sum)), (:prod , :(Base. mul_prod)),
@@ -107,7 +107,7 @@ for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)),
107107 fname! = Symbol (fname, ' !' )
108108 @eval begin
109109 Base.$ (fname!)(f:: Function , r:: AnyGPUArray , A:: AnyGPUArray{T} ) where T =
110- GPUArrays. mapreducedim! (f, $ (op), r, A; init= neutral_element ($ (op), T))
110+ GPUArrays. mapreducedim! (f, $ (op), r, A; init= AK . neutral_element ($ (op), T))
111111 end
112112end
113113
0 commit comments