1- #  ffts
1+ module  AbstractFFTsChainRulesCoreExt
2+ 
3+ using  AbstractFFTs
4+ import  ChainRulesCore
5+ 
26function  ChainRulesCore. frule ((_, Δx, _), :: typeof (fft), x:: AbstractArray , dims)
37    y =  fft (x, dims)
48    Δy =  fft (Δx, dims)
@@ -33,7 +37,9 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims)
3337
3438    project_x =  ChainRulesCore. ProjectTo (x)
3539    function  rfft_pullback (ȳ)
36-         x̄ =  project_x (brfft (ChainRulesCore. unthunk (ȳ) ./  scale, d, dims))
40+         ybar =  ChainRulesCore. unthunk (ȳ)
41+         _scale =  convert (typeof (ybar),scale)
42+         x̄ =  project_x (brfft (ybar ./  _scale, d, dims))
3743        return  ChainRulesCore. NoTangent (), x̄, ChainRulesCore. NoTangent ()
3844    end 
3945    return  y, rfft_pullback
@@ -46,7 +52,7 @@ function ChainRulesCore.frule((_, Δx, _), ::typeof(ifft), x::AbstractArray, dim
4652end 
4753function  ChainRulesCore. rrule (:: typeof (ifft), x:: AbstractArray , dims)
4854    y =  ifft (x, dims)
49-     invN =  normalization (y, dims)
55+     invN =  AbstractFFTs . normalization (y, dims)
5056    project_x =  ChainRulesCore. ProjectTo (x)
5157    function  ifft_pullback (ȳ)
5258        x̄ =  project_x (invN .*  fft (ChainRulesCore. unthunk (ȳ), dims))
@@ -66,7 +72,7 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims)
6672    #  compute scaling factors
6773    halfdim =  first (dims)
6874    n =  size (x, halfdim)
69-     invN =  normalization (y, dims)
75+     invN =  AbstractFFTs . normalization (y, dims)
7076    twoinvN =  2  *  invN
7177    scale =  reshape (
7278        [i ==  1  ||  (i ==  n &&  2  *  (i -  1 ) ==  d) ?  invN :  twoinvN for  i in  1 : n],
@@ -75,7 +81,9 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims)
7581
7682    project_x =  ChainRulesCore. ProjectTo (x)
7783    function  irfft_pullback (ȳ)
78-         x̄ =  project_x (scale .*  rfft (real .(ChainRulesCore. unthunk (ȳ)), dims))
84+         ybar =  ChainRulesCore. unthunk (ȳ)
85+         _scale =  convert (typeof (ybar),scale)
86+         x̄ =  project_x (_scale .*  rfft (real .(ybar), dims))
7987        return  ChainRulesCore. NoTangent (), x̄, ChainRulesCore. NoTangent (), ChainRulesCore. NoTangent ()
8088    end 
8189    return  y, irfft_pullback
@@ -152,12 +160,12 @@ function ChainRulesCore.rrule(::typeof(ifftshift), x::AbstractArray, dims)
152160end 
153161
154162#  plans
155- function  ChainRulesCore. frule ((_, _, Δx), :: typeof (* ), P:: Plan , x:: AbstractArray ) 
163+ function  ChainRulesCore. frule ((_, _, Δx), :: typeof (* ), P:: AbstractFFTs. Plan:: AbstractArray ) 
156164    y =  P *  x 
157165    Δy =  P *  Δx
158166    return  y, Δy
159167end 
160- function  ChainRulesCore. rrule (:: typeof (* ), P:: Plan , x:: AbstractArray )
168+ function  ChainRulesCore. rrule (:: typeof (* ), P:: AbstractFFTs. Plan:: AbstractArray )
161169    y =  P *  x
162170    project_x =  ChainRulesCore. ProjectTo (x)
163171    Pt =  P' 
@@ -168,22 +176,25 @@ function ChainRulesCore.rrule(::typeof(*), P::Plan, x::AbstractArray)
168176    return  y, mul_plan_pullback
169177end 
170178
171- function  ChainRulesCore. frule ((_, ΔP, Δx), :: typeof (* ), P:: ScaledPlan , x:: AbstractArray ) 
179+ function  ChainRulesCore. frule ((_, ΔP, Δx), :: typeof (* ), P:: AbstractFFTs. ScaledPlan:: AbstractArray ) 
172180    y =  P *  x 
173181    Δy =  P *  Δx .+  (ΔP. scale /  P. scale) .*  y
174182    return  y, Δy
175183end 
176- function  ChainRulesCore. rrule (:: typeof (* ), P:: ScaledPlan , x:: AbstractArray )
184+ function  ChainRulesCore. rrule (:: typeof (* ), P:: AbstractFFTs. ScaledPlan:: AbstractArray )
177185    y =  P *  x
178186    Pt =  P' 
179187    scale =  P. scale
180188    project_x =  ChainRulesCore. ProjectTo (x)
181189    project_scale =  ChainRulesCore. ProjectTo (scale)
182190    function  mul_scaledplan_pullback (ȳ)
183191        x̄ =  ChainRulesCore. @thunk (project_x (Pt *  ȳ))
184-         scale_tangent =  ChainRulesCore. @thunk (project_scale (dot (y, ȳ) /  conj (scale)))
192+         scale_tangent =  ChainRulesCore. @thunk (project_scale (AbstractFFTs . dot (y, ȳ) /  conj (scale)))
185193        plan_tangent =  ChainRulesCore. Tangent {typeof(P)} (;p= ChainRulesCore. NoTangent (), scale= scale_tangent)
186194        return  ChainRulesCore. NoTangent (), plan_tangent, x̄ 
187195    end 
188196    return  y, mul_scaledplan_pullback
189197end 
198+ 
199+ end  #  module
200+ 
0 commit comments