Autodifferentiation with FFT, with Enzyme?

It would be great to have some Enzyme rules in AbstractFFTs.jl (there’s a draft PR at Add EnzymeRules by sethaxen · Pull Request #103 · JuliaMath/AbstractFFTs.jl · GitHub, but someone has to adopt the project and get it over the finish line).

In the meantime, maybe you’ll find the following Enzyme reverse rule useful, extracted from a private project I’m working on. It’s not a complete solution (no forward rules, no batched rules, no rule for mul!), but I think it’s correct for the cases it covers, namely P * x for both in-place and out-of-place P.

using AbstractFFTs: AbstractFFTs
using Enzyme: Enzyme, EnzymeRules, Const, Duplicated, make_zero, make_zero!

Enzyme.EnzymeRules.inactive_type(::Type{<:AbstractFFTs.Plan}) = true

function EnzymeRules.augmented_primal(
    config::EnzymeRules.RevConfigWidth{1},
    ::Const{typeof(*)},
    ::Type,
    P::Const{<:AbstractFFTs.Plan{T}},
    x::Duplicated{<:StridedArray{T}},
) where {T}
    # we can never skip the forward pass because we don't know a priori whether P is an
    # in-place plan, and in-place mutation for non-NoNeed arguments must be performed
    # regardless of needs_primal
    xval = x.val
    yval = P.val * xval
    inplace = Base.mightalias(yval, xval)
    if inplace
        @assert yval === xval  # otherwise I don't know what to do
    end
    needs_primal = EnzymeRules.needs_primal(config)
    primal = needs_primal ? yval : nothing
    shadow = if EnzymeRules.needs_shadow(config)
        if needs_primal || inplace
            make_zero(yval)
        else  # might as well reuse yval as shadow then
            make_zero!(yval)
            yval
        end
    else
        nothing
    end
    tape = (inplace, shadow)  # since * is linear, we don't care whether x is overwritten
    return EnzymeRules.AugmentedReturn(primal, shadow, tape)
end

function EnzymeRules.reverse(
    ::EnzymeRules.RevConfigWidth{1},
    ::Const{typeof(*)},
    ::Type,
    tape,
    P::Const{<:AbstractFFTs.Plan{T}},
    x::Duplicated{<:StridedArray{T}},
) where {T}
    inplace, shadow = tape
    Padj = adjoint(P.val)
    dx = x.dval
    if inplace
        out = Padj * dx
        @assert out === dx  # sanity check that adjoint(P) is in-place when P is
    end
    if !isnothing(shadow)
        dx .+= Padj * shadow  # mul! not yet supported for adjoint plans
        make_zero!(shadow)
    end
    return (nothing, nothing)
end

This only works for plan types that already implement AbstractFFTs.AdjointStyle and AbstractFFTs.adjoint_mul, i.e., types that already have working Zygote/ChainRules rules. That notably excludes r2r and dct transforms from FFTW.jl, as well as most/all of FastTransforms.jl, I think. I have implementations lying around for the one-dimensional FFTW REDFT10 r2r transform and the Chebyshev transform from FastTransforms.jl if you’re interested. (I should really get around to making some PRs, it’s just so much work to flesh it out from just the parts I need to a complete and tested implementation.)

3 Likes