As I mentioned just now in the Autodifferentiation with FFT and Enzyme? · Issue #597 · SciML/NonlinearSolve.jl · GitHub this actually works well in Julia 1.10, and I am happily using Enzyme with FFTW and it is efficient. Thanks again @danielwe
The issues appears to be on Julia 1.11 (I tested this a few weeks back, so that might have changed).
Having a general set of rules for Enzyme and FFTW somewhere would be nice though to make this a bit easier, and the above seems to work. For reference, I currently use this copy and pasted together bunch of rules (written by @danielwe and @sethaxen and with help of others and tweaked slightly), which work in both Forward mode and Reverse mode.
using AbstractFFTs
using AbstractFFTs.LinearAlgebra
using FFTW
using Enzyme
using EnzymeCore
######################
# Forward-mode rules #
######################
const DuplicatedOrBatchDuplicated{T} = Union{Duplicated{T},BatchDuplicated{T}}
# since FFTs are linear, implement all forward-model rules generically at a low-level
function EnzymeRules.forward(
config,
func::Const{typeof(mul!)},
RT::Type{<:Const},
y::DuplicatedOrBatchDuplicated{<:StridedArray{T}},
p::Const{<:AbstractFFTs.Plan{T}},
x::DuplicatedOrBatchDuplicated{<:StridedArray{T}},
) where {T}
func.val(y.val, p.val, x.val)
if x isa Duplicated && y isa Duplicated
dval = func.val(y.dval, p.val, x.dval)
elseif x isa Duplicated && y isa Duplicated
dval = ntuple(Val(EnzymeRules.width(config))) do i
func.val(y.dval[i], p.val, x.dval[i])
end
end
nothing
end
function EnzymeRules.forward(
config,
func::Const{typeof(*)},
RT::Type{
<:Union{Const,Duplicated,DuplicatedNoNeed,BatchDuplicated,BatchDuplicatedNoNeed}
},
p::Const{<:AbstractFFTs.Plan},
x::DuplicatedOrBatchDuplicated{<:StridedArray},
)
RT <: Const && return func.val(p.val, x.val)
if x isa Duplicated
dval = func.val(p.val, x.dval)
RT <: DuplicatedNoNeed && return dval
val = func.val(p.val, x.val)
RT <: Duplicated && return Duplicated(val, dval)
else # x isa BatchDuplicated
dval = ntuple(Val(EnzymeRules.width(config))) do i
func.val(p.val, x.dval[i])
end
RT <: BatchDuplicatedNoNeed && return dval
val = func.val(p.val, x.val)
RT <: BatchDuplicated && return BatchDuplicated(val, dval)
end
end
Enzyme.EnzymeRules.inactive(::typeof(FFTW.assert_applicable), args...) = true
Enzyme.EnzymeRules.inactive_type(::Type{<:AbstractFFTs.Plan}) = true
Enzyme.EnzymeRules.inactive(::typeof(plan_fft), args...) = true
######################
# Reverse-mode rules #
######################
function EnzymeRules.augmented_primal(
config,
::Const{typeof(*)},
::Type,
P::Const{<:AbstractFFTs.Plan{T}},
x::Duplicated{<:StridedArray{T}},
) where {T}
# we can never skip the forward pass because we don't know a priori whether P is an
# in-place plan, and in-place mutation for non-NoNeed arguments must be performed
# regardless of needs_primal
xval = x.val
yval = P.val * xval
inplace = Base.mightalias(yval, xval)
if inplace
@assert yval === xval # otherwise I don't know what to do
end
needs_primal = EnzymeRules.needs_primal(config)
primal = needs_primal ? yval : nothing
shadow = if EnzymeRules.needs_shadow(config)
if needs_primal || inplace
make_zero(yval)
else # might as well reuse yval as shadow then
make_zero!(yval)
yval
end
else
nothing
end
tape = (inplace, shadow) # since * is linear, we don't care whether x is overwritten
return EnzymeRules.AugmentedReturn(primal, shadow, tape)
end
function EnzymeRules.reverse(
::EnzymeRules.RevConfigWidth{1},
::Const{typeof(*)},
::Type,
tape,
P::Const{<:AbstractFFTs.Plan{T}},
x::Duplicated{<:StridedArray{T}},
) where {T}
inplace, shadow = tape
Padj = adjoint(P.val)
dx = x.dval
if inplace
out = Padj * dx
@assert out === dx # sanity check that adjoint(P) is in-place when P is
end
if !isnothing(shadow)
dx .+= Padj * shadow # mul! not yet supported for adjoint plans
make_zero!(shadow)
end
return (nothing, nothing)
end