Autodifferentiation with FFT, with Enzyme?

Hi,
I’m trying to use autodifferentiation (for gradients, jacobians), working with a complex code that uses a lot of array mutation, custom structs, and FFTs (using FFTW). I got a small part of my code working by making my code generic enough and using FastTransformsForwardDiff.jl but it would be a lot of work to redesign my code to be fully generic (lots of mutated buffers etc.).

I was hoping I could just throw Enzyme at it, but it seems Enzyme cannot work with AbstractFFTs.jl

Has anyone been able to get Enzyme to work with an FFT interface?

Thanks,
John

It would be great to have some Enzyme rules in AbstractFFTs.jl (there’s a draft PR at Add EnzymeRules by sethaxen · Pull Request #103 · JuliaMath/AbstractFFTs.jl · GitHub, but someone has to adopt the project and get it over the finish line).

In the meantime, maybe you’ll find the following Enzyme reverse rule useful, extracted from a private project I’m working on. It’s not a complete solution (no forward rules, no batched rules, no rule for mul!), but I think it’s correct for the cases it covers, namely P * x for both in-place and out-of-place P.

using AbstractFFTs: AbstractFFTs
using Enzyme: Enzyme, EnzymeRules, Const, Duplicated, make_zero, make_zero!

Enzyme.EnzymeRules.inactive_type(::Type{<:AbstractFFTs.Plan}) = true

function EnzymeRules.augmented_primal(
    config::EnzymeRules.RevConfigWidth{1},
    ::Const{typeof(*)},
    ::Type,
    P::Const{<:AbstractFFTs.Plan{T}},
    x::Duplicated{<:StridedArray{T}},
) where {T}
    # we can never skip the forward pass because we don't know a priori whether P is an
    # in-place plan, and in-place mutation for non-NoNeed arguments must be performed
    # regardless of needs_primal
    xval = x.val
    yval = P.val * xval
    inplace = Base.mightalias(yval, xval)
    if inplace
        @assert yval === xval  # otherwise I don't know what to do
    end
    needs_primal = EnzymeRules.needs_primal(config)
    primal = needs_primal ? yval : nothing
    shadow = if EnzymeRules.needs_shadow(config)
        if needs_primal || inplace
            make_zero(yval)
        else  # might as well reuse yval as shadow then
            make_zero!(yval)
            yval
        end
    else
        nothing
    end
    tape = (inplace, shadow)  # since * is linear, we don't care whether x is overwritten
    return EnzymeRules.AugmentedReturn(primal, shadow, tape)
end

function EnzymeRules.reverse(
    ::EnzymeRules.RevConfigWidth{1},
    ::Const{typeof(*)},
    ::Type,
    tape,
    P::Const{<:AbstractFFTs.Plan{T}},
    x::Duplicated{<:StridedArray{T}},
) where {T}
    inplace, shadow = tape
    Padj = adjoint(P.val)
    dx = x.dval
    if inplace
        out = Padj * dx
        @assert out === dx  # sanity check that adjoint(P) is in-place when P is
    end
    if !isnothing(shadow)
        dx .+= Padj * shadow  # mul! not yet supported for adjoint plans
        make_zero!(shadow)
    end
    return (nothing, nothing)
end

This only works for plan types that already implement AbstractFFTs.AdjointStyle and AbstractFFTs.adjoint_mul, i.e., types that already have working Zygote/ChainRules rules. That notably excludes r2r and dct transforms from FFTW.jl, as well as most/all of FastTransforms.jl, I think. I have implementations lying around for the one-dimensional FFTW REDFT10 r2r transform and the Chebyshev transform from FastTransforms.jl if you’re interested. (I should really get around to making some PRs, it’s just so much work to flesh it out from just the parts I need to a complete and tested implementation.)

Thanks a lot for this!

I tried it with the following test code:

using AbstractFFTs, FFTW

function testme(x)
    p = plan_fft(x)
    p * x
end

function test()
    x = exp.(-collect(range(-100.0, 100.0, length=32)).^2)
    jacobian(Reverse, testme, x)
end

test()

And I get the following error:

ERROR: 
No augmented forward pass found for ijl_lazy_load_and_lookup
 at context:   %107 = call void ()* @ijl_lazy_load_and_lookup({} addrspace(10)* nonnull %106, i8* noundef nonnull getelementptr inbounds ([19 x i8], [19 x i8]* @_j_str_fftw_set_timelimit_4, i32 0, i32 0)) #70, !dbg !198

Stacktrace:
 [1] unsafe_set_timelimit
   @ ~/.julia/packages/FFTW/mFpUb/src/fft.jl:213
 [2] macro expansion
   @ ~/.julia/packages/FFTW/mFpUb/src/fft.jl:654
 [3] cFFTWPlan
   @ ~/.julia/packages/FFTW/mFpUb/src/FFTW.jl:49


Stacktrace:
  [1] unsafe_set_timelimit
    @ ~/.julia/packages/FFTW/mFpUb/src/fft.jl:213 [inlined]
  [2] macro expansion
    @ ~/.julia/packages/FFTW/mFpUb/src/fft.jl:654 [inlined]
  [3] cFFTWPlan
    @ ~/.julia/packages/FFTW/mFpUb/src/FFTW.jl:49
  [4] #plan_fft#10
    @ ~/.julia/packages/FFTW/mFpUb/src/fft.jl:787 [inlined]
  [5] plan_fft
    @ ~/.julia/packages/FFTW/mFpUb/src/fft.jl:777 [inlined]
  [6] plan_fft
    @ ~/.julia/packages/FFTW/mFpUb/src/fft.jl:803 [inlined]
  [7] testme
    @ ~/testenzyme.jl:63 [inlined]
  [8] augmented_julia_testme_43311wrap
    @ ~/testenzyme.jl:0
  [9] macro expansion
    @ ~/.julia/packages/Enzyme/iOOVd/src/compiler.jl:5448 [inlined]
 [10] enzyme_call
    @ ~/.julia/packages/Enzyme/iOOVd/src/compiler.jl:4986 [inlined]
 [11] (::Enzyme.Compiler.AugmentedForwardThunk{…})(fn::Const{…}, args::Duplicated{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/iOOVd/src/compiler.jl:4922
 [12] #128
    @ ~/.julia/packages/Enzyme/iOOVd/src/sugar.jl:933 [inlined]
 [13] macro expansion
    @ ./ntuple.jl:72 [inlined]
 [14] ntuple
    @ ./ntuple.jl:69 [inlined]
 [15] jacobian(mode::EnzymeCore.ReverseMode{…}, f::typeof(testme), x::Vector{…}; n_outs::Val{…}, chunk::Nothing)
    @ Enzyme ~/.julia/packages/Enzyme/iOOVd/src/sugar.jl:929
 [16] jacobian
    @ ~/.julia/packages/Enzyme/iOOVd/src/sugar.jl:846 [inlined]
 [17] #jacobian#127
    @ ~/.julia/packages/Enzyme/iOOVd/src/sugar.jl:861 [inlined]
 [18] jacobian
    @ ~/.julia/packages/Enzyme/iOOVd/src/sugar.jl:846 [inlined]
 [19] test()
    @ Main ~/testenzyme.jl:69
 [20] top-level scope
    @ ~/testenzyme.jl:72
Some type information was truncated. Use `show(err)` to see complete types.

Does that mean anything to you?

Looks like that’s coming from Enzyme trying to differentiate the plan instantiation. That’s not necessary. Try adding the following:

EnzymeRules.inactive(::typeof(plan_fft), args...) = true

Now I get

ERROR: Attempted to use automatic pullback (differential return value) deduction on a either a type unstable function returning an active complex number, or autodiff_deferred returning an active complex number. For the first case, please type stabilize your code, e.g. by specifying autodiff(Reverse, f->f(x)::Complex, ...). For the second case, please use regular non-deferred autodiff

I tried the type stability suggestion, but I think ti should be stable, and it had no effect.

I tried it too, and I get a method error because the rule is limited to plans and vectors having the same eltype, and the FFTW plan is complex while the vector is real. So I tried making the vector complex (just changing x = exp.(...) to x = complex.(exp.(...)).

Then I got the same error as you’re reporting. I think the error message is a little misleading, and the problem is actually that testme(x) returns a complex vector, for which the Jacobian is a tricky concept. There are multiple ways to define it, and the recommendation is usually to just work with functions with real returns, separating real and imaginary parts as needed before returning.

So I changed p * x to real(p * x). Now it works.

julia> test()
(ComplexF64[1.0 + 0.0im 1.0 + 0.0im … 1.0 + 0.0im 1.0 + 0.0im; 1.0 + 0.0im 0.9807852804032304 + 0.19509032201612828im … 0.9238795325112867 - 0.3826834323650898im 0.9807852804032304 - 0.19509032201612828im; … ; 1.0 + 0.0im 0.9238795325112867 - 0.3826834323650898im … 0.7071067811865476 + 0.7071067811865476im 0.9238795325112867 + 0.3826834323650898im; 1.0 + 0.0im 0.9807852804032304 - 0.19509032201612828im … 0.9238795325112867 + 0.3826834323650898im 0.9807852804032304 + 0.19509032201612828im],)

The rule could probably be generalized to accept the combination of real vectors and complex plans, but I’m not going to look into that right now.

I’d recommend double-checking against some finite differences before relying on this rule when using complex plans, though. I wrote it for r2r/Chebyshev plans, and haven’t thought through whether the complex case would need some extra conjugations or something like that.