Skip to content

Replace __init__ with OncePerProcess on supported Julia versions #316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FFTW"
uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
version = "1.8.1"
version = "1.9.0"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
32 changes: 27 additions & 5 deletions src/FFTW.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export dct, idct, dct!, idct!, plan_dct, plan_idct, plan_dct!, plan_idct!

include("providers.jl")

function __init__()
function initialize_library_paths()
# If someone is trying to set the provider via the old environment variable, warn them that they
# should instead use `set_provider!()` instead.
if haskey(ENV, "JULIA_FFTW_PROVIDER")
Expand All @@ -27,14 +27,36 @@ function __init__()
# libfftw3{,f} refs at runtime, since we may have relocated and
# changed the path to the library since the last time we precompiled.
@static if fftw_provider == "fftw"
libfftw3[] = FFTW_jll.libfftw3_path
libfftw3f[] = FFTW_jll.libfftw3f_path
libfftw3_path[] = FFTW_jll.libfftw3_path
libfftw3f_path[] = FFTW_jll.libfftw3f_path
fftw_init_threads()
end
@static if fftw_provider == "mkl"
libfftw3[] = MKL_jll.libmkl_rt_path
libfftw3f[] = MKL_jll.libmkl_rt_path
libfftw3_path[] = MKL_jll.libmkl_rt_path
libfftw3f_path[] = MKL_jll.libmkl_rt_path
end
return nothing
end

if VERSION >= v"1.12.0-beta1.29"
const initialize_library_paths_once = OncePerProcess{Nothing}() do
initialize_library_paths()
return
end
function libfftw3()
initialize_library_paths_once()
return libfftw3_path[]
end
function libfftw3f()
initialize_library_paths_once()
return libfftw3f_path[]
end
else
function __init__()
initialize_library_paths()
end
libfftw3() = libfftw3_path[]
libfftw3f() = libfftw3f_path[]
end

# most FFTW calls other than fftw_execute should be protected by a lock to be thread-safe
Expand Down
78 changes: 39 additions & 39 deletions src/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
## FFT: Implement fft by calling fftw.

const version = VersionNumber(split(unsafe_string(cglobal(
(:fftw_version,libfftw3[]), UInt8)), ['-', ' '])[2])
(:fftw_version,libfftw3()), UInt8)), ['-', ' '])[2])

## Direction of FFT

Expand Down Expand Up @@ -141,32 +141,32 @@
@exclusive function export_wisdom(fname::AbstractString)
f = ccall(:fopen, Ptr{Cvoid}, (Cstring,Cstring), fname, :w)
systemerror("could not open wisdom file $fname for writing", f == C_NULL)
ccall((:fftw_export_wisdom_to_file,libfftw3[]), Cvoid, (Ptr{Cvoid},), f)
ccall((:fftw_export_wisdom_to_file,libfftw3()), Cvoid, (Ptr{Cvoid},), f)

Check warning on line 144 in src/fft.jl

View check run for this annotation

Codecov / codecov/patch

src/fft.jl#L144

Added line #L144 was not covered by tests
ccall(:fputs, Int32, (Ptr{UInt8},Ptr{Cvoid}), " "^256, f) # no NUL, hence no Cstring
ccall((:fftwf_export_wisdom_to_file,libfftw3f[]), Cvoid, (Ptr{Cvoid},), f)
ccall((:fftwf_export_wisdom_to_file,libfftw3f()), Cvoid, (Ptr{Cvoid},), f)

Check warning on line 146 in src/fft.jl

View check run for this annotation

Codecov / codecov/patch

src/fft.jl#L146

Added line #L146 was not covered by tests
ccall(:fclose, Cvoid, (Ptr{Cvoid},), f)
end

@exclusive function import_wisdom(fname::AbstractString)
f = ccall(:fopen, Ptr{Cvoid}, (Cstring,Cstring), fname, :r)
systemerror("could not open wisdom file $fname for reading", f == C_NULL)
if ccall((:fftw_import_wisdom_from_file,libfftw3[]),Int32,(Ptr{Cvoid},),f)==0||
ccall((:fftwf_import_wisdom_from_file,libfftw3f[]),Int32,(Ptr{Cvoid},),f)==0
if ccall((:fftw_import_wisdom_from_file,libfftw3()),Int32,(Ptr{Cvoid},),f)==0||

Check warning on line 153 in src/fft.jl

View check run for this annotation

Codecov / codecov/patch

src/fft.jl#L153

Added line #L153 was not covered by tests
ccall((:fftwf_import_wisdom_from_file,libfftw3f()),Int32,(Ptr{Cvoid},),f)==0
error("failed to import wisdom from $fname")
end
ccall(:fclose, Cvoid, (Ptr{Cvoid},), f)
end

@exclusive function import_system_wisdom()
if ccall((:fftw_import_system_wisdom,libfftw3[]), Int32, ()) == 0 ||
ccall((:fftwf_import_system_wisdom,libfftw3f[]), Int32, ()) == 0
if ccall((:fftw_import_system_wisdom,libfftw3()), Int32, ()) == 0 ||

Check warning on line 161 in src/fft.jl

View check run for this annotation

Codecov / codecov/patch

src/fft.jl#L161

Added line #L161 was not covered by tests
ccall((:fftwf_import_system_wisdom,libfftw3f()), Int32, ()) == 0
error("failed to import system wisdom")
end
end

@exclusive function forget_wisdom()
ccall((:fftw_forget_wisdom,libfftw3[]), Cvoid, ())
ccall((:fftwf_forget_wisdom,libfftw3f[]), Cvoid, ())
ccall((:fftw_forget_wisdom,libfftw3()), Cvoid, ())
ccall((:fftwf_forget_wisdom,libfftw3f()), Cvoid, ())

Check warning on line 169 in src/fft.jl

View check run for this annotation

Codecov / codecov/patch

src/fft.jl#L168-L169

Added lines #L168 - L169 were not covered by tests
end

# Threads
Expand All @@ -176,15 +176,15 @@
@static if fftw_provider == "mkl"
_last_num_threads[] = num_threads
end
ccall((:fftw_plan_with_nthreads,libfftw3[]), Cvoid, (Int32,), num_threads)
ccall((:fftwf_plan_with_nthreads,libfftw3f[]), Cvoid, (Int32,), num_threads)
ccall((:fftw_plan_with_nthreads,libfftw3()), Cvoid, (Int32,), num_threads)
ccall((:fftwf_plan_with_nthreads,libfftw3f()), Cvoid, (Int32,), num_threads)
end

@exclusive set_num_threads(num_threads::Integer) = _set_num_threads(num_threads)

function get_num_threads()
@static if fftw_provider == "fftw"
ccall((:fftw_planner_nthreads,libfftw3[]), Cint, ())
ccall((:fftw_planner_nthreads,libfftw3()), Cint, ())
else
_last_num_threads[]
end
Expand All @@ -211,9 +211,9 @@

# only call these when fftwlock is held:
unsafe_set_timelimit(precision::fftwTypeDouble,seconds) =
ccall((:fftw_set_timelimit,libfftw3[]), Cvoid, (Float64,), seconds)
ccall((:fftw_set_timelimit,libfftw3()), Cvoid, (Float64,), seconds)
unsafe_set_timelimit(precision::fftwTypeSingle,seconds) =
ccall((:fftwf_set_timelimit,libfftw3f[]), Cvoid, (Float64,), seconds)
ccall((:fftwf_set_timelimit,libfftw3f()), Cvoid, (Float64,), seconds)
@exclusive set_timelimit(precision, seconds) = unsafe_set_timelimit(precision, seconds)

# Array alignment mod 16:
Expand All @@ -234,9 +234,9 @@
convert(Int32, convert(Int64, pointer(A)) % 16)
else
alignment_of(A::StridedArray{T}) where {T<:fftwDouble} =
ccall((:fftw_alignment_of, libfftw3[]), Int32, (Ptr{T},), A)
ccall((:fftw_alignment_of, libfftw3()), Int32, (Ptr{T},), A)
alignment_of(A::StridedArray{T}) where {T<:fftwSingle} =
ccall((:fftwf_alignment_of, libfftw3f[]), Int32, (Ptr{T},), A)
ccall((:fftwf_alignment_of, libfftw3f()), Int32, (Ptr{T},), A)
end

# FFTWPlan (low-level)
Expand Down Expand Up @@ -320,9 +320,9 @@

# these functions should only be called while the fftwlock is held
unsafe_destroy_plan(@nospecialize(plan::FFTWPlan{<:fftwDouble})) =
ccall((:fftw_destroy_plan,libfftw3[]), Cvoid, (PlanPtr,), plan)
ccall((:fftw_destroy_plan,libfftw3()), Cvoid, (PlanPtr,), plan)
unsafe_destroy_plan(@nospecialize(plan::FFTWPlan{<:fftwSingle})) =
ccall((:fftwf_destroy_plan,libfftw3f[]), Cvoid, (PlanPtr,), plan)
ccall((:fftwf_destroy_plan,libfftw3f()), Cvoid, (PlanPtr,), plan)

const deferred_destroy_lock = ReentrantLock() # lock protecting the deferred_destroy_plans list
const deferred_destroy_plans = FFTWPlan[]
Expand Down Expand Up @@ -388,19 +388,19 @@
#################################################################################################

cost(plan::FFTWPlan{<:fftwDouble}) =
ccall((:fftw_cost,libfftw3[]), Float64, (PlanPtr,), plan)
ccall((:fftw_cost,libfftw3()), Float64, (PlanPtr,), plan)
cost(plan::FFTWPlan{<:fftwSingle}) =
ccall((:fftwf_cost,libfftw3f[]), Float64, (PlanPtr,), plan)
ccall((:fftwf_cost,libfftw3f()), Float64, (PlanPtr,), plan)

@exclusive function arithmetic_ops(plan::FFTWPlan{<:fftwDouble})
add, mul, fma = Ref(0.0), Ref(0.0), Ref(0.0)
ccall((:fftw_flops,libfftw3[]), Cvoid,
ccall((:fftw_flops,libfftw3()), Cvoid,
(PlanPtr,Ref{Float64},Ref{Float64},Ref{Float64}), plan, add, mul, fma)
return (round(Int64, add[]), round(Int64, mul[]), round(Int64, fma[]))
end
@exclusive function arithmetic_ops(plan::FFTWPlan{<:fftwSingle})
add, mul, fma = Ref(0.0), Ref(0.0), Ref(0.0)
ccall((:fftwf_flops,libfftw3f[]), Cvoid,
ccall((:fftwf_flops,libfftw3f()), Cvoid,
(PlanPtr,Ref{Float64},Ref{Float64},Ref{Float64}), plan, add, mul, fma)
return (round(Int64, add[]), round(Int64, mul[]), round(Int64, fma[]))
end
Expand Down Expand Up @@ -431,9 +431,9 @@

@static if has_sprint_plan
sprint_plan_(plan::FFTWPlan{<:fftwDouble}) =
ccall((:fftw_sprint_plan,libfftw3[]), Ptr{UInt8}, (PlanPtr,), plan)
ccall((:fftw_sprint_plan,libfftw3()), Ptr{UInt8}, (PlanPtr,), plan)
sprint_plan_(plan::FFTWPlan{<:fftwSingle}) =
ccall((:fftwf_sprint_plan,libfftw3f[]), Ptr{UInt8}, (PlanPtr,), plan)
ccall((:fftwf_sprint_plan,libfftw3f()), Ptr{UInt8}, (PlanPtr,), plan)
function sprint_plan(plan::FFTWPlan)
p = sprint_plan_(plan)
str = unsafe_string(p)
Expand Down Expand Up @@ -515,49 +515,49 @@
# Execute

unsafe_execute!(plan::FFTWPlan{<:fftwDouble}) =
ccall((:fftw_execute,libfftw3[]), Cvoid, (PlanPtr,), plan)
ccall((:fftw_execute,libfftw3()), Cvoid, (PlanPtr,), plan)

unsafe_execute!(plan::FFTWPlan{<:fftwSingle}) =
ccall((:fftwf_execute,libfftw3f[]), Cvoid, (PlanPtr,), plan)
ccall((:fftwf_execute,libfftw3f()), Cvoid, (PlanPtr,), plan)

unsafe_execute!(plan::cFFTWPlan{T},
X::StridedArray{T}, Y::StridedArray{T}) where {T<:fftwDouble} =
ccall((:fftw_execute_dft,libfftw3[]), Cvoid,
ccall((:fftw_execute_dft,libfftw3()), Cvoid,
(PlanPtr,Ptr{T},Ptr{T}), plan, X, Y)

unsafe_execute!(plan::cFFTWPlan{T},
X::StridedArray{T}, Y::StridedArray{T}) where {T<:fftwSingle} =
ccall((:fftwf_execute_dft,libfftw3f[]), Cvoid,
ccall((:fftwf_execute_dft,libfftw3f()), Cvoid,
(PlanPtr,Ptr{T},Ptr{T}), plan, X, Y)

unsafe_execute!(plan::rFFTWPlan{Float64,FORWARD},
X::StridedArray{Float64}, Y::StridedArray{Complex{Float64}}) =
ccall((:fftw_execute_dft_r2c,libfftw3[]), Cvoid,
ccall((:fftw_execute_dft_r2c,libfftw3()), Cvoid,
(PlanPtr,Ptr{Float64},Ptr{Complex{Float64}}), plan, X, Y)

unsafe_execute!(plan::rFFTWPlan{Float32,FORWARD},
X::StridedArray{Float32}, Y::StridedArray{Complex{Float32}}) =
ccall((:fftwf_execute_dft_r2c,libfftw3f[]), Cvoid,
ccall((:fftwf_execute_dft_r2c,libfftw3f()), Cvoid,
(PlanPtr,Ptr{Float32},Ptr{Complex{Float32}}), plan, X, Y)

unsafe_execute!(plan::rFFTWPlan{Complex{Float64},BACKWARD},
X::StridedArray{Complex{Float64}}, Y::StridedArray{Float64}) =
ccall((:fftw_execute_dft_c2r,libfftw3[]), Cvoid,
ccall((:fftw_execute_dft_c2r,libfftw3()), Cvoid,
(PlanPtr,Ptr{Complex{Float64}},Ptr{Float64}), plan, X, Y)

unsafe_execute!(plan::rFFTWPlan{Complex{Float32},BACKWARD},
X::StridedArray{Complex{Float32}}, Y::StridedArray{Float32}) =
ccall((:fftwf_execute_dft_c2r,libfftw3f[]), Cvoid,
ccall((:fftwf_execute_dft_c2r,libfftw3f()), Cvoid,
(PlanPtr,Ptr{Complex{Float32}},Ptr{Float32}), plan, X, Y)

unsafe_execute!(plan::r2rFFTWPlan{T},
X::StridedArray{T}, Y::StridedArray{T}) where {T<:fftwDouble} =
ccall((:fftw_execute_r2r,libfftw3[]), Cvoid,
ccall((:fftw_execute_r2r,libfftw3()), Cvoid,
(PlanPtr,Ptr{T},Ptr{T}), plan, X, Y)

unsafe_execute!(plan::r2rFFTWPlan{T},
X::StridedArray{T}, Y::StridedArray{T}) where {T<:fftwSingle} =
ccall((:fftwf_execute_r2r,libfftw3f[]), Cvoid,
ccall((:fftwf_execute_r2r,libfftw3f()), Cvoid,
(PlanPtr,Ptr{T},Ptr{T}), plan, X, Y)

# NOTE ON GC (garbage collection):
Expand Down Expand Up @@ -654,7 +654,7 @@
unsafe_set_timelimit($Tr, timelimit)
R = isa(region, Tuple) ? region : copy(region)
dims, howmany = dims_howmany(X, Y, size(X), R)
plan = ccall(($(string(fftw,"_plan_guru64_dft")),$lib[]),
plan = ccall(($(string(fftw,"_plan_guru64_dft")),$lib()),
PlanPtr,
(Int32, Ptr{Int}, Int32, Ptr{Int},
Ptr{$Tc}, Ptr{$Tc}, Int32, UInt32),
Expand All @@ -674,7 +674,7 @@
regionshft = _circshiftmin1(region) # FFTW halves last dim
unsafe_set_timelimit($Tr, timelimit)
dims, howmany = dims_howmany(X, Y, size(X), regionshft)
plan = ccall(($(string(fftw,"_plan_guru64_dft_r2c")),$lib[]),
plan = ccall(($(string(fftw,"_plan_guru64_dft_r2c")),$lib()),
PlanPtr,
(Int32, Ptr{Int}, Int32, Ptr{Int},
Ptr{$Tr}, Ptr{$Tc}, UInt32),
Expand All @@ -694,7 +694,7 @@
regionshft = _circshiftmin1(region) # FFTW halves last dim
unsafe_set_timelimit($Tr, timelimit)
dims, howmany = dims_howmany(X, Y, size(Y), regionshft)
plan = ccall(($(string(fftw,"_plan_guru64_dft_c2r")),$lib[]),
plan = ccall(($(string(fftw,"_plan_guru64_dft_c2r")),$lib()),
PlanPtr,
(Int32, Ptr{Int}, Int32, Ptr{Int},
Ptr{$Tc}, Ptr{$Tr}, UInt32),
Expand All @@ -716,7 +716,7 @@
knd = fix_kinds(region, kinds)
unsafe_set_timelimit($Tr, timelimit)
dims, howmany = dims_howmany(X, Y, size(X), region)
plan = ccall(($(string(fftw,"_plan_guru64_r2r")),$lib[]),
plan = ccall(($(string(fftw,"_plan_guru64_r2r")),$lib()),
PlanPtr,
(Int32, Ptr{Int}, Int32, Ptr{Int},
Ptr{$Tr}, Ptr{$Tr}, Ptr{Int32}, UInt32),
Expand Down Expand Up @@ -744,7 +744,7 @@
howmany[2:3, :] .*= 2
end
howmany = [howmany [2,1,1]] # append loop over real/imag parts
plan = ccall(($(string(fftw,"_plan_guru64_r2r")),$lib[]),
plan = ccall(($(string(fftw,"_plan_guru64_r2r")),$lib()),
PlanPtr,
(Int32, Ptr{Int}, Int32, Ptr{Int},
Ptr{$Tc}, Ptr{$Tc}, Ptr{Int32}, UInt32),
Expand Down
23 changes: 13 additions & 10 deletions src/providers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ const fftw_provider = get_provider()
# We'll initialize `libfftw3` here (in the conditionals below), and
# it will get overwritten again in `__init__()`. This allows us to
# `ccall` at build time, and also be relocatable for PackageCompiler.
const libfftw3 = Ref{String}()
const libfftw3f = Ref{String}()
const libfftw3_path = Ref{String}()
const libfftw3f_path = Ref{String}()

"""
set_provider!(provider; export_prefs::Bool = false)
Expand All @@ -48,8 +48,8 @@ end
# If we're using fftw_jll, load it in
@static if fftw_provider == "fftw"
import FFTW_jll
libfftw3[] = FFTW_jll.libfftw3_path
libfftw3f[] = FFTW_jll.libfftw3f_path
libfftw3_path[] = FFTW_jll.libfftw3_path
libfftw3f_path[] = FFTW_jll.libfftw3f_path

# callback function that FFTW uses to launch `num` parallel
# tasks (FFTW/fftw3#175):
Expand All @@ -66,24 +66,27 @@ end
# (Previously, we called fftw_cleanup, but this invalidated existing
# plans, causing Base Julia issue #19892.)
function fftw_init_threads()
stat = ccall((:fftw_init_threads, libfftw3[]), Int32, ())
statf = ccall((:fftwf_init_threads, libfftw3f[]), Int32, ())
# We de-reference libfftw3(f)_path directly in this function instead of using
# `libfftw3(f)()` to avoid the circular dependency and thus a stack overflow. This
# function is only called after the path references are initialized.
stat = ccall((:fftw_init_threads, libfftw3_path[]), Int32, ())
statf = ccall((:fftwf_init_threads, libfftw3f_path[]), Int32, ())
if stat == 0 || statf == 0
error("could not initialize FFTW threads")
end

if nthreads() > 1
cspawnloop = @cfunction(spawnloop, Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Csize_t, Cint, Ptr{Cvoid}))
ccall((:fftw_threads_set_callback, libfftw3[]), Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}), cspawnloop, C_NULL)
ccall((:fftwf_threads_set_callback, libfftw3f[]), Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}), cspawnloop, C_NULL)
ccall((:fftw_threads_set_callback, libfftw3_path[]), Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}), cspawnloop, C_NULL)
ccall((:fftwf_threads_set_callback, libfftw3f_path[]), Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}), cspawnloop, C_NULL)
end
end
end

# If we're using MKL, load it in and set library paths appropriately.
@static if fftw_provider == "mkl"
import MKL_jll
libfftw3[] = MKL_jll.libmkl_rt_path
libfftw3f[] = MKL_jll.libmkl_rt_path
libfftw3_path[] = MKL_jll.libmkl_rt_path
libfftw3f_path[] = MKL_jll.libmkl_rt_path
const _last_num_threads = Ref(Cint(1))
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ end # fftw_provider == "fftw"
end

# check whether FFTW on this architecture has nontrivial alignment requirements
nontrivial_alignment = FFTW.fftw_provider == "fftw" && ccall((:fftwf_alignment_of, FFTW.libfftw3f[]), Int32, (Int,), 8) != 0
nontrivial_alignment = FFTW.fftw_provider == "fftw" && ccall((:fftwf_alignment_of, FFTW.libfftw3f()), Int32, (Int,), 8) != 0
if nontrivial_alignment
@test_throws ArgumentError plan_rfft(Array{Float32}(undef, 32)) * view(A, 2:33)
@test_throws ArgumentError plan_fft(Array{Complex{Float32}}(undef, 32)) * view(Ac, 2:33)
Expand Down
Loading