# ForwardDiff and Zygote cannot automatically differentiate (AD) function from C^n to R that uses FFT

I didn’t look in detail into your code, but at least Zygote supports complex numbers and FFTs.

``````julia> using FFTW, Zygote, ForwardDiff

julia> f(x) = sum(real(fft(x)))
f (generic function with 1 method)

julia> x0 = randn((10, 10));

(Complex{Float64}[100.0 + 0.0im 0.0 + 0.0im … 0.0 + 0.0im 0.0 + 0.0im; 0.0 + 0.0im 0.0 + 0.0im … 0.0 + 0.0im 0.0 + 0.0im; … ; 0.0 + 0.0im 0.0 + 0.0im … 0.0 + 0.0im 0.0 + 0.0im; 0.0 + 0.0im 0.0 + 0.0im … 0.0 + 0.0im 0.0 + 0.0im],)
``````

Also, using a pre-planned FFT, the gradient works:

``````julia> fft_p = plan_fft(x0)
FFTW forward plan for 10×10 array of Complex{Float64}
(dft-rank>=2/1
(dft-direct-10-x10 "n2fv_10_avx2_128")
(dft-direct-10-x10 "n1fv_10_avx2_128"))

julia> g(x) = sum(real(fft_p * x))
g (generic function with 1 method)

(Complex{Float64}[100.0 + 0.0im 0.0 + 0.0im … 0.0 + 0.0im 0.0 + 0.0im; 0.0 + 0.0im 0.0 + 0.0im … 0.0 + 0.0im 0.0 + 0.0im; … ; 0.0 + 0.0im 0.0 + 0.0im … 0.0 + 0.0im 0.0 + 0.0im; 0.0 + 0.0im 0.0 + 0.0im … 0.0 + 0.0im 0.0 + 0.0im],)

true

``````

You shouldn’t have included `plan_fft` in code which needs to be derived, because taking the derivative of `plan_fft` is not really meaningful. You should define a function which depends on `fft_p` or you should provide the `fft_p` object as argument.

Hope this helps!

1 Like

I think the issue is that ForwardDiff does not work with FFTW, and Zygote’s hessian function is just `hessian(f, x::AbstractArray) = forward_jacobian(x -> gradient(f, x)[1], x)[2]` i.e. it calculates the gradient, and then uses ForwardDiff. to differentiate again.

``````julia> gradient(x->sum(abs2, fftshift(fft(x))), rand(3) .+ im .* randn.())
(ComplexF64[1.495784118230028 - 13.219849172255007im, 4.629297683031974 + 5.3532713394798055im, 2.672875552797116 - 0.3633324783983167im],)

julia> Zygote.hessian(x->sum(abs2, fft(x)), rand(10))
ERROR: type ForwardDiff.Dual{Nothing, Float64, 10} not supported
``````

You might be able to cook something up using say this jacobian definition, so that both derivatives are taken using Zygote.

There might also be a more generic FFT somewhere which does accept dual numbers, i.e. not FFTW. Or as @Tamas_Papp suggests it might be possible to add a rule for this somewhere — Zygote already has rules which write the reverse gradient in terms of other `fft`s. ForwardDiff does not use ChainRules, but I don’t see why you couldn’t write an explicit method `AbstractFFTs.fft(::AbstractArray{Dual{...` to handle this.

3 Likes

Thank you. I actually tried computing the plan outside the function, as the code below, but I got the error
`type ScaledPlan has no field region` and I have no idea what it means. I simplified the code more so perhaps it’s easier to read. I’m not sure why it worked for your code but not mine.

With pre-plan

``````function light_prop(e_in, prop_kernel,pfft,pifft)
if ndims(e_in) == 3
prop_kernel = reshape(prop_kernel,(1,size(prop_kernel)...))
end
ek_in  = pfft*ifftshift(e_in)
ek_out = ek_in.*prop_kernel
e_out  = fftshift(pifft*ek_out)
return e_out
end

nx = 16
ny = 16
npl = 5
nmod = 10
mat(x) = reshape(x,(npl,nx,ny))
e_in = randn(nmod,nx,ny)
e_target = randn(nmod,nx,ny)
pfft = plan_fft(e_in)
pifft = plan_ifft(e_in)
airKernels   = randn(npl+1,nx,ny)+im.*randn(npl+1,nx,ny)
airPropagate(e_in, jj,pfft,pifft)  = light_prop(e_in, airKernels[jj,:,:],pfft,pifft)

f(x) = real(norm(propagatedOutput(e_in,airKernels,mat(x),pfft,pifft).*conj(e_target))^2/nmod -1);
``````

I tried to execute your code, but it is not working (missing methods, and copying above failed for some reason…).

You should really try to narrow the problem down (1D arrays maybe, no operations like airpropagate etc.).

Obviously, it seems to be a problem with Zygote and FFTW, so try to find a minimal example.
Also post the full stack trace please.

We can try then to fix the issue using custom adjoints in Zygote.

1 Like

Sorry I tested on an old kernel so my broken code passed, you can see I’m a bit new to this. I got the same error using the simplified code:

``````using LinearAlgebra,Random,ForwardDiff,FFTW,Zygote
n = 100
x0 = im.*randn(n)

pfft = plan_fft(x0)
pifft = plan_ifft(x0)

f(x) = real(norm(pifft*fftshift(pfft*x0)));
``````
``````type ScaledPlan has no field region

Stacktrace:
[1] getproperty(::AbstractFFTs.ScaledPlan{Complex{Float64},FFTW.cFFTWPlan{Complex{Float64},1,false,1,UnitRange{Int64}},Float64}, ::Symbol) at .\Base.jl:33
[2] (::Zygote.var"#911#912"{AbstractFFTs.ScaledPlan{Complex{Float64},FFTW.cFFTWPlan{Complex{Float64},1,false,1,UnitRange{Int64}},Float64},Array{Complex{Float64},1}})(::Array{Complex{Float64},1}) at C:\Users\bryan\.julia\packages\Zygote\bRa8J\src\lib\array.jl:826
[4] f at .\In[3]:11 [inlined]
[5] (::typeof(∂(f)))(::Float64) at C:\Users\bryan\.julia\packages\Zygote\bRa8J\src\compiler\interface2.jl:0
[6] (::Zygote.var"#41#42"{typeof(∂(f))})(::Float64) at C:\Users\bryan\.julia\packages\Zygote\bRa8J\src\compiler\interface.jl:40
[8] top-level scope at In[3]:12
``````

Zygote is able to handle `pfft` but not `pifft`. Let’s look at the types

``````julia> typeof(pfft)
FFTW.cFFTWPlan{Complex{Float64},-1,false,1,UnitRange{Int64}}

julia> typeof(pifft)
AbstractFFTs.ScaledPlan{Complex{Float64},FFTW.cFFTWPlan{Complex{Float64},1,false,1,UnitRange{Int64}},Float64}
``````

This is where `ScaledPlan` comes from, the reason is `0.01`:

``````julia> pfft
FFTW forward plan for 100-element array of Complex{Float64}
(dft-ct-dit/10
(dftw-direct-10/6 "t3fv_10_avx2_128")
(dft-direct-10-x10 "n2fv_10_avx2_128"))

julia> pifft
0.01 * FFTW backward plan for 100-element array of Complex{Float64}
(dft-ct-dit/10
(dftw-direct-10/6 "t3bv_10_avx2_128")
(dft-direct-10-x10 "n2bv_10_avx2_128"))
``````

I believe the error really comes from this line. So Zygote can handle the `pfft` object but not the `pifft`. In the best case, we add in Zygote a suitable fix.

I’m not too familiar with FFTW and Zygote to provide a good solution for that but we can bypass the ScaledPlan issue with a custom adjoint.

``````using LinearAlgebra, Random, FFTW
using Zygote
Random.seed!(42)

pifft_f(x, pifft, pfft) = pifft * x
pifft_f(x, pifft, pfft), c̄ -> (1 ./ length(c̄) .* (pfft * c̄), nothing, nothing, nothing)

function main()
n = 10000
x0 = im.*randn(n)

pfft = plan_fft(x0)
pifft = plan_ifft(fft(x0))

f(x) = real(norm(pifft_f(fftshift(pfft*x), pifft, pfft)));
f2(x) = real(norm(ifft(fftshift(fft(x)))));

@show f(x0)
@show f2(x0)

@show gx[1] ≈ gx2[1]

return
end

main()
``````
``````julia> include("/tmp/discourse.jl")
f(x0) = 100.96395788714146
f2(x0) = 100.96395788714146
0.000790 seconds (43 allocations: 1.375 MiB)
0.000877 seconds (175 allocations: 1.232 MiB)
gx[1] ≈ gx2[1] = true
``````

The problem is, that Zygote.@adjoint requires a global scope definition. In the case, we don’t have a global `pfft` and `pifft` we need to provide both arguments to a `pifft_f` function.
I know, that that’s annoying but it works and it is still faster than the non-planned. Maybe someone else can knows a more elegant solution.

3 Likes

I think Zygote’s definition here: https://github.com/FluxML/Zygote.jl/blob/master/src/lib/array.jl#L820 assumes that all subtypes of `Plan` have this field, and that’s not true. Although they do all contain this information:

``````julia> pfft |> typeof |> supertype |> supertype
AbstractFFTs.Plan{ComplexF64}

julia> pifft |> typeof |> supertype
AbstractFFTs.Plan{ComplexF64}

julia> pfft.region
1:1

julia> pifft.p.region
1:1
``````

I can’t see one, but ideally there would be some function like `region(::Plan)` which extracts this information.

The quick and dirty hack is to define `Base.getproperty(p::AbstractFFTs.ScaledPlan, s::Symbol) = s === :region ? getfield(p, :p).region : getfield(p, s)` and see if that works.

2 Likes

Ok, we can avoid this by extracting the scaling and the FFT operator by ourself:

``````f3(x) = real(norm(pifft.scale .* (pifft.p * fftshift(pfft*x))));
``````
``````julia> include("/tmp/discourse.jl")
f(x0) = 100.96395788714146
f2(x0) = 100.96395788714146
f3(x0) = 100.96395788714146
0.000428 seconds (43 allocations: 1.375 MiB)
0.000589 seconds (175 allocations: 1.232 MiB)
0.000525 seconds (84 allocations: 1.835 MiB)
gx[1] ≈ gx3[1] = true
``````
2 Likes

Ah! That did the trick. Thank you so much. The only problem left is fft not knowing how to handle ForwardDiff dual type. As @mcabbott suggested I would either need to write my own Jacobian function to call Zygote or write my own method for fft to handle dual type. The complexities of both options are perhaps out of my reach right now, but I appreciate all of your help to uncover the root of the problem.

1 Like

Just do an overload to fft so you don’t send the duals around it but just multiply the result. Differentiating fft is too trivial to actually need to differentiate through the algorithm.

From what I understand, the gradient of FFT should just be the DFT matrix which is constant. So I think I understand what you mean by differentiating fft is too trivial. However, I have trouble wrapping my head around how to feed this gradient to Zygote.hessian function in practice. Would you mind showing me an example of your solution using this function? I’m not sure I understand what you mean by “overload”. Thank you!

``````using LinearAlgebra,Random,ForwardDiff,FFTW,Zygote
n = 100
x0 = im.*randn(n)

pfft = plan_fft(x0)
pifft = plan_ifft(x0)
f(x) = real(norm(pifft.scale .* (pifft.p * fftshift(pfft*x))));
``````

Exactly, so you need to use dispatch to directly put that solution into the dual portion. For example, this is what it looks like on a quadrature:

Instead of differentiating through a quadrature, you can define a quadrature on the differentiated function and then stick the result in the partials:

For more information on how ForwardDiff and dual numbers work, check out these lecture notes:

https://mitmath.github.io/18337/lecture8/automatic_differentiation.html

https://github.com/FluxML/Zygote.jl/pull/215

Thank you, I’m going to give this a try and dive deeper into the technical details.

I think I understand how dual numbers work. The first dimension keeps track of the function value and the second dimension has algebra defined that works exactly like dx and the chain rule, so it keeps track of the value of the gradient. However, I’m still having trouble seeing how to implement it in practice. I think the example above might be a bit advanced for me to understand what exactly is going on as I’m still a beginner in Julia. I would really appreciate if you could show me a toy example, perhaps using this function

``````using LinearAlgebra,Random,ForwardDiff,FFTW,Zygote
n = 10
x0 = randn(n)

pfft = plan_fft(x0)
pifft = plan_ifft(x0)
f(x) = real(norm(pifft.scale .* (pifft.p * fftshift(pfft*x))));
``````

Here we need the plan for fft `pfft`, `fftshift`, and `pifft.p` to recognize dual number `x` (I precompute the plan because I don’t want to differentiate this part). Right now I don’t think they have a method of dealing with dual numbers, so I either have to write methods for dual numbers myself or compute the gradient for these myself and feed to the dual part. I don’t really know how to do either. Something I’m attempting right now is starting with fft:

``````AbstractFFTs.fft(f::AbstractArray{Dual{TODO}})=ForwardDiff.Dual{fft(f.val), DFT* f.der}
``````

where DFT stands for the DFT matrix. First, I don’t know what types to put as I always have trouble with Julia types. I want types to be general so I put “any” but Julia doesn’t like that. Second, if the gradient is `DFT * f.der` then I assume it will be very slow due to matrix multiplication. This most likely won’t scale very well for my problem. Third, I have no idea what the gradient is for `fftshift`.

As for feeding the DFT matrix to the partials myself, I don’t really know how to ask Zygote.hessian/ ForwardDiff to skip finding the gradient of `fft` and use the one I provide instead. I feel a bit overwhelmed by the intricacies here, and any help is greatly appreciated!

I’m quite busy so I set a reminder for 1 week from now. But others should chime in if you haven’t figured it out. You’re close and on the right track.

1 Like

I was fiddling a bit, and perhaps I paste here before I accidentally close the window again.

FFT is linear, so what we want is `f * (x + dx) = f*x + f*dx` where `f` is what `plan_fft` gives you, and might be worth re-using. Then the rough idea is something like this:

``````julia> using FFTW, ForwardDiff

julia> x = rand(2)
2-element Vector{Float64}:
0.7786240130665765
0.3371491022135036

julia> f = plan_fft(x)
FFTW forward plan for 2-element array of ComplexF64
(dft-direct-2 "n2fv_2_avx2_128")

julia> xtil = f * x
2-element Vector{ComplexF64}:
1.11577311528008 + 0.0im
0.4414749108530729 + 0.0im

julia> x_plus_dx = [ForwardDiff.Dual(x[i], (i,i^2)) for i in 1:2]  # junk data + some duals
2-element Vector{ForwardDiff.Dual{Nothing, Float64, 2}}:
Dual{Nothing}(0.7786240130665765,1.0,1.0)
Dual{Nothing}(0.3371491022135036,2.0,4.0)

julia> x == ForwardDiff.value.(x_plus_dx)
true

julia> dx1 = ForwardDiff.partials.(x_plus_dx, 1) # extract the dual part
2-element Vector{Float64}:
1.0
2.0

julia> dx1til = f * dx1  # apply the same FFT
2-element Vector{ComplexF64}:
3.0 + 0.0im
-1.0 + 0.0im

julia> dx2til = f * @view reinterpret(Float64, x_plus_dx)[3:3:end] # another way?
2-element Vector{ComplexF64}:
5.0 + 0.0im
-3.0 + 0.0im

julia> xtil_plus = [Complex(   # re-assemble
ForwardDiff.Dual(real(xtil[i]), (real(dx1til[i]), real(dx2til[i]))),
ForwardDiff.Dual(imag(xtil[i]), (imag(dx1til[i]), imag(dx2til[i])))
) for i in 1:2]
2-element Vector{Complex{ForwardDiff.Dual{Nothing, Float64, 2}}}:
Dual{Nothing}(1.11577311528008,3.0,5.0) + Dual{Nothing}(0.0,0.0,0.0)*im
Dual{Nothing}(0.4414749108530729,-1.0,-3.0) + Dual{Nothing}(0.0,0.0,0.0)*im
``````

So there’s going to be a lot of messing with arrays-of-structs. Really `x` is going to be complex like `xtil_plus` here, so that’s one more layer. I think this is the right way around, `Complex{Dual{...}}`.

When you call `fft(x_plus_dx)`, first it converts the matrix to be complex, then it makes a plan, then it applies it. I think these are the steps to make that work here:

``````# @edit fft(x_plus_dx, 1:1) points me here:
AbstractFFTs.complexfloat(x::AbstractArray{<:ForwardDiff.Dual}) = float.(x .+ 0im)

# @edit fft(x_plus_dx .+ 0im, 1:1) # now this makes a plan, we need:
AbstractFFTs.plan_fft(x::AbstractArray{<:ForwardDiff.Dual}, region=1:ndims(x)) = plan_fft(ForwardDiff.value.(x) .+ 0im, region)
AbstractFFTs.plan_fft(x::AbstractArray{<:Complex{<:ForwardDiff.Dual}}, region=1:ndims(x)) = plan_fft(ForwardDiff.value.(x), region)

# Where I want value() to work on complex duals too:
ForwardDiff.value(x::Complex{<:ForwardDiff.Dual}) = Complex(x.re.value, x.im.value)
ForwardDiff.partials(x::Complex{<:ForwardDiff.Dual}, n::Int) = Complex(ForwardDiff.partials(x.re, n), ForwardDiff.partials(x.im, n))
ForwardDiff.npartials(x::Complex{<:ForwardDiff.Dual}) = ForwardDiff.npartials(x.re)

# Now fft(x_plus_dx) fails at *(p::FFTW.cFFTWPlan{ComplexF64, -1, false, 1, UnitRange{Int64}}, x::Vector{Complex{ForwardDiff.Dual{Nothing, Float64, 2}}}), great!
function Base.:*(p::AbstractFFTs.Plan, x::AbstractArray{<:Complex{<:ForwardDiff.Dual}})
xtil = p * ForwardDiff.value.(x)
ndxs = ForwardDiff.npartials(first(x))
dxtils = ntuple(ndxs) do n
p * ForwardDiff.partials.(x, n)
end
# dxtils = (dx1til, dx2til)
ndxs == 2 || error("this won't yet work for npartials(x) != 2, sorry")
dx1til, dx2til = dxtils
@. Complex(
ForwardDiff.Dual(real(xtil), tuple(real(dx1til), real(dx2til))),
ForwardDiff.Dual(imag(xtil), tuple(imag(dx1til), imag(dx2til))),
)
end

fft(x_plus_dx) # works!
xtil_plus == fft(x_plus_dx)

``````

This does quite a lot of copying. It might be neater to treat the different components by slicing views out of the array, like `dx2til` above. It seems that FFTW is happy to handle this, if you are consistent:

``````p1k = plan_fft(rand(ComplexF64, 1000))
r1 = rand(ComplexF64, 1000);
@time p1k * r1; @time p1k * r1;
r2 = @view rand(ComplexF64, 2000)[1:2:end];
p1k * r2; # ArgumentError: FFTW plan applied to wrong-strides array
p2k = plan_fft(r2)
@time p2k * r2; @time p2k * r2; # This does work without copying, compare:
@time copy(r2); @time copy(r2);
``````

Perhaps you can similarly handle the output by switching thigns to `fft!` on (views of) a copy of the data, rather than making separate slices & re-assembling them. But anyway, that’s a start!

To make my simplest Hessian example work:

``````AbstractFFTs.plan_bfft(x::AbstractArray{<:ForwardDiff.Dual}, region=1:ndims(x)) = plan_bfft(ForwardDiff.value.(x) .+ 0im, region)
AbstractFFTs.plan_bfft(x::AbstractArray{<:Complex{<:ForwardDiff.Dual}}, region=1:ndims(x)) = plan_bfft(ForwardDiff.value.(x), region)

Zygote.extract(x_plus_dx) # ([0.7786240130665765, 0.3371491022135036], [1.0 2.0; 1.0 4.0])
# @edit Zygote.extract(x_plus_dx)
function Zygote.extract(xs::AbstractArray{<:Complex{<:ForwardDiff.Dual{T,V,N}}}) where {T,V,N}
J = similar(xs, complex(V), N, length(xs))
for i = 1:length(xs), j = 1:N
J[j, i] = xs[i].re.partials.values[j] + im * xs[i].im.partials.values[j]
end
x0 = ForwardDiff.value.(xs)
return x0, J
end

Zygote.hessian(x->sum(abs2, fft(x)), rand(2)) # ok
``````
4 Likes

Thank you so much Michael, this is incredibly helpful. I’m working through your example and encountered an error at `fft(x_plus_dx)`:

``````type ForwardDiff.Dual{Nothing,Float64,2} not supported

error(::String)@error.jl:33
_fftfloat(::Type{ForwardDiff.Dual{Nothing,Float64,2}})@definitions.jl:22
_fftfloat(::ForwardDiff.Dual{Nothing,Float64,2})@definitions.jl:23
fftfloat(::ForwardDiff.Dual{Nothing,Float64,2})@definitions.jl:18
complexfloat(::Array{ForwardDiff.Dual{Nothing,Float64,2},1})@definitions.jl:31
fft(::Array{ForwardDiff.Dual{Nothing,Float64,2},1}, ::UnitRange{Int64})@definitions.jl:198
top-level scope@Local: 1[inlined]
``````

After changing to

``````AbstractFFTs.complexfloat(x::AbstractArray{<:ForwardDiff.Dual}) = float.(ForwardDiff.value.(x) .+ 0im)
``````

it worked, hopefully this is the correct fix!

Moreover, the last line of code Zygote.hessian threw this error

``````MethodError: no method matching extract(::Array{Complex{Float64},1})

Closest candidates are:

extract(!Matched::ForwardDiff.Dual) at C:\Users\bryan\.julia\packages\Zygote\bRa8J\src\lib\forward.jl:12

extract(!Matched::AbstractArray{ForwardDiff.Dual{T,V,N},N1} where N1) where {T, V, N} at C:\Users\bryan\.julia\packages\Zygote\bRa8J\src\lib\forward.jl:14

extract(!Matched::AbstractArray{var"#s389",N} where N where var"#s389"<:(Complex{var"#s390"} where var"#s390"<:ForwardDiff.Dual{T,V,N})) where {T, V, N} at C:\Users\bryan\.julia\pluto_notebooks\Remarkable discovery.jl#==#8a61f360-5046-11eb-344f-77b29aed4bdf:2

1. **forward_jacobian** (::Zygote.var"#1216#1217"{var"#23#24"{typeof(abs2),typeof(AbstractFFTs.fft),typeof(sum)}}, ::Array{Float64,1}, ::Val{2})@ *forward.jl:23*
2. **forward_jacobian** (::Function, ::Array{Float64,1})@ *forward.jl:38*
3. **hessian** (::Function, ::Array{Float64,1})@ *utils.jl:113*
4. **top-level scope** @ *[Local: 1](http://localhost:1234/edit?id=264f8330-4ebe-11eb-38f3-23c1e3d1fc03#)* [inlined]
``````

is `extract` missing a method for complex numbers? I’m not quite sure how to fix this.

Sorry about the `complexfloat` (copy-paste?) mistake.

`extract(::Array{Complex{Float64},1})` doesn’t sound good, I think it should get an array with dual numbers. If these have gone missing along the way then something is wrong!

No worries at all! I’m trying to debug this and realized that right now `fft(x_plus_dx)` actually gives me complex numbers instead of a complex dual but `xtil_plus` is a complex dual yet `xtil_plus == fft(x_plus_dx)` still returns `true`. That shouldn’t happen, right? Is that not the case for you? I see that this code defines how to apply plan to complex duel, but I’m not sure why it’s not returning complex dual for me.

``````function Base.:*(p::AbstractFFTs.Plan, x::AbstractArray{<:Complex{<:ForwardDiff.Dual}})
xtil = p * ForwardDiff.value.(x)
ndxs = ForwardDiff.npartials(first(x))
dxtils = ntuple(ndxs) do n
p * ForwardDiff.partials.(x, n)
end
# dxtils = (dx1til, dx2til)
ndxs == 2 || error("this won't yet work for npartials(x) != 2, sorry")
dx1til, dx2til = dxtils
@. Complex(
ForwardDiff.Dual(real(xtil), tuple(real(dx1til), real(dx2til))),
ForwardDiff.Dual(imag(xtil), tuple(imag(dx1til), imag(dx2til))),
)
end
``````