Hi all,
I’m attempting to figure out how to automatically differentiate through code that uses FFTs and IFFTs. I know there are a few topics on this (I’ve referenced some) and I have tried to do my best to get the most out of them, but they haven’t solved my problem.
I’m trying to differentiate through Poisson solver that uses FFT (the process is summed up here, but basically its do a Fourier transform, do some operations in the Fourier space and then do an inverse fourier transform)
Simply put, I have a working function f : v -> w
, where v
and u
are function over R^2 (discretized as matrices) that solves the Poisson problem Δw = v
. I want to get ∂f/∂v
with AD. I’m using FFTW for the Fourier transform. I realized that the problem was in transform part, so I brought it down to a very trivial MWE:
using FFTW
N = (8, 8)
v0 = rand(N...)
pfft = plan_fft(v0)
pifft = plan_ifft(fft(v0))
apply_fft(x, pfft, pifft) = pfft * x
apply_ifft(x, pfft, pifft) = pifft * x
f(v) = real.(apply_ifft(
apply_fft(v, pfft, pifft),
pfft, pifft
))
-
Zygote is the only library that worked, using the custom adjoint for
apply_ifft
defined here.
using Zygote
Zygote.@adjoint apply_ifft(x, pfft, pifft) =
apply_ifft(x, pifft, pfft),
c̄ -> (1 ./ length(c̄) .* (pfft * c̄), nothing, nothing, nothing)
Zygote.jacobian(f, v0)
However because of the bad performance when doing a lot of array indexing (which is required for my usecase, as the grid is big) I found Zygote to be too slow.
-
ReverseDiff: I can’t get it to work at all. I tried defining my own adjoint for
apply_ifft
only, but myapply_fft
function doesn’t work onTrackedReals
(the only way that worked was by doing a conditional on the type of input, and returningpfft * x.value
instead when there was a tracked array - this gave a Jacobian full of zeros and I’m not sure I’m supposed to be using thevalue
property anyways.
I tried to have rules for bothfft
andifft
:
using ReverseDiff, ChainRulesCore
# FFT
function ChainRulesCore.rrule(::typeof(apply_fft), x, pfft, pifft)
y = apply_fft(x, pfft, pifft)
function apply_fft_pullback(c̄)
return (
ChainRulesCore.NoTangent(),
1.0 ./ length(c̄) .* (pifft * c̄),
ChainRulesCore.NoTangent(),
ChainRulesCore.NoTangent()
)
end
return y, apply_fft_pullback
end
# IFFT
function ChainRulesCore.rrule(::typeof(apply_ifft), x, pfft, pifft)
y = apply_ifft(x, pfft, pifft)
function apply_ifft_pullback(c̄)
return (
ChainRulesCore.NoTangent(),
1.0 ./ length(c̄) .* (pfft * c̄),
ChainRulesCore.NoTangent(),
ChainRulesCore.NoTangent()
)
end
return y, apply_ifft_pullback
end
ReverseDiff.@grad_from_chainrules apply_fft(x::TrackedArray, pifft, pfft)
ReverseDiff.@grad_from_chainrules apply_ifft(x::TrackedArray, pifft, pfft)
ReverseDiff.jacobian(f, v0)
But this gives me a TypeError: in TrackedReal, in V, expected V<:Real, got Type{ComplexF64}
. I’m guessing that I can’t define a custom pullback for function that output to C ?
-
ForwardDiff: the example I have isn’t minimal enough to post here, but I’ve tried writing a definition of
apply_fft
that accepts dual numbers, without any success. The end of this post suggests that this hasn’t been solved.
- Enzyme:
using Enzyme
Enzyme.jacobian(Forward, f, v0)
InvalidIRError: compiling function
#apply_fft(Matrix{Float64}, FFTW.cFFTWPlan{ComplexF64, -1, false, 2, UnitRange{Int64}}, AbstractFFTs.ScaledPlan{ComplexF64, FFTW.cFFTWPlan{ComplexF64, 1, false, 2, UnitRange{Int64}}, Float64})
resulted in invalid LLVM IR
Reason: unsupported jl_lazy_load_and_lookup
Enzyme.jacobian(Reverse, f, v0, Val(N[1]*N[2]))
InvalidIRError: compiling function
#apply_fft(Matrix{Float64}, FFTW.cFFTWPlan{ComplexF64, -1, false, 2, UnitRange{Int64}}, AbstractFFTs.ScaledPlan{ComplexF64, FFTW.cFFTWPlan{ComplexF64, 1, false, 2, UnitRange{Int64}}, Float64})
resulted in invalid LLVM IR
Reason: unsupported jl_lazy_load_and_lookup
I’m at loss at what to do now. I’m guessing I can still use FiniteDifferences but for the same reasons (FFT being used) its hard to automatically (or manually I think) get the color vectors / sparsity pattern, so the performance is better than Zygote but still not good.
Thank you very much for taking the time to read and for your help!
(I can provide stack traces if needed, I just thought the post was long enough…)