ForwardDiff and Zygote cannot automatically differentiate (AD) function from C^n to R that uses FFT

I was fiddling a bit, and perhaps I paste here before I accidentally close the window again.

FFT is linear, so what we want is f * (x + dx) = f*x + f*dx where f is what plan_fft gives you, and might be worth re-using. Then the rough idea is something like this:

julia> using FFTW, ForwardDiff

julia> x = rand(2)
2-element Vector{Float64}:
 0.7786240130665765
 0.3371491022135036

julia> f = plan_fft(x)
FFTW forward plan for 2-element array of ComplexF64
(dft-direct-2 "n2fv_2_avx2_128")

julia> xtil = f * x
2-element Vector{ComplexF64}:
   1.11577311528008 + 0.0im
 0.4414749108530729 + 0.0im

julia> x_plus_dx = [ForwardDiff.Dual(x[i], (i,i^2)) for i in 1:2]  # junk data + some duals
2-element Vector{ForwardDiff.Dual{Nothing, Float64, 2}}:
 Dual{Nothing}(0.7786240130665765,1.0,1.0)
 Dual{Nothing}(0.3371491022135036,2.0,4.0)

julia> x == ForwardDiff.value.(x_plus_dx)
true

julia> dx1 = ForwardDiff.partials.(x_plus_dx, 1) # extract the dual part
2-element Vector{Float64}:
 1.0
 2.0

julia> dx1til = f * dx1  # apply the same FFT
2-element Vector{ComplexF64}:
  3.0 + 0.0im
 -1.0 + 0.0im

julia> dx2til = f * @view reinterpret(Float64, x_plus_dx)[3:3:end] # another way?
2-element Vector{ComplexF64}:
  5.0 + 0.0im
 -3.0 + 0.0im

julia> xtil_plus = [Complex(   # re-assemble
        ForwardDiff.Dual(real(xtil[i]), (real(dx1til[i]), real(dx2til[i]))),
        ForwardDiff.Dual(imag(xtil[i]), (imag(dx1til[i]), imag(dx2til[i])))
        ) for i in 1:2]
2-element Vector{Complex{ForwardDiff.Dual{Nothing, Float64, 2}}}:
     Dual{Nothing}(1.11577311528008,3.0,5.0) + Dual{Nothing}(0.0,0.0,0.0)*im
 Dual{Nothing}(0.4414749108530729,-1.0,-3.0) + Dual{Nothing}(0.0,0.0,0.0)*im

So there’s going to be a lot of messing with arrays-of-structs. Really x is going to be complex like xtil_plus here, so that’s one more layer. I think this is the right way around, Complex{Dual{...}}.

When you call fft(x_plus_dx), first it converts the matrix to be complex, then it makes a plan, then it applies it. I think these are the steps to make that work here:

# @edit fft(x_plus_dx, 1:1) points me here:
AbstractFFTs.complexfloat(x::AbstractArray{<:ForwardDiff.Dual}) = float.(x .+ 0im)

# @edit fft(x_plus_dx .+ 0im, 1:1) # now this makes a plan, we need:
AbstractFFTs.plan_fft(x::AbstractArray{<:ForwardDiff.Dual}, region=1:ndims(x)) = plan_fft(ForwardDiff.value.(x) .+ 0im, region)
AbstractFFTs.plan_fft(x::AbstractArray{<:Complex{<:ForwardDiff.Dual}}, region=1:ndims(x)) = plan_fft(ForwardDiff.value.(x), region)

# Where I want value() to work on complex duals too:
ForwardDiff.value(x::Complex{<:ForwardDiff.Dual}) = Complex(x.re.value, x.im.value)
ForwardDiff.partials(x::Complex{<:ForwardDiff.Dual}, n::Int) = Complex(ForwardDiff.partials(x.re, n), ForwardDiff.partials(x.im, n))
ForwardDiff.npartials(x::Complex{<:ForwardDiff.Dual}) = ForwardDiff.npartials(x.re)

# Now fft(x_plus_dx) fails at *(p::FFTW.cFFTWPlan{ComplexF64, -1, false, 1, UnitRange{Int64}}, x::Vector{Complex{ForwardDiff.Dual{Nothing, Float64, 2}}}), great! 
function Base.:*(p::AbstractFFTs.Plan, x::AbstractArray{<:Complex{<:ForwardDiff.Dual}})
    xtil = p * ForwardDiff.value.(x)
    ndxs = ForwardDiff.npartials(first(x))
    dxtils = ntuple(ndxs) do n
        p * ForwardDiff.partials.(x, n)
    end
    # dxtils = (dx1til, dx2til)
    ndxs == 2 || error("this won't yet work for npartials(x) != 2, sorry")
    dx1til, dx2til = dxtils
    @. Complex(
        ForwardDiff.Dual(real(xtil), tuple(real(dx1til), real(dx2til))),
        ForwardDiff.Dual(imag(xtil), tuple(imag(dx1til), imag(dx2til))),
        )
end

fft(x_plus_dx) # works! 
xtil_plus == fft(x_plus_dx)

This does quite a lot of copying. It might be neater to treat the different components by slicing views out of the array, like dx2til above. It seems that FFTW is happy to handle this, if you are consistent:

p1k = plan_fft(rand(ComplexF64, 1000))
r1 = rand(ComplexF64, 1000);
@time p1k * r1; @time p1k * r1;
r2 = @view rand(ComplexF64, 2000)[1:2:end];
p1k * r2; # ArgumentError: FFTW plan applied to wrong-strides array
p2k = plan_fft(r2)
@time p2k * r2; @time p2k * r2; # This does work without copying, compare:
@time copy(r2); @time copy(r2);

Perhaps you can similarly handle the output by switching thigns to fft! on (views of) a copy of the data, rather than making separate slices & re-assembling them. But anyway, that’s a start!

To make my simplest Hessian example work:

AbstractFFTs.plan_bfft(x::AbstractArray{<:ForwardDiff.Dual}, region=1:ndims(x)) = plan_bfft(ForwardDiff.value.(x) .+ 0im, region)
AbstractFFTs.plan_bfft(x::AbstractArray{<:Complex{<:ForwardDiff.Dual}}, region=1:ndims(x)) = plan_bfft(ForwardDiff.value.(x), region)

Zygote.extract(x_plus_dx) # ([0.7786240130665765, 0.3371491022135036], [1.0 2.0; 1.0 4.0])
# @edit Zygote.extract(x_plus_dx)
function Zygote.extract(xs::AbstractArray{<:Complex{<:ForwardDiff.Dual{T,V,N}}}) where {T,V,N}
  J = similar(xs, complex(V), N, length(xs))
  for i = 1:length(xs), j = 1:N
    J[j, i] = xs[i].re.partials.values[j] + im * xs[i].im.partials.values[j]
  end
  x0 = ForwardDiff.value.(xs)
  return x0, J
end

Zygote.hessian(x->sum(abs2, fft(x)), rand(2)) # ok
4 Likes