Autodifferentiation with FFT, with Enzyme?

As I mentioned just now in the Autodifferentiation with FFT and Enzyme? · Issue #597 · SciML/NonlinearSolve.jl · GitHub this actually works well in Julia 1.10, and I am happily using Enzyme with FFTW and it is efficient. Thanks again @danielwe

The issues appears to be on Julia 1.11 (I tested this a few weeks back, so that might have changed).

Having a general set of rules for Enzyme and FFTW somewhere would be nice though to make this a bit easier, and the above seems to work. For reference, I currently use this copy and pasted together bunch of rules (written by @danielwe and @sethaxen and with help of others and tweaked slightly), which work in both Forward mode and Reverse mode.

using AbstractFFTs
using AbstractFFTs.LinearAlgebra
using FFTW
using Enzyme
using EnzymeCore

######################
# Forward-mode rules #
######################

const DuplicatedOrBatchDuplicated{T} = Union{Duplicated{T},BatchDuplicated{T}}

# since FFTs are linear, implement all forward-model rules generically at a low-level

function EnzymeRules.forward(
    config,
    func::Const{typeof(mul!)},
    RT::Type{<:Const},
    y::DuplicatedOrBatchDuplicated{<:StridedArray{T}},
    p::Const{<:AbstractFFTs.Plan{T}},
    x::DuplicatedOrBatchDuplicated{<:StridedArray{T}},
) where {T}
    func.val(y.val, p.val, x.val)
    if x isa Duplicated && y isa Duplicated
        dval = func.val(y.dval, p.val, x.dval)
    elseif x isa Duplicated && y isa Duplicated
        dval = ntuple(Val(EnzymeRules.width(config))) do i
            func.val(y.dval[i], p.val, x.dval[i])
        end
    end
    nothing
end

function EnzymeRules.forward(
    config,
    func::Const{typeof(*)},
    RT::Type{
        <:Union{Const,Duplicated,DuplicatedNoNeed,BatchDuplicated,BatchDuplicatedNoNeed}
    },
    p::Const{<:AbstractFFTs.Plan},
    x::DuplicatedOrBatchDuplicated{<:StridedArray},
)
    RT <: Const && return func.val(p.val, x.val)
    if x isa Duplicated
        dval = func.val(p.val, x.dval)
        RT <: DuplicatedNoNeed && return dval
        val = func.val(p.val, x.val)
        RT <: Duplicated && return Duplicated(val, dval)
    else  # x isa BatchDuplicated
        dval = ntuple(Val(EnzymeRules.width(config))) do i
            func.val(p.val, x.dval[i])
        end
        RT <: BatchDuplicatedNoNeed && return dval
        val = func.val(p.val, x.val)
        RT <: BatchDuplicated && return BatchDuplicated(val, dval)
    end
end

Enzyme.EnzymeRules.inactive(::typeof(FFTW.assert_applicable), args...) = true
Enzyme.EnzymeRules.inactive_type(::Type{<:AbstractFFTs.Plan}) = true
Enzyme.EnzymeRules.inactive(::typeof(plan_fft), args...) = true

######################
# Reverse-mode rules #
######################

function EnzymeRules.augmented_primal(
    config,
    ::Const{typeof(*)},
    ::Type,
    P::Const{<:AbstractFFTs.Plan{T}},
    x::Duplicated{<:StridedArray{T}},
) where {T}
    # we can never skip the forward pass because we don't know a priori whether P is an
    # in-place plan, and in-place mutation for non-NoNeed arguments must be performed
    # regardless of needs_primal
    xval = x.val
    yval = P.val * xval
    inplace = Base.mightalias(yval, xval)
    if inplace
        @assert yval === xval  # otherwise I don't know what to do
    end
    needs_primal = EnzymeRules.needs_primal(config)
    primal = needs_primal ? yval : nothing
    shadow = if EnzymeRules.needs_shadow(config)
        if needs_primal || inplace
            make_zero(yval)
        else  # might as well reuse yval as shadow then
            make_zero!(yval)
            yval
        end
    else
        nothing
    end
    tape = (inplace, shadow)  # since * is linear, we don't care whether x is overwritten
    return EnzymeRules.AugmentedReturn(primal, shadow, tape)
end

function EnzymeRules.reverse(
    ::EnzymeRules.RevConfigWidth{1},
    ::Const{typeof(*)},
    ::Type,
    tape,
    P::Const{<:AbstractFFTs.Plan{T}},
    x::Duplicated{<:StridedArray{T}},
) where {T}
    inplace, shadow = tape
    Padj = adjoint(P.val)
    dx = x.dval
    if inplace
        out = Padj * dx
        @assert out === dx  # sanity check that adjoint(P) is in-place when P is
    end
    if !isnothing(shadow)
        dx .+= Padj * shadow  # mul! not yet supported for adjoint plans
        make_zero!(shadow)
    end
    return (nothing, nothing)
end
2 Likes