It would be great to have some Enzyme rules in AbstractFFTs.jl (there’s a draft PR at Add EnzymeRules by sethaxen · Pull Request #103 · JuliaMath/AbstractFFTs.jl · GitHub, but someone has to adopt the project and get it over the finish line).
In the meantime, maybe you’ll find the following Enzyme reverse rule useful, extracted from a private project I’m working on. It’s not a complete solution (no forward rules, no batched rules, no rule for mul!
), but I think it’s correct for the cases it covers, namely P * x
for both in-place and out-of-place P
.
using AbstractFFTs: AbstractFFTs
using Enzyme: Enzyme, EnzymeRules, Const, Duplicated, make_zero, make_zero!
Enzyme.EnzymeRules.inactive_type(::Type{<:AbstractFFTs.Plan}) = true
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
::Const{typeof(*)},
::Type,
P::Const{<:AbstractFFTs.Plan{T}},
x::Duplicated{<:StridedArray{T}},
) where {T}
# we can never skip the forward pass because we don't know a priori whether P is an
# in-place plan, and in-place mutation for non-NoNeed arguments must be performed
# regardless of needs_primal
xval = x.val
yval = P.val * xval
inplace = Base.mightalias(yval, xval)
if inplace
@assert yval === xval # otherwise I don't know what to do
end
needs_primal = EnzymeRules.needs_primal(config)
primal = needs_primal ? yval : nothing
shadow = if EnzymeRules.needs_shadow(config)
if needs_primal || inplace
make_zero(yval)
else # might as well reuse yval as shadow then
make_zero!(yval)
yval
end
else
nothing
end
tape = (inplace, shadow) # since * is linear, we don't care whether x is overwritten
return EnzymeRules.AugmentedReturn(primal, shadow, tape)
end
function EnzymeRules.reverse(
::EnzymeRules.RevConfigWidth{1},
::Const{typeof(*)},
::Type,
tape,
P::Const{<:AbstractFFTs.Plan{T}},
x::Duplicated{<:StridedArray{T}},
) where {T}
inplace, shadow = tape
Padj = adjoint(P.val)
dx = x.dval
if inplace
out = Padj * dx
@assert out === dx # sanity check that adjoint(P) is in-place when P is
end
if !isnothing(shadow)
dx .+= Padj * shadow # mul! not yet supported for adjoint plans
make_zero!(shadow)
end
return (nothing, nothing)
end
This only works for plan types that already implement AbstractFFTs.AdjointStyle
and AbstractFFTs.adjoint_mul
, i.e., types that already have working Zygote/ChainRules rules. That notably excludes r2r and dct transforms from FFTW.jl, as well as most/all of FastTransforms.jl, I think. I have implementations lying around for the one-dimensional FFTW REDFT10 r2r transform and the Chebyshev transform from FastTransforms.jl if you’re interested. (I should really get around to making some PRs, it’s just so much work to flesh it out from just the parts I need to a complete and tested implementation.)