ForwardDiff and Zygote cannot automatically differentiate (AD) function from C^n to R that uses FFT

I’ve encountered so many bugs using ForwardDiff and Zygote to do automatic differentiation for my objective function that takes complex vector as input and outputs a real number. I don’t think I have the expertise to debug myself so I’m posting my problem here. ForwardDiff is just bad and a quick search suggested that it doesn’t support complex stuff as well as Zygote. For Zygote, the most common error I got is about how plan_fft doesn’t have a method for the ForwardDiff AD datatype:
MethodError: no method matching plan_fft(::Array{Complex{ForwardDiff.Dual{Nothing,Float64,12}},3}, ::UnitRange{Int64})
I then tried to compute the plan ahead of time
pfft = plan_fft(e_in) and pifft = plan_ifft(e_in) and feed them into the functions using *, but I got errors
type ScaledPlan has no field region. Google search resulted in zero answers.

Moreover, my objective function is not supposed to return complex values, but I got error saying output is complex, so I had to add real() to my objective function.

At this point I’m not sure if I’m just using it incorrectly or is this type of objective function involving complex inputs and fast fourier transform doomed for AD packages like ForwardDiff and Zygote. I’ve attached an abbreviated version of my code below with dummy values. Thank you so much for your help.

using LinearAlgebra,Polynomials,Printf,Random,ForwardDiff,PyPlot,PyCall,OffsetArrays,FFTW,Zygote
np=pyimport("numpy")

function make_prop_kernel( sz, z; dl=2,lmbd=0.5)
    nx = sz[1]
    ny = sz[2]
    k=2*pi/lmbd # wavenumber
    dkx=2*pi/((nx-1)*dl)
    dky=2*pi/((ny-1)*dl)
    kx=(LinRange(0,nx-1,nx).-nx/2)*dkx
    ky=(LinRange(0,ny-1,ny).-ny/2)*dky

    inflate(f, kx, ky) = [f(x,y) for x in kx, y in ky]
    f(kx,ky)=exp(1im*sqrt(k^2-kx^2-ky^2)*z)

    prop_kernel=inflate(f,kx,ky)

    prop_kernel = ifftshift(prop_kernel)
    
    return prop_kernel
end

function light_prop(e_in, prop_kernel)
    if ndims(e_in) == 3
        prop_kernel = reshape(prop_kernel,(1,size(prop_kernel)...))
    end
    ek_in  = fft(ifftshift(e_in))
    ek_out = ek_in.*prop_kernel
    e_out  = fftshift(ifft(ek_out))
    return e_out
end

function phase_mod(e_in, theta; samp_ratio=1)
    #=
    e_in is the input field
    theta is the phase mask
    samp_ratio is the pixel size ratio between the phase mask and the e field 
    =#
    if ndims(e_in) == 2
        if samp_ratio == 1
            e_out = e_in.*exp.(1im*theta)
        else
            e_out = e_in.*kron(exp.(1im*theta),ones((samp_ratio,samp_ratio)))
        end
    elseif ndims(e_in) == 3
        if samp_ratio == 1
            M = exp.(1im*theta)
            e_out = e_in.*reshape(M,(1,size(M)...))
        else
            M = kron(exp.(1im*theta),ones((samp_ratio,samp_ratio)))
            e_out = e_in.*reshape(M,(1,size(M)...))
        end
    end
    return e_out
end

function propagatedOutput(e_in,airKernels,theta)
    e_out = e_in
    for jj = 1:size(theta,1)  # loop over plate
        e_out = airPropagate(e_out, jj)
        e_out = phase_mod(e_out, theta[jj,:,:], samp_ratio=1)
    end
    return e_out
end

function symmetricADHessian(f,x)
    Hx = Zygote.hessian(f,x)
    return LinearAlgebra.tril(Hx,-1)+LinearAlgebra.tril(Hx)'
end

nx = 16
ny = 16
npl = 5
nmod = 10
mat(x) = reshape(x,(npl,nx,ny))
e_in = randn(nmod,nx,ny)
e_target = randn(nmod,nx,ny)

d = [2e4,2.5e4,2.5e4,2.5e4,2.5e4,2e4]
airKernels   = zeros(ComplexF64,(npl+1,nx,ny))
for jj = 1:npl+1
  airKernels[jj,:,:] = make_prop_kernel( (nx,ny), d[jj])
end
airPropagate(e_in, jj)  = light_prop(e_in, airKernels[jj,:,:])

#objective function
f(x) = real(norm(propagatedOutput(e_in,airKernels,mat(x)).*conj(e_target))^2/nmod -1);

x0 = zeros(npl,nx,ny)
gx=Zygote.gradient(f,x0)
Hx=symmetricADHessian(f,x0)

You may have to define a rule for the derivative. See eg the excellent docs of

3 Likes

I didn’t look in detail into your code, but at least Zygote supports complex numbers and FFTs.

julia> using FFTW, Zygote, ForwardDiff

julia> f(x) = sum(real(fft(x)))
f (generic function with 1 method)

julia> x0 = randn((10, 10));

julia> Zygote.gradient(f, x0)
(Complex{Float64}[100.0 + 0.0im 0.0 + 0.0im … 0.0 + 0.0im 0.0 + 0.0im; 0.0 + 0.0im 0.0 + 0.0im … 0.0 + 0.0im 0.0 + 0.0im; … ; 0.0 + 0.0im 0.0 + 0.0im … 0.0 + 0.0im 0.0 + 0.0im; 0.0 + 0.0im 0.0 + 0.0im … 0.0 + 0.0im 0.0 + 0.0im],)

Also, using a pre-planned FFT, the gradient works:

julia> fft_p = plan_fft(x0)
FFTW forward plan for 10×10 array of Complex{Float64}
(dft-rank>=2/1
  (dft-direct-10-x10 "n2fv_10_avx2_128")
  (dft-direct-10-x10 "n1fv_10_avx2_128"))

julia> g(x) = sum(real(fft_p * x))
g (generic function with 1 method)

julia> Zygote.gradient(g, x0)
(Complex{Float64}[100.0 + 0.0im 0.0 + 0.0im … 0.0 + 0.0im 0.0 + 0.0im; 0.0 + 0.0im 0.0 + 0.0im … 0.0 + 0.0im 0.0 + 0.0im; … ; 0.0 + 0.0im 0.0 + 0.0im … 0.0 + 0.0im 0.0 + 0.0im; 0.0 + 0.0im 0.0 + 0.0im … 0.0 + 0.0im 0.0 + 0.0im],)

julia> Zygote.gradient(g, x0) == Zygote.gradient(f, x0)
true

You shouldn’t have included plan_fft in code which needs to be derived, because taking the derivative of plan_fft is not really meaningful. You should define a function which depends on fft_p or you should provide the fft_p object as argument.

Hope this helps!

1 Like

I think the issue is that ForwardDiff does not work with FFTW, and Zygote’s hessian function is just hessian(f, x::AbstractArray) = forward_jacobian(x -> gradient(f, x)[1], x)[2] i.e. it calculates the gradient, and then uses ForwardDiff. to differentiate again.

julia> gradient(x->sum(abs2, fftshift(fft(x))), rand(3) .+ im .* randn.())
(ComplexF64[1.495784118230028 - 13.219849172255007im, 4.629297683031974 + 5.3532713394798055im, 2.672875552797116 - 0.3633324783983167im],)

julia> Zygote.hessian(x->sum(abs2, fft(x)), rand(10))
ERROR: type ForwardDiff.Dual{Nothing, Float64, 10} not supported

You might be able to cook something up using say this jacobian definition, so that both derivatives are taken using Zygote.

There might also be a more generic FFT somewhere which does accept dual numbers, i.e. not FFTW. Or as @Tamas_Papp suggests it might be possible to add a rule for this somewhere — Zygote already has rules which write the reverse gradient in terms of other ffts. ForwardDiff does not use ChainRules, but I don’t see why you couldn’t write an explicit method AbstractFFTs.fft(::AbstractArray{Dual{... to handle this.

3 Likes

Thank you. I actually tried computing the plan outside the function, as the code below, but I got the error
type ScaledPlan has no field region and I have no idea what it means. I simplified the code more so perhaps it’s easier to read. I’m not sure why it worked for your code but not mine.

With pre-plan

function light_prop(e_in, prop_kernel,pfft,pifft)
    if ndims(e_in) == 3
        prop_kernel = reshape(prop_kernel,(1,size(prop_kernel)...))
    end
    ek_in  = pfft*ifftshift(e_in)
    ek_out = ek_in.*prop_kernel
    e_out  = fftshift(pifft*ek_out)
    return e_out
end

nx = 16
ny = 16
npl = 5
nmod = 10
mat(x) = reshape(x,(npl,nx,ny))
e_in = randn(nmod,nx,ny)
e_target = randn(nmod,nx,ny)
pfft = plan_fft(e_in) 
pifft = plan_ifft(e_in)
airKernels   = randn(npl+1,nx,ny)+im.*randn(npl+1,nx,ny)
airPropagate(e_in, jj,pfft,pifft)  = light_prop(e_in, airKernels[jj,:,:],pfft,pifft)

f(x) = real(norm(propagatedOutput(e_in,airKernels,mat(x),pfft,pifft).*conj(e_target))^2/nmod -1);
gx=Zygote.gradient(f,x0)

I tried to execute your code, but it is not working (missing methods, and copying above failed for some reason…).

You should really try to narrow the problem down (1D arrays maybe, no operations like airpropagate etc.).

Obviously, it seems to be a problem with Zygote and FFTW, so try to find a minimal example.
Also post the full stack trace please.

We can try then to fix the issue using custom adjoints in Zygote.

1 Like

Sorry I tested on an old kernel so my broken code passed, you can see I’m a bit new to this. I got the same error using the simplified code:

using LinearAlgebra,Random,ForwardDiff,FFTW,Zygote
n = 100
x0 = im.*randn(n)

pfft = plan_fft(x0) 
pifft = plan_ifft(x0)

f(x) = real(norm(pifft*fftshift(pfft*x0)));
gx=Zygote.gradient(f,x0)
type ScaledPlan has no field region

Stacktrace:
 [1] getproperty(::AbstractFFTs.ScaledPlan{Complex{Float64},FFTW.cFFTWPlan{Complex{Float64},1,false,1,UnitRange{Int64}},Float64}, ::Symbol) at .\Base.jl:33
 [2] (::Zygote.var"#911#912"{AbstractFFTs.ScaledPlan{Complex{Float64},FFTW.cFFTWPlan{Complex{Float64},1,false,1,UnitRange{Int64}},Float64},Array{Complex{Float64},1}})(::Array{Complex{Float64},1}) at C:\Users\bryan\.julia\packages\Zygote\bRa8J\src\lib\array.jl:826
 [3] (::Zygote.var"#3417#back#913"{Zygote.var"#911#912"{AbstractFFTs.ScaledPlan{Complex{Float64},FFTW.cFFTWPlan{Complex{Float64},1,false,1,UnitRange{Int64}},Float64},Array{Complex{Float64},1}}})(::Array{Complex{Float64},1}) at C:\Users\bryan\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [4] f at .\In[3]:11 [inlined]
 [5] (::typeof(∂(f)))(::Float64) at C:\Users\bryan\.julia\packages\Zygote\bRa8J\src\compiler\interface2.jl:0
 [6] (::Zygote.var"#41#42"{typeof(∂(f))})(::Float64) at C:\Users\bryan\.julia\packages\Zygote\bRa8J\src\compiler\interface.jl:40
 [7] gradient(::Function, ::Array{Complex{Float64},1}) at C:\Users\bryan\.julia\packages\Zygote\bRa8J\src\compiler\interface.jl:49
 [8] top-level scope at In[3]:12
 [9] include_string(::Function, ::Module, ::String, ::String) at .\loading.jl:1091

Zygote is able to handle pfft but not pifft. Let’s look at the types

julia> typeof(pfft)
FFTW.cFFTWPlan{Complex{Float64},-1,false,1,UnitRange{Int64}}

julia> typeof(pifft)
AbstractFFTs.ScaledPlan{Complex{Float64},FFTW.cFFTWPlan{Complex{Float64},1,false,1,UnitRange{Int64}},Float64}

This is where ScaledPlan comes from, the reason is 0.01:

julia> pfft
FFTW forward plan for 100-element array of Complex{Float64}
(dft-ct-dit/10
  (dftw-direct-10/6 "t3fv_10_avx2_128")
  (dft-direct-10-x10 "n2fv_10_avx2_128"))

julia> pifft
0.01 * FFTW backward plan for 100-element array of Complex{Float64}
(dft-ct-dit/10
  (dftw-direct-10/6 "t3bv_10_avx2_128")
  (dft-direct-10-x10 "n2bv_10_avx2_128"))

I believe the error really comes from this line. So Zygote can handle the pfft object but not the pifft. In the best case, we add in Zygote a suitable fix.

I’m not too familiar with FFTW and Zygote to provide a good solution for that but we can bypass the ScaledPlan issue with a custom adjoint.

using LinearAlgebra, Random, FFTW
using Zygote
Random.seed!(42)

pifft_f(x, pifft, pfft) = pifft * x 
Zygote.@adjoint pifft_f(x, pifft, pfft) = 
    pifft_f(x, pifft, pfft), c̄ -> (1 ./ length(c̄) .* (pfft * c̄), nothing, nothing, nothing)

function main()
    n = 10000 
    x0 = im.*randn(n)
        
    pfft = plan_fft(x0) 
    pifft = plan_ifft(fft(x0))
        

    f(x) = real(norm(pifft_f(fftshift(pfft*x), pifft, pfft)));
    f2(x) = real(norm(ifft(fftshift(fft(x)))));
        
    @show f(x0)
    @show f2(x0)
        
    gx =Zygote.gradient(f,x0)
    gx2 =Zygote.gradient(f2,x0)
        
    @time gx =Zygote.gradient(f,x0)
    @time gx2 =Zygote.gradient(f2,x0)
        
    @show gx[1] ≈ gx2[1]

    return 
end

main()
julia> include("/tmp/discourse.jl")
f(x0) = 100.96395788714146
f2(x0) = 100.96395788714146
  0.000790 seconds (43 allocations: 1.375 MiB)
  0.000877 seconds (175 allocations: 1.232 MiB)
gx[1] ≈ gx2[1] = true

The problem is, that Zygote.@adjoint requires a global scope definition. In the case, we don’t have a global pfft and pifft we need to provide both arguments to a pifft_f function.
I know, that that’s annoying but it works and it is still faster than the non-planned. Maybe someone else can knows a more elegant solution.

1 Like

I think Zygote’s definition here: Zygote.jl/array.jl at master · FluxML/Zygote.jl · GitHub assumes that all subtypes of Plan have this field, and that’s not true. Although they do all contain this information:

julia> pfft |> typeof |> supertype |> supertype
AbstractFFTs.Plan{ComplexF64}

julia> pifft |> typeof |> supertype
AbstractFFTs.Plan{ComplexF64}

julia> pfft.region
1:1

julia> pifft.p.region
1:1

I can’t see one, but ideally there would be some function like region(::Plan) which extracts this information.

The quick and dirty hack is to define Base.getproperty(p::AbstractFFTs.ScaledPlan, s::Symbol) = s === :region ? getfield(p, :p).region : getfield(p, s) and see if that works.

2 Likes

Ok, we can avoid this by extracting the scaling and the FFT operator by ourself:

f3(x) = real(norm(pifft.scale .* (pifft.p * fftshift(pfft*x))));
julia> include("/tmp/discourse.jl")
f(x0) = 100.96395788714146
f2(x0) = 100.96395788714146
f3(x0) = 100.96395788714146
  0.000428 seconds (43 allocations: 1.375 MiB)
  0.000589 seconds (175 allocations: 1.232 MiB)
  0.000525 seconds (84 allocations: 1.835 MiB)
gx[1] ≈ gx3[1] = true
1 Like

Ah! That did the trick. Thank you so much. The only problem left is fft not knowing how to handle ForwardDiff dual type. As @mcabbott suggested I would either need to write my own Jacobian function to call Zygote or write my own method for fft to handle dual type. The complexities of both options are perhaps out of my reach right now, but I appreciate all of your help to uncover the root of the problem.

1 Like

Just do an overload to fft so you don’t send the duals around it but just multiply the result. Differentiating fft is too trivial to actually need to differentiate through the algorithm.

From what I understand, the gradient of FFT should just be the DFT matrix which is constant. So I think I understand what you mean by differentiating fft is too trivial. However, I have trouble wrapping my head around how to feed this gradient to Zygote.hessian function in practice. Would you mind showing me an example of your solution using this function? I’m not sure I understand what you mean by “overload”. Thank you!

using LinearAlgebra,Random,ForwardDiff,FFTW,Zygote
n = 100
x0 = im.*randn(n)

pfft = plan_fft(x0) 
pifft = plan_ifft(x0)
f(x) = real(norm(pifft.scale .* (pifft.p * fftshift(pfft*x))));

Exactly, so you need to use dispatch to directly put that solution into the dual portion. For example, this is what it looks like on a quadrature:

Instead of differentiating through a quadrature, you can define a quadrature on the differentiated function and then stick the result in the partials:

For more information on how ForwardDiff and dual numbers work, check out these lecture notes:

https://mitmath.github.io/18337/lecture8/automatic_differentiation.html

The discussions of the FFTW adjoints might also be helpful:

Thank you, I’m going to give this a try and dive deeper into the technical details.

I think I understand how dual numbers work. The first dimension keeps track of the function value and the second dimension has algebra defined that works exactly like dx and the chain rule, so it keeps track of the value of the gradient. However, I’m still having trouble seeing how to implement it in practice. I think the example above might be a bit advanced for me to understand what exactly is going on as I’m still a beginner in Julia. I would really appreciate if you could show me a toy example, perhaps using this function

using LinearAlgebra,Random,ForwardDiff,FFTW,Zygote
n = 10
x0 = randn(n)

pfft = plan_fft(x0) 
pifft = plan_ifft(x0)
f(x) = real(norm(pifft.scale .* (pifft.p * fftshift(pfft*x))));

Here we need the plan for fft pfft, fftshift, and pifft.p to recognize dual number x (I precompute the plan because I don’t want to differentiate this part). Right now I don’t think they have a method of dealing with dual numbers, so I either have to write methods for dual numbers myself or compute the gradient for these myself and feed to the dual part. I don’t really know how to do either. Something I’m attempting right now is starting with fft:

AbstractFFTs.fft(f::AbstractArray{Dual{TODO}})=ForwardDiff.Dual{fft(f.val), DFT* f.der}

where DFT stands for the DFT matrix. First, I don’t know what types to put as I always have trouble with Julia types. I want types to be general so I put “any” but Julia doesn’t like that. Second, if the gradient is DFT * f.der then I assume it will be very slow due to matrix multiplication. This most likely won’t scale very well for my problem. Third, I have no idea what the gradient is for fftshift.

As for feeding the DFT matrix to the partials myself, I don’t really know how to ask Zygote.hessian/ ForwardDiff to skip finding the gradient of fft and use the one I provide instead. I feel a bit overwhelmed by the intricacies here, and any help is greatly appreciated!

I’m quite busy so I set a reminder for 1 week from now. But others should chime in if you haven’t figured it out. You’re close and on the right track.

1 Like

I was fiddling a bit, and perhaps I paste here before I accidentally close the window again.

FFT is linear, so what we want is f * (x + dx) = f*x + f*dx where f is what plan_fft gives you, and might be worth re-using. Then the rough idea is something like this:

julia> using FFTW, ForwardDiff

julia> x = rand(2)
2-element Vector{Float64}:
 0.7786240130665765
 0.3371491022135036

julia> f = plan_fft(x)
FFTW forward plan for 2-element array of ComplexF64
(dft-direct-2 "n2fv_2_avx2_128")

julia> xtil = f * x
2-element Vector{ComplexF64}:
   1.11577311528008 + 0.0im
 0.4414749108530729 + 0.0im

julia> x_plus_dx = [ForwardDiff.Dual(x[i], (i,i^2)) for i in 1:2]  # junk data + some duals
2-element Vector{ForwardDiff.Dual{Nothing, Float64, 2}}:
 Dual{Nothing}(0.7786240130665765,1.0,1.0)
 Dual{Nothing}(0.3371491022135036,2.0,4.0)

julia> x == ForwardDiff.value.(x_plus_dx)
true

julia> dx1 = ForwardDiff.partials.(x_plus_dx, 1) # extract the dual part
2-element Vector{Float64}:
 1.0
 2.0

julia> dx1til = f * dx1  # apply the same FFT
2-element Vector{ComplexF64}:
  3.0 + 0.0im
 -1.0 + 0.0im

julia> dx2til = f * @view reinterpret(Float64, x_plus_dx)[3:3:end] # another way?
2-element Vector{ComplexF64}:
  5.0 + 0.0im
 -3.0 + 0.0im

julia> xtil_plus = [Complex(   # re-assemble
        ForwardDiff.Dual(real(xtil[i]), (real(dx1til[i]), real(dx2til[i]))),
        ForwardDiff.Dual(imag(xtil[i]), (imag(dx1til[i]), imag(dx2til[i])))
        ) for i in 1:2]
2-element Vector{Complex{ForwardDiff.Dual{Nothing, Float64, 2}}}:
     Dual{Nothing}(1.11577311528008,3.0,5.0) + Dual{Nothing}(0.0,0.0,0.0)*im
 Dual{Nothing}(0.4414749108530729,-1.0,-3.0) + Dual{Nothing}(0.0,0.0,0.0)*im

So there’s going to be a lot of messing with arrays-of-structs. Really x is going to be complex like xtil_plus here, so that’s one more layer. I think this is the right way around, Complex{Dual{...}}.

When you call fft(x_plus_dx), first it converts the matrix to be complex, then it makes a plan, then it applies it. I think these are the steps to make that work here:

# @edit fft(x_plus_dx, 1:1) points me here:
AbstractFFTs.complexfloat(x::AbstractArray{<:ForwardDiff.Dual}) = float.(x .+ 0im)

# @edit fft(x_plus_dx .+ 0im, 1:1) # now this makes a plan, we need:
AbstractFFTs.plan_fft(x::AbstractArray{<:ForwardDiff.Dual}, region=1:ndims(x)) = plan_fft(ForwardDiff.value.(x) .+ 0im, region)
AbstractFFTs.plan_fft(x::AbstractArray{<:Complex{<:ForwardDiff.Dual}}, region=1:ndims(x)) = plan_fft(ForwardDiff.value.(x), region)

# Where I want value() to work on complex duals too:
ForwardDiff.value(x::Complex{<:ForwardDiff.Dual}) = Complex(x.re.value, x.im.value)
ForwardDiff.partials(x::Complex{<:ForwardDiff.Dual}, n::Int) = Complex(ForwardDiff.partials(x.re, n), ForwardDiff.partials(x.im, n))
ForwardDiff.npartials(x::Complex{<:ForwardDiff.Dual}) = ForwardDiff.npartials(x.re)

# Now fft(x_plus_dx) fails at *(p::FFTW.cFFTWPlan{ComplexF64, -1, false, 1, UnitRange{Int64}}, x::Vector{Complex{ForwardDiff.Dual{Nothing, Float64, 2}}}), great! 
function Base.:*(p::AbstractFFTs.Plan, x::AbstractArray{<:Complex{<:ForwardDiff.Dual}})
    xtil = p * ForwardDiff.value.(x)
    ndxs = ForwardDiff.npartials(first(x))
    dxtils = ntuple(ndxs) do n
        p * ForwardDiff.partials.(x, n)
    end
    # dxtils = (dx1til, dx2til)
    ndxs == 2 || error("this won't yet work for npartials(x) != 2, sorry")
    dx1til, dx2til = dxtils
    @. Complex(
        ForwardDiff.Dual(real(xtil), tuple(real(dx1til), real(dx2til))),
        ForwardDiff.Dual(imag(xtil), tuple(imag(dx1til), imag(dx2til))),
        )
end

fft(x_plus_dx) # works! 
xtil_plus == fft(x_plus_dx)

This does quite a lot of copying. It might be neater to treat the different components by slicing views out of the array, like dx2til above. It seems that FFTW is happy to handle this, if you are consistent:

p1k = plan_fft(rand(ComplexF64, 1000))
r1 = rand(ComplexF64, 1000);
@time p1k * r1; @time p1k * r1;
r2 = @view rand(ComplexF64, 2000)[1:2:end];
p1k * r2; # ArgumentError: FFTW plan applied to wrong-strides array
p2k = plan_fft(r2)
@time p2k * r2; @time p2k * r2; # This does work without copying, compare:
@time copy(r2); @time copy(r2);

Perhaps you can similarly handle the output by switching thigns to fft! on (views of) a copy of the data, rather than making separate slices & re-assembling them. But anyway, that’s a start!

To make my simplest Hessian example work:

AbstractFFTs.plan_bfft(x::AbstractArray{<:ForwardDiff.Dual}, region=1:ndims(x)) = plan_bfft(ForwardDiff.value.(x) .+ 0im, region)
AbstractFFTs.plan_bfft(x::AbstractArray{<:Complex{<:ForwardDiff.Dual}}, region=1:ndims(x)) = plan_bfft(ForwardDiff.value.(x), region)

Zygote.extract(x_plus_dx) # ([0.7786240130665765, 0.3371491022135036], [1.0 2.0; 1.0 4.0])
# @edit Zygote.extract(x_plus_dx)
function Zygote.extract(xs::AbstractArray{<:Complex{<:ForwardDiff.Dual{T,V,N}}}) where {T,V,N}
  J = similar(xs, complex(V), N, length(xs))
  for i = 1:length(xs), j = 1:N
    J[j, i] = xs[i].re.partials.values[j] + im * xs[i].im.partials.values[j]
  end
  x0 = ForwardDiff.value.(xs)
  return x0, J
end

Zygote.hessian(x->sum(abs2, fft(x)), rand(2)) # ok
3 Likes

Thank you so much Michael, this is incredibly helpful. I’m working through your example and encountered an error at fft(x_plus_dx):

type ForwardDiff.Dual{Nothing,Float64,2} not supported

    error(::String)@error.jl:33
    _fftfloat(::Type{ForwardDiff.Dual{Nothing,Float64,2}})@definitions.jl:22
    _fftfloat(::ForwardDiff.Dual{Nothing,Float64,2})@definitions.jl:23
    fftfloat(::ForwardDiff.Dual{Nothing,Float64,2})@definitions.jl:18
    complexfloat(::Array{ForwardDiff.Dual{Nothing,Float64,2},1})@definitions.jl:31
    fft(::Array{ForwardDiff.Dual{Nothing,Float64,2},1}, ::UnitRange{Int64})@definitions.jl:198
    top-level scope@Local: 1[inlined]

After changing to

AbstractFFTs.complexfloat(x::AbstractArray{<:ForwardDiff.Dual}) = float.(ForwardDiff.value.(x) .+ 0im)

it worked, hopefully this is the correct fix!

Moreover, the last line of code Zygote.hessian threw this error

MethodError: no method matching extract(::Array{Complex{Float64},1})

Closest candidates are:

extract(!Matched::ForwardDiff.Dual) at C:\Users\bryan\.julia\packages\Zygote\bRa8J\src\lib\forward.jl:12

extract(!Matched::AbstractArray{ForwardDiff.Dual{T,V,N},N1} where N1) where {T, V, N} at C:\Users\bryan\.julia\packages\Zygote\bRa8J\src\lib\forward.jl:14

extract(!Matched::AbstractArray{var"#s389",N} where N where var"#s389"<:(Complex{var"#s390"} where var"#s390"<:ForwardDiff.Dual{T,V,N})) where {T, V, N} at C:\Users\bryan\.julia\pluto_notebooks\Remarkable discovery.jl#==#8a61f360-5046-11eb-344f-77b29aed4bdf:2

1. **forward_jacobian** (::Zygote.var"#1216#1217"{var"#23#24"{typeof(abs2),typeof(AbstractFFTs.fft),typeof(sum)}}, ::Array{Float64,1}, ::Val{2})@ *forward.jl:23*
2. **forward_jacobian** (::Function, ::Array{Float64,1})@ *forward.jl:38*
3. **hessian** (::Function, ::Array{Float64,1})@ *utils.jl:113*
4. **top-level scope** @ *[Local: 1](http://localhost:1234/edit?id=264f8330-4ebe-11eb-38f3-23c1e3d1fc03#)* [inlined]

is extract missing a method for complex numbers? I’m not quite sure how to fix this.