205205 y = randn (size (x))
206206 for dims in unique ((1 , 1 : N, N))
207207 P = plan_fft (x, dims)
208- @test AbstractFFTs. output_size (P) == size (P * x)
208+ @test AbstractFFTs. output_size (P) == size (x)
209+ @test AbstractFFTs. output_size (P' ) == size (x)
209210 Pinv = plan_ifft (x)
210- @test AbstractFFTs. output_size (Pinv) == size (Pinv * x)
211+ @test AbstractFFTs. output_size (Pinv) == size (x)
212+ @test AbstractFFTs. output_size (Pinv' ) == size (x)
211213 end
212214 end
213215 end
218220 P = plan_rfft (x, dims)
219221 Px_sz = size (P * x)
220222 @test AbstractFFTs. output_size (P) == Px_sz
223+ @test AbstractFFTs. output_size (P' ) == size (x)
221224 y = randn (Px_sz) .+ randn (Px_sz) * im
222225 Pinv = plan_irfft (y, size (x)[first (dims)], dims)
223226 @test AbstractFFTs. output_size (Pinv) == size (Pinv * y)
227+ @test AbstractFFTs. output_size (Pinv' ) == size (y)
224228 end
225229 end
226230 end
233237 y = randn (size (x))
234238 for dims in unique ((1 , 1 : N, N))
235239 P = plan_fft (x, dims)
240+ @test (P' )' * x == P * x # test adjoint of adjoint
236241 @test dot (y, P * x) ≈ dot (P' * y, x)
237242 @test_broken dot (y, P \ x) ≈ dot (P' \ y, x)
238- Pinv = plan_ifft (x)
243+ Pinv = plan_ifft (y)
244+ @test (Pinv' )' * y == Pinv * y # test adjoint of adjoint
239245 @test dot (x, Pinv * y) ≈ dot (Pinv' * x, y)
240246 @test_broken dot (x, Pinv \ y) ≈ dot (Pinv' \ x, y)
241247 end
@@ -246,12 +252,14 @@ end
246252 N = ndims (x)
247253 for dims in unique ((1 , 1 : N, N))
248254 P = plan_rfft (x, dims)
255+ @test (P' )' * x == P * x
249256 y_real = randn (size (P * x))
250257 y_imag = randn (size (P * x))
251258 y = y_real .+ y_imag .* im
252259 @test dot (y_real, real .(P * x)) + dot (y_imag, imag .(P * x)) ≈ dot (P' * y, x)
253260 @test_broken dot (y_real, real .(P \ x)) + dot (y_imag, imag .(P \ x)) ≈ dot (P' * y, x)
254261 Pinv = plan_irfft (y, size (x)[first (dims)], dims)
262+ @test (Pinv' )' * y == Pinv * y
255263 @test dot (x, Pinv * y) ≈ dot (y_real, real .(Pinv' * x)) + dot (y_imag, imag .(Pinv' * x))
256264 @test_broken dot (x, Pinv \ y) ≈ dot (y_real, real .(Pinv' \ x)) + dot (y_imag, imag .(Pinv' \ x))
257265 end
@@ -284,20 +292,27 @@ end
284292 N = ndims (x)
285293 complex_x = complex .(x)
286294 for dims in unique ((1 , 1 : N, N))
295+ # fft, ifft, bfft
287296 for f in (fft, ifft, bfft)
288297 test_frule (f, x, dims)
289298 test_rrule (f, x, dims)
290299 test_frule (f, complex_x, dims)
291300 test_rrule (f, complex_x, dims)
292301 end
293-
294302 for pf in (plan_fft, plan_ifft, plan_bfft)
295303 test_frule (* , pf (x, dims) ⊢ NoTangent (), x)
296304 test_rrule (* , pf (x, dims) ⊢ NoTangent (), x)
297305 test_frule (* , pf (complex_x, dims) ⊢ NoTangent (), complex_x)
298306 test_rrule (* , pf (complex_x, dims) ⊢ NoTangent (), complex_x)
299307 end
300308
309+ # rfft
310+ test_frule (rfft, x, dims)
311+ test_rrule (rfft, x, dims)
312+ test_frule (* , plan_rfft (x, dims) ⊢ NoTangent (), x)
313+ test_rrule (* , plan_rfft (x, dims) ⊢ NoTangent (), x)
314+
315+ # irfft, brfft
301316 for f in (irfft, brfft)
302317 for d in (2 * size (x, first (dims)) - 1 , 2 * size (x, first (dims)) - 2 )
303318 test_frule (f, x, d, dims)
@@ -306,14 +321,12 @@ end
306321 test_rrule (f, complex_x, d, dims)
307322 end
308323 end
309-
310324 for pf in (plan_irfft, plan_brfft)
311325 for d in (2 * size (x, first (dims)) - 1 , 2 * size (x, first (dims)) - 2 )
312326 test_frule (* , pf (complex_x, d, dims) ⊢ NoTangent (), complex_x)
313327 test_rrule (* , pf (complex_x, d, dims) ⊢ NoTangent (), complex_x)
314328 end
315329 end
316-
317330 end
318331 end
319332 end
0 commit comments