Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 78 additions & 18 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,55 +58,95 @@ to1(x::AbstractArray) = _to1(axes(x), x)
_to1(::Tuple{Base.OneTo,Vararg{Base.OneTo}}, x) = x
_to1(::Tuple, x) = copy1(eltype(x), x)

# Abstract FFT Backend
export AbstractFFTBackend
abstract type AbstractFFTBackend end
const ACTIVE_BACKEND = Ref{Union{Missing, AbstractFFTBackend}}(missing)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A major caveat is that this is thread-unsafe. Maybe a task-local storage or a scope value would be better options? They may come at a performance penalty though, not sure. Alternatively, not having a global state would be better.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! I also tried playing around a bit with ScopedValues. I just didn't see a way where I could preserve the "default" behaviour of plan_fft with it, since the user would always need to do:

with(fft_backend=>FFTW.backend()) do
    p = plan_fft(x)
    # ...
end

I did see ScopedSettings.jl, which would allow for the desired behaviour but I'm not sure if such an additional non-base dependency is justified.

Another option would be to have both a "global" active backend and a dynamic backend with the later being a ScopedValue. Then the "high-level" API would need to first check if the dynamic scoped value is set, if not then check the global one and if neither is set throw the no-backend-error.

I'm not sure if packages that want to actively switch between backends, should target the high-level interface. The current behaviour if multiple FFT backends are loaded is also not well-defined at the moment


"""
set_active_backend!(back::Union{Missing, Module, AbstractFFTBackend})

Set the default FFT plan backend. A module `back` must implement `back.backend()`.
"""
set_active_backend!(back::Module) = set_active_backend!(back.backend())
function set_active_backend!(back::Union{Missing, AbstractFFTBackend})
ACTIVE_BACKEND[] = back
end
active_backend() = ACTIVE_BACKEND[]
function no_backend_error()
error(
"""
No default backend available!
Make sure to also "import/using" an FFT backend such as FFTW, FFTA or RustFFT.
"""
)
end

for f in (:fft, :bfft, :ifft, :fft!, :bfft!, :ifft!, :rfft, :brfft, :irfft)
pf = Symbol("plan_", f)
@eval begin
$f(x::AbstractArray, args...; kws...) = $f(active_backend(), x, args...; kws...)
$pf(x::AbstractArray, args...; kws...) = $pf(active_backend(), x, args...; kws...)
$f(::Missing, x::AbstractArray, args...; kws...) = no_backend_error()
$pf(::Missing, x::AbstractArray, args...; kws...) = no_backend_error()
end
end
# implementations only need to provide plan_X(x, region)
# for X in (:fft, :bfft, ...):
for f in (:fft, :bfft, :ifft, :fft!, :bfft!, :ifft!, :rfft)
pf = Symbol("plan_", f)
@eval begin
$f(x::AbstractArray) = $f(x, 1:ndims(x))
$f(x::AbstractArray, region) = (y = to1(x); $pf(y, region) * y)
$pf(x::AbstractArray; kws...) = (y = to1(x); $pf(y, 1:ndims(y); kws...))
$f(b::AbstractFFTBackend, x::AbstractArray) = $f(b, x, 1:ndims(x))
$f(b::AbstractFFTBackend, x::AbstractArray, region) = (y = to1(x); $pf(b, y, region) * y)
$pf(b::AbstractFFTBackend, x::AbstractArray; kws...) = (y = to1(x); $pf(b, y, 1:ndims(y); kws...))
end
end

"""
plan_ifft(backend, A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
plan_ifft(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)

Same as [`plan_fft`](@ref), but produces a plan that performs inverse transforms
[`ifft`](@ref).
[`ifft`](@ref). Uses active `backend` if no explicit `backend` is provided.
"""
plan_ifft

"""
plan_ifft!(backend, A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
plan_ifft!(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)

Same as [`plan_ifft`](@ref), but operates in-place on `A`.
"""
plan_ifft!

"""
plan_bfft!(backend, A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
plan_bfft!(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)

Same as [`plan_bfft`](@ref), but operates in-place on `A`.
"""
plan_bfft!

"""
plan_bfft(backend, A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
plan_bfft(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)

Same as [`plan_fft`](@ref), but produces a plan that performs an unnormalized
backwards transform [`bfft`](@ref).
backwards transform [`bfft`](@ref). Uses active `backend` if no explicit `backend` is provided.
"""
plan_bfft

"""
plan_fft(backend, A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
plan_fft(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)

Pre-plan an optimized FFT along given dimensions (`dims`) of arrays matching the shape and
type of `A`. (The first two arguments have the same meaning as for [`fft`](@ref).)
Returns an object `P` which represents the linear operator computed by the FFT, and which
contains all of the information needed to compute `fft(A, dims)` quickly.

Uses active `backend` if no explicit `backend` is provided.

To apply `P` to an array `A`, use `P * A`; in general, the syntax for applying plans is much
like that of matrices. (A plan can only be applied to arrays of the same size as the `A`
for which the plan was created.) You can also apply a plan with a preallocated output array `Â`
Expand All @@ -132,34 +172,40 @@ plans that perform the equivalent of the inverse transforms [`ifft`](@ref) and s
plan_fft

"""
plan_fft!(backend A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
plan_fft!(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)

Same as [`plan_fft`](@ref), but operates in-place on `A`.
"""
plan_fft!

"""
rfft(backend, A [, dims])
rfft(A [, dims])

Multidimensional FFT of a real array `A`, exploiting the fact that the transform has
conjugate symmetry in order to save roughly half the computational time and storage costs
compared with [`fft`](@ref). If `A` has size `(n_1, ..., n_d)`, the result has size
`(div(n_1,2)+1, ..., n_d)`.

Uses active `backend` if no explicit `backend` is provided.

The optional `dims` argument specifies an iterable subset of one or more dimensions of `A`
to transform, similar to [`fft`](@ref). Instead of (roughly) halving the first
dimension of `A` in the result, the `dims[1]` dimension is (roughly) halved in the same way.
"""
rfft

"""
ifft!(backend, A [, dims])
ifft!(A [, dims])

Same as [`ifft`](@ref), but operates in-place on `A`.
"""
ifft!

"""
ifft(backend, A [, dims])
ifft(A [, dims])

Multidimensional inverse FFT.
Expand All @@ -177,6 +223,7 @@ A multidimensional inverse FFT simply performs this operation along each transfo
ifft

"""
fft!(backend, A [, dims])
fft!(A [, dims])

Same as [`fft`](@ref), but operates in-place on `A`, which must be an array of
Expand All @@ -185,6 +232,7 @@ complex floating-point numbers.
fft!

"""
bfft(backend, A [, dims])
bfft(A [, dims])

Similar to [`ifft`](@ref), but computes an unnormalized inverse (backward)
Expand All @@ -200,6 +248,7 @@ computational steps elsewhere.)
bfft

"""
bfft!(backend, A [, dims])
bfft!(A [, dims])

Same as [`bfft`](@ref), but operates in-place on `A`.
Expand All @@ -215,10 +264,15 @@ for f in (:fft, :bfft, :ifft)
$pf(x::AbstractArray{<:Real}, region; kws...) = $pf(complexfloat(x), region; kws...)
$f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region) = $f(complexfloat(x), region)
$pf(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region; kws...) = $pf(complexfloat(x), region; kws...)
# These methods run into ambig. if a backend does not specialise T further
$f(b::AbstractFFTBackend, x::AbstractArray{<:Real}, region) = $f(b, complexfloat(x), region)
$pf(b::AbstractFFTBackend, x::AbstractArray{<:Real}, region; kws...) = $pf(b, complexfloat(x), region; kws...)
$f(b::AbstractFFTBackend, x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region) = $f(b, complexfloat(x), region)
$pf(b::AbstractFFTBackend, x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region; kws...) = $pf(b, complexfloat(x), region; kws...)
end
end
rfft(x::AbstractArray{<:Union{Integer,Rational}}, region=1:ndims(x)) = rfft(realfloat(x), region)
plan_rfft(x::AbstractArray, region; kws...) = plan_rfft(realfloat(x), region; kws...)
rfft(b::AbstractFFTBackend, x::AbstractArray{<:Union{Integer,Rational}}, region=1:ndims(x)) = rfft(b, realfloat(x), region)
plan_rfft(b::AbstractFFTBackend, x::AbstractArray, region; kws...) = plan_rfft(b, realfloat(x), region; kws...)

# only require implementation to provide *(::Plan{T}, ::Array{T})
*(p::Plan{T}, x::AbstractArray) where {T} = p * copy1(T, x)
Expand Down Expand Up @@ -279,10 +333,10 @@ summary(p::ScaledPlan) = string(p.scale, " * ", summary(p.p))
end
normalization(X, region) = normalization(real(eltype(X)), size(X), region)

plan_ifft(x::AbstractArray, region; kws...) =
ScaledPlan(plan_bfft(x, region; kws...), normalization(x, region))
plan_ifft!(x::AbstractArray, region; kws...) =
ScaledPlan(plan_bfft!(x, region; kws...), normalization(x, region))
plan_ifft(b::AbstractFFTBackend, x::AbstractArray, region; kws...) =
ScaledPlan(plan_bfft(b, x, region; kws...), normalization(x, region))
plan_ifft!(b::AbstractFFTBackend, x::AbstractArray, region; kws...) =
ScaledPlan(plan_bfft!(b, x, region; kws...), normalization(x, region))

plan_inv(p::ScaledPlan) = ScaledPlan(plan_inv(p.p), inv(p.scale))
# Don't cache inverse of scaled plan (only inverse of inner plan)
Expand All @@ -302,20 +356,21 @@ LinearAlgebra.mul!(y::AbstractArray, p::ScaledPlan, x::AbstractArray) =
for f in (:brfft, :irfft)
pf = Symbol("plan_", f)
@eval begin
$f(x::AbstractArray, d::Integer) = $f(x, d, 1:ndims(x))
$f(x::AbstractArray, d::Integer, region) = $pf(x, d, region) * x
$pf(x::AbstractArray, d::Integer;kws...) = $pf(x, d, 1:ndims(x);kws...)
$f(b::AbstractFFTBackend, x::AbstractArray, d::Integer) = $f(b, x, d, 1:ndims(x))
$f(b::AbstractFFTBackend, x::AbstractArray, d::Integer, region) = $pf(b, x, d, region) * x
$pf(b::AbstractFFTBackend, x::AbstractArray, d::Integer;kws...) = $pf(b, x, d, 1:ndims(x);kws...)
end
end

for f in (:brfft, :irfft)
@eval begin
$f(x::AbstractArray{<:Real}, d::Integer, region) = $f(complexfloat(x), d, region)
$f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, d::Integer, region) = $f(complexfloat(x), d, region)
$f(b::AbstractFFTBackend, x::AbstractArray{<:Real}, d::Integer, region) = $f(b, complexfloat(x), d, region)
$f(b::AbstractFFTBackend, x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, d::Integer, region) = $f(b, complexfloat(x), d, region)
end
end

"""
irfft(backend, A, d [, dims])
irfft(A, d [, dims])

Inverse of [`rfft`](@ref): for a complex array `A`, gives the corresponding real
Expand All @@ -330,6 +385,7 @@ transformed real array.)
irfft

"""
brfft(backend, A, d [, dims])
brfft(A, d [, dims])

Similar to [`irfft`](@ref) but computes an unnormalized inverse transform (similar
Expand All @@ -351,11 +407,12 @@ function brfft_output_size(sz::Dims{N}, d::Integer, region) where {N}
return ntuple(i -> i == d1 ? d : sz[i], Val(N))
end

plan_irfft(x::AbstractArray{Complex{T}}, d::Integer, region; kws...) where {T} =
ScaledPlan(plan_brfft(x, d, region; kws...),
plan_irfft(b::AbstractFFTBackend, x::AbstractArray{Complex{T}}, d::Integer, region; kws...) where {T} =
ScaledPlan(plan_brfft(b, x, d, region; kws...),
normalization(T, brfft_output_size(x, d, region), region))

"""
plan_irfft(backend, A, d [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
plan_irfft(A, d [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)

Pre-plan an optimized inverse real-input FFT, similar to [`plan_rfft`](@ref)
Expand Down Expand Up @@ -543,6 +600,7 @@ fftshift(x::Frequencies) = (x.n_nonnegative-x.n:x.n_nonnegative-1)*x.multiplier
##############################################################################

"""
fft(backend, A [, dims])
fft(A [, dims])

Performs a multidimensional FFT of the array `A`. The optional `dims` argument specifies an
Expand Down Expand Up @@ -570,6 +628,7 @@ A multidimensional FFT simply performs this operation along each transformed dim
fft

"""
plan_rfft(backend, A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
plan_rfft(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)

Pre-plan an optimized real-input FFT, similar to [`plan_fft`](@ref) except for
Expand All @@ -579,6 +638,7 @@ size of the transformed result, are the same as for [`rfft`](@ref).
plan_rfft

"""
plan_brfft(backend, A, d [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
plan_brfft(A, d [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)

Pre-plan an optimized real-input unnormalized transform, similar to
Expand Down
16 changes: 10 additions & 6 deletions test/TestPlans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ using LinearAlgebra
using AbstractFFTs
using AbstractFFTs: Plan

struct TestBackend <: AbstractFFTBackend end
backend() = TestBackend()
activate!() = AbstractFFTs.set_active_backend!(TestPlans)

mutable struct TestPlan{T,N,G} <: Plan{T}
region::G
sz::NTuple{N,Int}
Expand All @@ -30,10 +34,10 @@ Base.ndims(::InverseTestPlan{T,N}) where {T,N} = N
AbstractFFTs.AdjointStyle(::TestPlan) = AbstractFFTs.FFTAdjointStyle()
AbstractFFTs.AdjointStyle(::InverseTestPlan) = AbstractFFTs.FFTAdjointStyle()

function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...) where {T}
function AbstractFFTs.plan_fft(::TestBackend, x::AbstractArray{T}, region; kwargs...) where {T}
return TestPlan{T}(region, size(x))
end
function AbstractFFTs.plan_bfft(x::AbstractArray{T}, region; kwargs...) where {T}
function AbstractFFTs.plan_bfft(::TestBackend, x::AbstractArray{T}, region; kwargs...) where {T}
return InverseTestPlan{T}(region, size(x))
end

Expand Down Expand Up @@ -119,10 +123,10 @@ end
AbstractFFTs.AdjointStyle(::TestRPlan) = AbstractFFTs.RFFTAdjointStyle()
AbstractFFTs.AdjointStyle(p::InverseTestRPlan) = AbstractFFTs.IRFFTAdjointStyle(p.d)

function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T<:Real}
function AbstractFFTs.plan_rfft(::TestBackend, x::AbstractArray{T}, region; kwargs...) where {T<:Real}
return TestRPlan{T}(region, size(x))
end
function AbstractFFTs.plan_brfft(x::AbstractArray{Complex{T}}, d, region; kwargs...) where {T}
function AbstractFFTs.plan_brfft(::TestBackend, x::AbstractArray{Complex{T}}, d, region; kwargs...) where {T}
return InverseTestRPlan{T}(d, region, size(x))
end
function AbstractFFTs.plan_inv(p::TestRPlan{T,N}) where {T,N}
Expand Down Expand Up @@ -265,10 +269,10 @@ Base.ndims(p::InplaceTestPlan) = ndims(p.plan)
AbstractFFTs.fftdims(p::InplaceTestPlan) = fftdims(p.plan)
AbstractFFTs.AdjointStyle(p::InplaceTestPlan) = AbstractFFTs.AdjointStyle(p.plan)

function AbstractFFTs.plan_fft!(x::AbstractArray, region; kwargs...)
function AbstractFFTs.plan_fft!(::TestBackend, x::AbstractArray, region; kwargs...)
return InplaceTestPlan(plan_fft(x, region; kwargs...))
end
function AbstractFFTs.plan_bfft!(x::AbstractArray, region; kwargs...)
function AbstractFFTs.plan_bfft!(::TestBackend, x::AbstractArray, region; kwargs...)
return InplaceTestPlan(plan_bfft(x, region; kwargs...))
end

Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Random.seed!(1234)

# Load example plan implementation.
include("TestPlans.jl")
TestPlans.activate!()

# Run interface tests for TestPlans
AbstractFFTs.TestUtils.test_complex_ffts(Array)
Expand Down
Loading