ForwardDiff and Zygote cannot automatically differentiate (AD) function from C^n to R that uses FFT

I didn’t look in detail into your code, but at least Zygote supports complex numbers and FFTs.

julia> using FFTW, Zygote, ForwardDiff

julia> f(x) = sum(real(fft(x)))
f (generic function with 1 method)

julia> x0 = randn((10, 10));

julia> Zygote.gradient(f, x0)
(Complex{Float64}[100.0 + 0.0im 0.0 + 0.0im … 0.0 + 0.0im 0.0 + 0.0im; 0.0 + 0.0im 0.0 + 0.0im … 0.0 + 0.0im 0.0 + 0.0im; … ; 0.0 + 0.0im 0.0 + 0.0im … 0.0 + 0.0im 0.0 + 0.0im; 0.0 + 0.0im 0.0 + 0.0im … 0.0 + 0.0im 0.0 + 0.0im],)

Also, using a pre-planned FFT, the gradient works:

julia> fft_p = plan_fft(x0)
FFTW forward plan for 10×10 array of Complex{Float64}
  (dft-direct-10-x10 "n2fv_10_avx2_128")
  (dft-direct-10-x10 "n1fv_10_avx2_128"))

julia> g(x) = sum(real(fft_p * x))
g (generic function with 1 method)

julia> Zygote.gradient(g, x0)
(Complex{Float64}[100.0 + 0.0im 0.0 + 0.0im … 0.0 + 0.0im 0.0 + 0.0im; 0.0 + 0.0im 0.0 + 0.0im … 0.0 + 0.0im 0.0 + 0.0im; … ; 0.0 + 0.0im 0.0 + 0.0im … 0.0 + 0.0im 0.0 + 0.0im; 0.0 + 0.0im 0.0 + 0.0im … 0.0 + 0.0im 0.0 + 0.0im],)

julia> Zygote.gradient(g, x0) == Zygote.gradient(f, x0)

You shouldn’t have included plan_fft in code which needs to be derived, because taking the derivative of plan_fft is not really meaningful. You should define a function which depends on fft_p or you should provide the fft_p object as argument.

Hope this helps!

1 Like

I think the issue is that ForwardDiff does not work with FFTW, and Zygote’s hessian function is just hessian(f, x::AbstractArray) = forward_jacobian(x -> gradient(f, x)[1], x)[2] i.e. it calculates the gradient, and then uses ForwardDiff. to differentiate again.

julia> gradient(x->sum(abs2, fftshift(fft(x))), rand(3) .+ im .* randn.())
(ComplexF64[1.495784118230028 - 13.219849172255007im, 4.629297683031974 + 5.3532713394798055im, 2.672875552797116 - 0.3633324783983167im],)

julia> Zygote.hessian(x->sum(abs2, fft(x)), rand(10))
ERROR: type ForwardDiff.Dual{Nothing, Float64, 10} not supported

You might be able to cook something up using say this jacobian definition, so that both derivatives are taken using Zygote.

There might also be a more generic FFT somewhere which does accept dual numbers, i.e. not FFTW. Or as @Tamas_Papp suggests it might be possible to add a rule for this somewhere — Zygote already has rules which write the reverse gradient in terms of other ffts. ForwardDiff does not use ChainRules, but I don’t see why you couldn’t write an explicit method AbstractFFTs.fft(::AbstractArray{Dual{... to handle this.


Thank you. I actually tried computing the plan outside the function, as the code below, but I got the error
type ScaledPlan has no field region and I have no idea what it means. I simplified the code more so perhaps it’s easier to read. I’m not sure why it worked for your code but not mine.

With pre-plan

function light_prop(e_in, prop_kernel,pfft,pifft)
    if ndims(e_in) == 3
        prop_kernel = reshape(prop_kernel,(1,size(prop_kernel)...))
    ek_in  = pfft*ifftshift(e_in)
    ek_out = ek_in.*prop_kernel
    e_out  = fftshift(pifft*ek_out)
    return e_out

nx = 16
ny = 16
npl = 5
nmod = 10
mat(x) = reshape(x,(npl,nx,ny))
e_in = randn(nmod,nx,ny)
e_target = randn(nmod,nx,ny)
pfft = plan_fft(e_in) 
pifft = plan_ifft(e_in)
airKernels   = randn(npl+1,nx,ny)+im.*randn(npl+1,nx,ny)
airPropagate(e_in, jj,pfft,pifft)  = light_prop(e_in, airKernels[jj,:,:],pfft,pifft)

f(x) = real(norm(propagatedOutput(e_in,airKernels,mat(x),pfft,pifft).*conj(e_target))^2/nmod -1);

I tried to execute your code, but it is not working (missing methods, and copying above failed for some reason…).

You should really try to narrow the problem down (1D arrays maybe, no operations like airpropagate etc.).

Obviously, it seems to be a problem with Zygote and FFTW, so try to find a minimal example.
Also post the full stack trace please.

We can try then to fix the issue using custom adjoints in Zygote.

1 Like

Sorry I tested on an old kernel so my broken code passed, you can see I’m a bit new to this. I got the same error using the simplified code:

using LinearAlgebra,Random,ForwardDiff,FFTW,Zygote
n = 100
x0 = im.*randn(n)

pfft = plan_fft(x0) 
pifft = plan_ifft(x0)

f(x) = real(norm(pifft*fftshift(pfft*x0)));
type ScaledPlan has no field region

 [1] getproperty(::AbstractFFTs.ScaledPlan{Complex{Float64},FFTW.cFFTWPlan{Complex{Float64},1,false,1,UnitRange{Int64}},Float64}, ::Symbol) at .\Base.jl:33
 [2] (::Zygote.var"#911#912"{AbstractFFTs.ScaledPlan{Complex{Float64},FFTW.cFFTWPlan{Complex{Float64},1,false,1,UnitRange{Int64}},Float64},Array{Complex{Float64},1}})(::Array{Complex{Float64},1}) at C:\Users\bryan\.julia\packages\Zygote\bRa8J\src\lib\array.jl:826
 [3] (::Zygote.var"#3417#back#913"{Zygote.var"#911#912"{AbstractFFTs.ScaledPlan{Complex{Float64},FFTW.cFFTWPlan{Complex{Float64},1,false,1,UnitRange{Int64}},Float64},Array{Complex{Float64},1}}})(::Array{Complex{Float64},1}) at C:\Users\bryan\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [4] f at .\In[3]:11 [inlined]
 [5] (::typeof(∂(f)))(::Float64) at C:\Users\bryan\.julia\packages\Zygote\bRa8J\src\compiler\interface2.jl:0
 [6] (::Zygote.var"#41#42"{typeof(∂(f))})(::Float64) at C:\Users\bryan\.julia\packages\Zygote\bRa8J\src\compiler\interface.jl:40
 [7] gradient(::Function, ::Array{Complex{Float64},1}) at C:\Users\bryan\.julia\packages\Zygote\bRa8J\src\compiler\interface.jl:49
 [8] top-level scope at In[3]:12
 [9] include_string(::Function, ::Module, ::String, ::String) at .\loading.jl:1091

Zygote is able to handle pfft but not pifft. Let’s look at the types

julia> typeof(pfft)

julia> typeof(pifft)

This is where ScaledPlan comes from, the reason is 0.01:

julia> pfft
FFTW forward plan for 100-element array of Complex{Float64}
  (dftw-direct-10/6 "t3fv_10_avx2_128")
  (dft-direct-10-x10 "n2fv_10_avx2_128"))

julia> pifft
0.01 * FFTW backward plan for 100-element array of Complex{Float64}
  (dftw-direct-10/6 "t3bv_10_avx2_128")
  (dft-direct-10-x10 "n2bv_10_avx2_128"))

I believe the error really comes from this line. So Zygote can handle the pfft object but not the pifft. In the best case, we add in Zygote a suitable fix.

I’m not too familiar with FFTW and Zygote to provide a good solution for that but we can bypass the ScaledPlan issue with a custom adjoint.

using LinearAlgebra, Random, FFTW
using Zygote

pifft_f(x, pifft, pfft) = pifft * x 
Zygote.@adjoint pifft_f(x, pifft, pfft) = 
    pifft_f(x, pifft, pfft), c̄ -> (1 ./ length(c̄) .* (pfft * c̄), nothing, nothing, nothing)

function main()
    n = 10000 
    x0 = im.*randn(n)
    pfft = plan_fft(x0) 
    pifft = plan_ifft(fft(x0))

    f(x) = real(norm(pifft_f(fftshift(pfft*x), pifft, pfft)));
    f2(x) = real(norm(ifft(fftshift(fft(x)))));
    @show f(x0)
    @show f2(x0)
    gx =Zygote.gradient(f,x0)
    gx2 =Zygote.gradient(f2,x0)
    @time gx =Zygote.gradient(f,x0)
    @time gx2 =Zygote.gradient(f2,x0)
    @show gx[1] ≈ gx2[1]


julia> include("/tmp/discourse.jl")
f(x0) = 100.96395788714146
f2(x0) = 100.96395788714146
  0.000790 seconds (43 allocations: 1.375 MiB)
  0.000877 seconds (175 allocations: 1.232 MiB)
gx[1] ≈ gx2[1] = true

The problem is, that Zygote.@adjoint requires a global scope definition. In the case, we don’t have a global pfft and pifft we need to provide both arguments to a pifft_f function.
I know, that that’s annoying but it works and it is still faster than the non-planned. Maybe someone else can knows a more elegant solution.


I think Zygote’s definition here: assumes that all subtypes of Plan have this field, and that’s not true. Although they do all contain this information:

julia> pfft |> typeof |> supertype |> supertype

julia> pifft |> typeof |> supertype

julia> pfft.region

julia> pifft.p.region

I can’t see one, but ideally there would be some function like region(::Plan) which extracts this information.

The quick and dirty hack is to define Base.getproperty(p::AbstractFFTs.ScaledPlan, s::Symbol) = s === :region ? getfield(p, :p).region : getfield(p, s) and see if that works.


Ok, we can avoid this by extracting the scaling and the FFT operator by ourself:

f3(x) = real(norm(pifft.scale .* (pifft.p * fftshift(pfft*x))));
julia> include("/tmp/discourse.jl")
f(x0) = 100.96395788714146
f2(x0) = 100.96395788714146
f3(x0) = 100.96395788714146
  0.000428 seconds (43 allocations: 1.375 MiB)
  0.000589 seconds (175 allocations: 1.232 MiB)
  0.000525 seconds (84 allocations: 1.835 MiB)
gx[1] ≈ gx3[1] = true

Ah! That did the trick. Thank you so much. The only problem left is fft not knowing how to handle ForwardDiff dual type. As @mcabbott suggested I would either need to write my own Jacobian function to call Zygote or write my own method for fft to handle dual type. The complexities of both options are perhaps out of my reach right now, but I appreciate all of your help to uncover the root of the problem.

1 Like

Just do an overload to fft so you don’t send the duals around it but just multiply the result. Differentiating fft is too trivial to actually need to differentiate through the algorithm.

From what I understand, the gradient of FFT should just be the DFT matrix which is constant. So I think I understand what you mean by differentiating fft is too trivial. However, I have trouble wrapping my head around how to feed this gradient to Zygote.hessian function in practice. Would you mind showing me an example of your solution using this function? I’m not sure I understand what you mean by “overload”. Thank you!

using LinearAlgebra,Random,ForwardDiff,FFTW,Zygote
n = 100
x0 = im.*randn(n)

pfft = plan_fft(x0) 
pifft = plan_ifft(x0)
f(x) = real(norm(pifft.scale .* (pifft.p * fftshift(pfft*x))));

Exactly, so you need to use dispatch to directly put that solution into the dual portion. For example, this is what it looks like on a quadrature:

Instead of differentiating through a quadrature, you can define a quadrature on the differentiated function and then stick the result in the partials:

For more information on how ForwardDiff and dual numbers work, check out these lecture notes:

The discussions of the FFTW adjoints might also be helpful:

Thank you, I’m going to give this a try and dive deeper into the technical details.

I think I understand how dual numbers work. The first dimension keeps track of the function value and the second dimension has algebra defined that works exactly like dx and the chain rule, so it keeps track of the value of the gradient. However, I’m still having trouble seeing how to implement it in practice. I think the example above might be a bit advanced for me to understand what exactly is going on as I’m still a beginner in Julia. I would really appreciate if you could show me a toy example, perhaps using this function

using LinearAlgebra,Random,ForwardDiff,FFTW,Zygote
n = 10
x0 = randn(n)

pfft = plan_fft(x0) 
pifft = plan_ifft(x0)
f(x) = real(norm(pifft.scale .* (pifft.p * fftshift(pfft*x))));

Here we need the plan for fft pfft, fftshift, and pifft.p to recognize dual number x (I precompute the plan because I don’t want to differentiate this part). Right now I don’t think they have a method of dealing with dual numbers, so I either have to write methods for dual numbers myself or compute the gradient for these myself and feed to the dual part. I don’t really know how to do either. Something I’m attempting right now is starting with fft:

AbstractFFTs.fft(f::AbstractArray{Dual{TODO}})=ForwardDiff.Dual{fft(f.val), DFT* f.der}

where DFT stands for the DFT matrix. First, I don’t know what types to put as I always have trouble with Julia types. I want types to be general so I put “any” but Julia doesn’t like that. Second, if the gradient is DFT * f.der then I assume it will be very slow due to matrix multiplication. This most likely won’t scale very well for my problem. Third, I have no idea what the gradient is for fftshift.

As for feeding the DFT matrix to the partials myself, I don’t really know how to ask Zygote.hessian/ ForwardDiff to skip finding the gradient of fft and use the one I provide instead. I feel a bit overwhelmed by the intricacies here, and any help is greatly appreciated!

I’m quite busy so I set a reminder for 1 week from now. But others should chime in if you haven’t figured it out. You’re close and on the right track.

1 Like

I was fiddling a bit, and perhaps I paste here before I accidentally close the window again.

FFT is linear, so what we want is f * (x + dx) = f*x + f*dx where f is what plan_fft gives you, and might be worth re-using. Then the rough idea is something like this:

julia> using FFTW, ForwardDiff

julia> x = rand(2)
2-element Vector{Float64}:

julia> f = plan_fft(x)
FFTW forward plan for 2-element array of ComplexF64
(dft-direct-2 "n2fv_2_avx2_128")

julia> xtil = f * x
2-element Vector{ComplexF64}:
   1.11577311528008 + 0.0im
 0.4414749108530729 + 0.0im

julia> x_plus_dx = [ForwardDiff.Dual(x[i], (i,i^2)) for i in 1:2]  # junk data + some duals
2-element Vector{ForwardDiff.Dual{Nothing, Float64, 2}}:

julia> x == ForwardDiff.value.(x_plus_dx)

julia> dx1 = ForwardDiff.partials.(x_plus_dx, 1) # extract the dual part
2-element Vector{Float64}:

julia> dx1til = f * dx1  # apply the same FFT
2-element Vector{ComplexF64}:
  3.0 + 0.0im
 -1.0 + 0.0im

julia> dx2til = f * @view reinterpret(Float64, x_plus_dx)[3:3:end] # another way?
2-element Vector{ComplexF64}:
  5.0 + 0.0im
 -3.0 + 0.0im

julia> xtil_plus = [Complex(   # re-assemble
        ForwardDiff.Dual(real(xtil[i]), (real(dx1til[i]), real(dx2til[i]))),
        ForwardDiff.Dual(imag(xtil[i]), (imag(dx1til[i]), imag(dx2til[i])))
        ) for i in 1:2]
2-element Vector{Complex{ForwardDiff.Dual{Nothing, Float64, 2}}}:
     Dual{Nothing}(1.11577311528008,3.0,5.0) + Dual{Nothing}(0.0,0.0,0.0)*im
 Dual{Nothing}(0.4414749108530729,-1.0,-3.0) + Dual{Nothing}(0.0,0.0,0.0)*im

So there’s going to be a lot of messing with arrays-of-structs. Really x is going to be complex like xtil_plus here, so that’s one more layer. I think this is the right way around, Complex{Dual{...}}.

When you call fft(x_plus_dx), first it converts the matrix to be complex, then it makes a plan, then it applies it. I think these are the steps to make that work here:

# @edit fft(x_plus_dx, 1:1) points me here:
AbstractFFTs.complexfloat(x::AbstractArray{<:ForwardDiff.Dual}) = float.(x .+ 0im)

# @edit fft(x_plus_dx .+ 0im, 1:1) # now this makes a plan, we need:
AbstractFFTs.plan_fft(x::AbstractArray{<:ForwardDiff.Dual}, region=1:ndims(x)) = plan_fft(ForwardDiff.value.(x) .+ 0im, region)
AbstractFFTs.plan_fft(x::AbstractArray{<:Complex{<:ForwardDiff.Dual}}, region=1:ndims(x)) = plan_fft(ForwardDiff.value.(x), region)

# Where I want value() to work on complex duals too:
ForwardDiff.value(x::Complex{<:ForwardDiff.Dual}) = Complex(,
ForwardDiff.partials(x::Complex{<:ForwardDiff.Dual}, n::Int) = Complex(ForwardDiff.partials(, n), ForwardDiff.partials(, n))
ForwardDiff.npartials(x::Complex{<:ForwardDiff.Dual}) = ForwardDiff.npartials(

# Now fft(x_plus_dx) fails at *(p::FFTW.cFFTWPlan{ComplexF64, -1, false, 1, UnitRange{Int64}}, x::Vector{Complex{ForwardDiff.Dual{Nothing, Float64, 2}}}), great! 
function Base.:*(p::AbstractFFTs.Plan, x::AbstractArray{<:Complex{<:ForwardDiff.Dual}})
    xtil = p * ForwardDiff.value.(x)
    ndxs = ForwardDiff.npartials(first(x))
    dxtils = ntuple(ndxs) do n
        p * ForwardDiff.partials.(x, n)
    # dxtils = (dx1til, dx2til)
    ndxs == 2 || error("this won't yet work for npartials(x) != 2, sorry")
    dx1til, dx2til = dxtils
    @. Complex(
        ForwardDiff.Dual(real(xtil), tuple(real(dx1til), real(dx2til))),
        ForwardDiff.Dual(imag(xtil), tuple(imag(dx1til), imag(dx2til))),

fft(x_plus_dx) # works! 
xtil_plus == fft(x_plus_dx)

This does quite a lot of copying. It might be neater to treat the different components by slicing views out of the array, like dx2til above. It seems that FFTW is happy to handle this, if you are consistent:

p1k = plan_fft(rand(ComplexF64, 1000))
r1 = rand(ComplexF64, 1000);
@time p1k * r1; @time p1k * r1;
r2 = @view rand(ComplexF64, 2000)[1:2:end];
p1k * r2; # ArgumentError: FFTW plan applied to wrong-strides array
p2k = plan_fft(r2)
@time p2k * r2; @time p2k * r2; # This does work without copying, compare:
@time copy(r2); @time copy(r2);

Perhaps you can similarly handle the output by switching thigns to fft! on (views of) a copy of the data, rather than making separate slices & re-assembling them. But anyway, that’s a start!

To make my simplest Hessian example work:

AbstractFFTs.plan_bfft(x::AbstractArray{<:ForwardDiff.Dual}, region=1:ndims(x)) = plan_bfft(ForwardDiff.value.(x) .+ 0im, region)
AbstractFFTs.plan_bfft(x::AbstractArray{<:Complex{<:ForwardDiff.Dual}}, region=1:ndims(x)) = plan_bfft(ForwardDiff.value.(x), region)

Zygote.extract(x_plus_dx) # ([0.7786240130665765, 0.3371491022135036], [1.0 2.0; 1.0 4.0])
# @edit Zygote.extract(x_plus_dx)
function Zygote.extract(xs::AbstractArray{<:Complex{<:ForwardDiff.Dual{T,V,N}}}) where {T,V,N}
  J = similar(xs, complex(V), N, length(xs))
  for i = 1:length(xs), j = 1:N
    J[j, i] = xs[i].re.partials.values[j] + im * xs[i].im.partials.values[j]
  x0 = ForwardDiff.value.(xs)
  return x0, J

Zygote.hessian(x->sum(abs2, fft(x)), rand(2)) # ok

Thank you so much Michael, this is incredibly helpful. I’m working through your example and encountered an error at fft(x_plus_dx):

type ForwardDiff.Dual{Nothing,Float64,2} not supported

    fft(::Array{ForwardDiff.Dual{Nothing,Float64,2},1}, ::UnitRange{Int64})@definitions.jl:198
    top-level scope@Local: 1[inlined]

After changing to

AbstractFFTs.complexfloat(x::AbstractArray{<:ForwardDiff.Dual}) = float.(ForwardDiff.value.(x) .+ 0im)

it worked, hopefully this is the correct fix!

Moreover, the last line of code Zygote.hessian threw this error

MethodError: no method matching extract(::Array{Complex{Float64},1})

Closest candidates are:

extract(!Matched::ForwardDiff.Dual) at C:\Users\bryan\.julia\packages\Zygote\bRa8J\src\lib\forward.jl:12

extract(!Matched::AbstractArray{ForwardDiff.Dual{T,V,N},N1} where N1) where {T, V, N} at C:\Users\bryan\.julia\packages\Zygote\bRa8J\src\lib\forward.jl:14

extract(!Matched::AbstractArray{var"#s389",N} where N where var"#s389"<:(Complex{var"#s390"} where var"#s390"<:ForwardDiff.Dual{T,V,N})) where {T, V, N} at C:\Users\bryan\.julia\pluto_notebooks\Remarkable discovery.jl#==#8a61f360-5046-11eb-344f-77b29aed4bdf:2

1. **forward_jacobian** (::Zygote.var"#1216#1217"{var"#23#24"{typeof(abs2),typeof(AbstractFFTs.fft),typeof(sum)}}, ::Array{Float64,1}, ::Val{2})@ *forward.jl:23*
2. **forward_jacobian** (::Function, ::Array{Float64,1})@ *forward.jl:38*
3. **hessian** (::Function, ::Array{Float64,1})@ *utils.jl:113*
4. **top-level scope** @ *[Local: 1](http://localhost:1234/edit?id=264f8330-4ebe-11eb-38f3-23c1e3d1fc03#)* [inlined]

is extract missing a method for complex numbers? I’m not quite sure how to fix this.

Sorry about the complexfloat (copy-paste?) mistake.

extract(::Array{Complex{Float64},1}) doesn’t sound good, I think it should get an array with dual numbers. If these have gone missing along the way then something is wrong!

No worries at all! I’m trying to debug this and realized that right now fft(x_plus_dx) actually gives me complex numbers instead of a complex dual but xtil_plus is a complex dual yet xtil_plus == fft(x_plus_dx) still returns true. That shouldn’t happen, right? Is that not the case for you? I see that this code defines how to apply plan to complex duel, but I’m not sure why it’s not returning complex dual for me.

function Base.:*(p::AbstractFFTs.Plan, x::AbstractArray{<:Complex{<:ForwardDiff.Dual}})
    xtil = p * ForwardDiff.value.(x)
    ndxs = ForwardDiff.npartials(first(x))
    dxtils = ntuple(ndxs) do n
        p * ForwardDiff.partials.(x, n)
    # dxtils = (dx1til, dx2til)
    ndxs == 2 || error("this won't yet work for npartials(x) != 2, sorry")
    dx1til, dx2til = dxtils
    @. Complex(
        ForwardDiff.Dual(real(xtil), tuple(real(dx1til), real(dx2til))),
        ForwardDiff.Dual(imag(xtil), tuple(imag(dx1til), imag(dx2til))),