Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions src/RustFFT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using .Internal
@reexport using AbstractFFTs

import Base: *, size
import AbstractFFTs: Plan, ScaledPlan, plan_fft, fft, plan_fft!, fft!, plan_bfft, plan_bfft!,
import AbstractFFTs: Plan, ScaledPlan, AbstractFFTBackend, plan_fft, fft, plan_fft!, fft!, plan_bfft, plan_bfft!,
ifft, ifft!, fftdims, plan_inv
import LinearAlgebra: mul!

Expand Down Expand Up @@ -87,6 +87,15 @@ struct IgnoreArrayChecks <: ArrayChecks end
# or `ComplexF32`.
const RustFFTNumber = Union{Complex{Float64},Complex{Float32}}

export RustFFTBackend
struct RustFFTBackend <: AbstractFFTBackend end
backend() = RustFFTBackend()
activate!() = AbstractFFTs.set_active_backend!(RustFFT)

function __init__()
activate!()
end

mutable struct RustFFTPlan{T<:RustFFTNumber,inplace,direction<:Direction,gcsafety<:GcSafety,arraychecks<:ArrayChecks} <: Plan{T}
plan::FftInstance{T}
pinv::ScaledPlan
Expand Down Expand Up @@ -116,7 +125,7 @@ end
return FftPlanner64()
end

function plan_fft(x::Vector{T}, region;
function plan_fft(::RustFFTBackend, x::Vector{T}, region;
rustfft_checks::arraychecks=IgnoreArrayTracking(),
rustfft_gc_safe::gcsafety=GcUnsafe(),
rustfft_planner::Union{FftPlanner{T},Nothing}=nothing,
Expand All @@ -133,7 +142,7 @@ function plan_fft(x::Vector{T}, region;
RustFFTPlan{T,false,Forward,gcsafety,arraychecks}(instance)
end

function plan_fft!(x::Vector{T}, region;
function plan_fft!(::RustFFTBackend, x::Vector{T}, region;
rustfft_checks::arraychecks=AllArrayChecks(),
rustfft_gc_safe::gcsafety=GcUnsafe(),
rustfft_planner::Union{FftPlanner{T},Nothing}=nothing,
Expand All @@ -150,7 +159,7 @@ function plan_fft!(x::Vector{T}, region;
RustFFTPlan{T,true,Forward,gcsafety,arraychecks}(instance)
end

function plan_bfft(x::Vector{T}, region;
function plan_bfft(::RustFFTBackend, x::Vector{T}, region;
rustfft_checks::arraychecks=IgnoreArrayTracking(),
rustfft_gc_safe::gcsafety=GcUnsafe(),
rustfft_planner::Union{FftPlanner{T},Nothing}=nothing,
Expand All @@ -167,7 +176,7 @@ function plan_bfft(x::Vector{T}, region;
RustFFTPlan{T,false,Backward,gcsafety,arraychecks}(instance)
end

function plan_bfft!(x::Vector{T}, region;
function plan_bfft!(::RustFFTBackend, x::Vector{T}, region;
rustfft_checks::arraychecks=AllArrayChecks(),
rustfft_gc_safe::gcsafety=GcUnsafe(),
rustfft_planner::Union{FftPlanner{T},Nothing}=nothing,
Expand Down