Zygote is able to handle pfft
but not pifft
. Let’s look at the types
julia> typeof(pfft)
FFTW.cFFTWPlan{Complex{Float64},-1,false,1,UnitRange{Int64}}
julia> typeof(pifft)
AbstractFFTs.ScaledPlan{Complex{Float64},FFTW.cFFTWPlan{Complex{Float64},1,false,1,UnitRange{Int64}},Float64}
This is where ScaledPlan
comes from, the reason is 0.01
:
julia> pfft
FFTW forward plan for 100-element array of Complex{Float64}
(dft-ct-dit/10
(dftw-direct-10/6 "t3fv_10_avx2_128")
(dft-direct-10-x10 "n2fv_10_avx2_128"))
julia> pifft
0.01 * FFTW backward plan for 100-element array of Complex{Float64}
(dft-ct-dit/10
(dftw-direct-10/6 "t3bv_10_avx2_128")
(dft-direct-10-x10 "n2bv_10_avx2_128"))
I believe the error really comes from this line. So Zygote can handle the pfft
object but not the pifft
. In the best case, we add in Zygote a suitable fix.
I’m not too familiar with FFTW and Zygote to provide a good solution for that but we can bypass the ScaledPlan issue with a custom adjoint.
using LinearAlgebra, Random, FFTW
using Zygote
Random.seed!(42)
pifft_f(x, pifft, pfft) = pifft * x
Zygote.@adjoint pifft_f(x, pifft, pfft) =
pifft_f(x, pifft, pfft), c̄ -> (1 ./ length(c̄) .* (pfft * c̄), nothing, nothing, nothing)
function main()
n = 10000
x0 = im.*randn(n)
pfft = plan_fft(x0)
pifft = plan_ifft(fft(x0))
f(x) = real(norm(pifft_f(fftshift(pfft*x), pifft, pfft)));
f2(x) = real(norm(ifft(fftshift(fft(x)))));
@show f(x0)
@show f2(x0)
gx =Zygote.gradient(f,x0)
gx2 =Zygote.gradient(f2,x0)
@time gx =Zygote.gradient(f,x0)
@time gx2 =Zygote.gradient(f2,x0)
@show gx[1] ≈ gx2[1]
return
end
main()
julia> include("/tmp/discourse.jl")
f(x0) = 100.96395788714146
f2(x0) = 100.96395788714146
0.000790 seconds (43 allocations: 1.375 MiB)
0.000877 seconds (175 allocations: 1.232 MiB)
gx[1] ≈ gx2[1] = true
The problem is, that Zygote.@adjoint requires a global scope definition. In the case, we don’t have a global pfft
and pifft
we need to provide both arguments to a pifft_f
function.
I know, that that’s annoying but it works and it is still faster than the non-planned. Maybe someone else can knows a more elegant solution.