@@ -6,6 +6,7 @@ using AbstractFFTs
66using AbstractFFTs: TestUtils
77using AbstractFFTs. LinearAlgebra
88using Test
9+ import Random
910
1011# Ground truth x_fft computed using FFTW library
1112const TEST_CASES = (
@@ -52,15 +53,18 @@ const TEST_CASES = (
5253 )
5354
5455
55- function TestUtils. test_plan (P:: AbstractFFTs.Plan , x:: AbstractArray , x_transformed:: AbstractArray ; inplace_plan= false , copy_input= false )
56+ function TestUtils. test_plan (P:: AbstractFFTs.Plan , x:: AbstractArray , x_transformed:: AbstractArray ;
57+ inplace_plan= false , copy_input= false , test_wrappers= true )
5658 _copy = copy_input ? copy : identity
5759 if ! inplace_plan
5860 @test P * _copy (x) ≈ x_transformed
5961 @test P \ (P * _copy (x)) ≈ x
6062 _x_out = similar (P * _copy (x))
6163 @test mul! (_x_out, P, _copy (x)) ≈ x_transformed
6264 @test _x_out ≈ x_transformed
63- @test P * view (_copy (x), axes (x)... ) ≈ x_transformed # test view input
65+ if test_wrappers
66+ @test P * view (_copy (x), axes (x)... ) ≈ x_transformed # test view input
67+ end
6468 else
6569 _x = copy (x)
6670 @test P * _copy (_x) ≈ x_transformed
@@ -70,9 +74,10 @@ function TestUtils.test_plan(P::AbstractFFTs.Plan, x::AbstractArray, x_transform
7074 end
7175end
7276
73- function TestUtils. test_plan_adjoint (P:: AbstractFFTs.Plan , x:: AbstractArray ; real_plan= false , copy_input= false )
77+ function TestUtils. test_plan_adjoint (P:: AbstractFFTs.Plan , x:: AbstractArray ;
78+ real_plan= false , copy_input= false , test_wrappers= true )
7479 _copy = copy_input ? copy : identity
75- y = rand ( eltype ( P * _copy (x)), size (P * _copy (x) ))
80+ y = Random . rand! ( P * _copy (x))
7681 # test basic properties
7782 @test_skip eltype (P' ) === typeof (y) # (AbstractFFTs.jl#110)
7883 @test (P' )' === P # test adjoint of adjoint
@@ -86,11 +91,13 @@ function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray; rea
8691 @test _component_dot (y, P * _copy (x)) ≈ _component_dot (P' * _copy (y), x)
8792 @test _component_dot (x, P \ _copy (y)) ≈ _component_dot (P' \ _copy (x), y)
8893 end
89- @test P' * view (_copy (y), axes (y)... ) ≈ P' * _copy (y) # test view input (AbstractFFTs.jl#112)
94+ if test_wrappers
95+ @test P' * view (_copy (y), axes (y)... ) ≈ P' * _copy (y) # test view input (AbstractFFTs.jl#112)
96+ end
9097 @test_throws MethodError mul! (x, P' , y)
9198end
9299
93- function TestUtils. test_complex_ffts (ArrayType= Array; test_inplace= true , test_adjoint= true )
100+ function TestUtils. test_complex_ffts (ArrayType= Array; test_inplace= true , test_adjoint= true , test_wrappers = true )
94101 @testset " correctness of fft, bfft, ifft" begin
95102 for test_case in TEST_CASES
96103 _x, dims, _x_fft = copy (test_case. x), test_case. dims, copy (test_case. x_fft)
@@ -110,18 +117,18 @@ function TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_ad
110117 for P in (plan_fft (similar (x_complexf), dims),
111118 (_inv (plan_ifft (similar (x_complexf), dims)) for _inv in (inv, AbstractFFTs. plan_inv)). .. )
112119 @test eltype (P) <: Complex
113- @test fftdims (P) == dims
114- TestUtils. test_plan (P, x_complexf, x_fft)
120+ @test collect ( fftdims (P))[:] == collect ( dims)[:] # compare as iterables
121+ TestUtils. test_plan (P, x_complexf, x_fft; test_wrappers = test_wrappers )
115122 if test_adjoint
116123 @test fftdims (P' ) == fftdims (P)
117- TestUtils. test_plan_adjoint (P, x_complexf)
124+ TestUtils. test_plan_adjoint (P, x_complexf, test_wrappers = test_wrappers )
118125 end
119126 end
120127 if test_inplace
121128 # test IIP plans
122129 for P in (plan_fft! (similar (x_complexf), dims),
123130 (_inv (plan_ifft! (similar (x_complexf), dims)) for _inv in (inv, AbstractFFTs. plan_inv)). .. )
124- TestUtils. test_plan (P, x_complexf, x_fft; inplace_plan= true )
131+ TestUtils. test_plan (P, x_complexf, x_fft; inplace_plan= true , test_wrappers = test_wrappers )
125132 end
126133 end
127134
@@ -136,17 +143,17 @@ function TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_ad
136143 # test OOP plans. Just 1 plan to test, but we use a for loop for consistent style
137144 for P in (plan_bfft (similar (x_fft), dims),)
138145 @test eltype (P) <: Complex
139- @test fftdims (P) == dims
140- TestUtils. test_plan (P, x_fft, x_scaled)
146+ @test collect ( fftdims (P))[:] == collect ( dims)[:] # compare as iterables
147+ TestUtils. test_plan (P, x_fft, x_scaled; test_wrappers = test_wrappers )
141148 if test_adjoint
142- TestUtils. test_plan_adjoint (P, x_fft)
149+ TestUtils. test_plan_adjoint (P, x_fft, test_wrappers = test_wrappers )
143150 end
144151 end
145152 # test IIP plans
146153 for P in (plan_bfft! (similar (x_fft), dims),)
147154 @test eltype (P) <: Complex
148- @test fftdims (P) == dims
149- TestUtils. test_plan (P, x_fft, x_scaled; inplace_plan= true )
155+ @test collect ( fftdims (P))[:] == collect ( dims)[:] # compare as iterables
156+ TestUtils. test_plan (P, x_fft, x_scaled; inplace_plan= true , test_wrappers = test_wrappers )
150157 end
151158
152159 # IFFT
@@ -160,33 +167,33 @@ function TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_ad
160167 for P in (plan_ifft (similar (x_complexf), dims),
161168 (_inv (plan_fft (similar (x_complexf), dims)) for _inv in (inv, AbstractFFTs. plan_inv)). .. )
162169 @test eltype (P) <: Complex
163- @test fftdims (P) == dims
164- TestUtils. test_plan (P, x_fft, x)
170+ @test collect ( fftdims (P))[:] == collect ( dims)[:] # compare as iterables
171+ TestUtils. test_plan (P, x_fft, x; test_wrappers = test_wrappers )
165172 if test_adjoint
166- TestUtils. test_plan_adjoint (P, x_fft)
173+ TestUtils. test_plan_adjoint (P, x_fft; test_wrappers = test_wrappers )
167174 end
168175 end
169176 # test IIP plans
170177 if test_inplace
171178 for P in (plan_ifft! (similar (x_complexf), dims),
172179 (_inv (plan_fft! (similar (x_complexf), dims)) for _inv in (inv, AbstractFFTs. plan_inv)). .. )
173180 @test eltype (P) <: Complex
174- @test fftdims (P) == dims
175- TestUtils. test_plan (P, x_fft, x; inplace_plan= true )
181+ @test collect ( fftdims (P))[:] == collect ( dims)[:] # compare as iterables
182+ TestUtils. test_plan (P, x_fft, x; inplace_plan= true , test_wrappers = test_wrappers )
176183 end
177184 end
178185 end
179186 end
180187end
181188
182- function TestUtils. test_real_ffts (ArrayType= Array; test_adjoint= true , copy_input= false )
189+ function TestUtils. test_real_ffts (ArrayType= Array; test_adjoint= true , copy_input= false , test_wrappers = true )
183190 @testset " correctness of rfft, brfft, irfft" begin
184191 for test_case in TEST_CASES
185192 _x, dims, _x_fft = copy (test_case. x), test_case. dims, copy (test_case. x_fft)
186193 x = convert (ArrayType, _x) # dummy array that will be passed to plans
187194 x_real = float .(x) # for testing mutating real FFTs
188195 x_fft = convert (ArrayType, _x_fft)
189- x_rfft = collect (selectdim (x_fft, first (dims), 1 : (size (x_fft, first (dims)) ÷ 2 + 1 )))
196+ x_rfft = convert (ArrayType, collect (selectdim (x_fft, first (dims), 1 : (size (x_fft, first (dims)) ÷ 2 + 1 ) )))
190197
191198 if ! (eltype (x) <: Real )
192199 continue
@@ -197,10 +204,10 @@ function TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input
197204 for P in (plan_rfft (similar (x_real), dims),
198205 (_inv (plan_irfft (similar (x_rfft), size (x, first (dims)), dims)) for _inv in (inv, AbstractFFTs. plan_inv)). .. )
199206 @test eltype (P) <: Real
200- @test fftdims (P) == dims
201- TestUtils. test_plan (P, x_real, x_rfft; copy_input= copy_input)
207+ @test collect ( fftdims (P))[:] == collect ( dims)[:] # compare as iterables
208+ TestUtils. test_plan (P, x_real, x_rfft; copy_input= copy_input, test_wrappers = test_wrappers )
202209 if test_adjoint
203- TestUtils. test_plan_adjoint (P, x_real; real_plan= true , copy_input= copy_input)
210+ TestUtils. test_plan_adjoint (P, x_real; real_plan= true , copy_input= copy_input, test_wrappers = test_wrappers )
204211 end
205212 end
206213
@@ -209,10 +216,10 @@ function TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input
209216 @test brfft (x_rfft, size (x, first (dims)), dims) ≈ x_scaled
210217 for P in (plan_brfft (similar (x_rfft), size (x, first (dims)), dims),)
211218 @test eltype (P) <: Complex
212- @test fftdims (P) == dims
213- TestUtils. test_plan (P, x_rfft, x_scaled; copy_input= copy_input)
219+ @test collect ( fftdims (P))[:] == collect ( dims)[:] # compare as iterables
220+ TestUtils. test_plan (P, x_rfft, x_scaled; copy_input= copy_input, test_wrappers = test_wrappers )
214221 if test_adjoint
215- TestUtils. test_plan_adjoint (P, x_rfft; real_plan= true , copy_input= copy_input)
222+ TestUtils. test_plan_adjoint (P, x_rfft; real_plan= true , copy_input= copy_input, test_wrappers = test_wrappers )
216223 end
217224 end
218225
@@ -221,10 +228,10 @@ function TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input
221228 for P in (plan_irfft (similar (x_rfft), size (x, first (dims)), dims),
222229 (_inv (plan_rfft (similar (x_real), dims)) for _inv in (inv, AbstractFFTs. plan_inv)). .. )
223230 @test eltype (P) <: Complex
224- @test fftdims (P) == dims
225- TestUtils. test_plan (P, x_rfft, x; copy_input= copy_input)
231+ @test collect ( fftdims (P))[:] == collect ( dims)[:] # compare as iterables
232+ TestUtils. test_plan (P, x_rfft, x; copy_input= copy_input, test_wrappers = test_wrappers )
226233 if test_adjoint
227- TestUtils. test_plan_adjoint (P, x_rfft; real_plan= true , copy_input= copy_input)
234+ TestUtils. test_plan_adjoint (P, x_rfft; real_plan= true , copy_input= copy_input, test_wrappers = test_wrappers )
228235 end
229236 end
230237 end
0 commit comments