ForwardDiff and Zygote cannot automatically differentiate (AD) function from C^n to R that uses FFT

Zygote is able to handle pfft but not pifft. Let’s look at the types

julia> typeof(pfft)
FFTW.cFFTWPlan{Complex{Float64},-1,false,1,UnitRange{Int64}}

julia> typeof(pifft)
AbstractFFTs.ScaledPlan{Complex{Float64},FFTW.cFFTWPlan{Complex{Float64},1,false,1,UnitRange{Int64}},Float64}

This is where ScaledPlan comes from, the reason is 0.01:

julia> pfft
FFTW forward plan for 100-element array of Complex{Float64}
(dft-ct-dit/10
  (dftw-direct-10/6 "t3fv_10_avx2_128")
  (dft-direct-10-x10 "n2fv_10_avx2_128"))

julia> pifft
0.01 * FFTW backward plan for 100-element array of Complex{Float64}
(dft-ct-dit/10
  (dftw-direct-10/6 "t3bv_10_avx2_128")
  (dft-direct-10-x10 "n2bv_10_avx2_128"))

I believe the error really comes from this line. So Zygote can handle the pfft object but not the pifft. In the best case, we add in Zygote a suitable fix.

I’m not too familiar with FFTW and Zygote to provide a good solution for that but we can bypass the ScaledPlan issue with a custom adjoint.

using LinearAlgebra, Random, FFTW
using Zygote
Random.seed!(42)

pifft_f(x, pifft, pfft) = pifft * x 
Zygote.@adjoint pifft_f(x, pifft, pfft) = 
    pifft_f(x, pifft, pfft), c̄ -> (1 ./ length(c̄) .* (pfft * c̄), nothing, nothing, nothing)

function main()
    n = 10000 
    x0 = im.*randn(n)
        
    pfft = plan_fft(x0) 
    pifft = plan_ifft(fft(x0))
        

    f(x) = real(norm(pifft_f(fftshift(pfft*x), pifft, pfft)));
    f2(x) = real(norm(ifft(fftshift(fft(x)))));
        
    @show f(x0)
    @show f2(x0)
        
    gx =Zygote.gradient(f,x0)
    gx2 =Zygote.gradient(f2,x0)
        
    @time gx =Zygote.gradient(f,x0)
    @time gx2 =Zygote.gradient(f2,x0)
        
    @show gx[1] ≈ gx2[1]

    return 
end

main()
julia> include("/tmp/discourse.jl")
f(x0) = 100.96395788714146
f2(x0) = 100.96395788714146
  0.000790 seconds (43 allocations: 1.375 MiB)
  0.000877 seconds (175 allocations: 1.232 MiB)
gx[1] ≈ gx2[1] = true

The problem is, that Zygote.@adjoint requires a global scope definition. In the case, we don’t have a global pfft and pifft we need to provide both arguments to a pifft_f function.
I know, that that’s annoying but it works and it is still faster than the non-planned. Maybe someone else can knows a more elegant solution.

3 Likes