Non-Zygote autodiff through a FFT-based poisson solver?

Hi all,

I’m attempting to figure out how to automatically differentiate through code that uses FFTs and IFFTs. I know there are a few topics on this (I’ve referenced some) and I have tried to do my best to get the most out of them, but they haven’t solved my problem.

I’m trying to differentiate through Poisson solver that uses FFT (the process is summed up here, but basically its do a Fourier transform, do some operations in the Fourier space and then do an inverse fourier transform)

Simply put, I have a working function f : v -> w, where v and u are function over R^2 (discretized as matrices) that solves the Poisson problem Δw = v. I want to get ∂f/∂v with AD. I’m using FFTW for the Fourier transform. I realized that the problem was in transform part, so I brought it down to a very trivial MWE:

using FFTW

N = (8, 8)
v0 = rand(N...)

pfft = plan_fft(v0) 
pifft = plan_ifft(fft(v0))

apply_fft(x, pfft, pifft) = pfft * x
apply_ifft(x, pfft, pifft) = pifft * x

f(v) = real.(apply_ifft(
    apply_fft(v, pfft, pifft),
    pfft, pifft
))

  • Zygote is the only library that worked, using the custom adjoint for apply_ifft defined here.
using Zygote

Zygote.@adjoint apply_ifft(x, pfft, pifft) = 
	    apply_ifft(x, pifft, pfft), 
		c̄ -> (1 ./ length(c̄) .* (pfft * c̄), nothing, nothing, nothing)

Zygote.jacobian(f, v0)

However because of the bad performance when doing a lot of array indexing (which is required for my usecase, as the grid is big) I found Zygote to be too slow.


  • ReverseDiff: I can’t get it to work at all. I tried defining my own adjoint for apply_ifft only, but my apply_fft function doesn’t work on TrackedReals (the only way that worked was by doing a conditional on the type of input, and returning pfft * x.value instead when there was a tracked array - this gave a Jacobian full of zeros and I’m not sure I’m supposed to be using the value property anyways.
    I tried to have rules for both fft and ifft:
using ReverseDiff, ChainRulesCore

# FFT
function ChainRulesCore.rrule(::typeof(apply_fft), x, pfft, pifft)
	y = apply_fft(x, pfft, pifft)
	function apply_fft_pullback(c̄)
		return (
			ChainRulesCore.NoTangent(),
			1.0 ./ length(c̄) .* (pifft * c̄),
			ChainRulesCore.NoTangent(),
			ChainRulesCore.NoTangent()
		)
	end
	return y, apply_fft_pullback
end

# IFFT
function ChainRulesCore.rrule(::typeof(apply_ifft), x, pfft, pifft)
	y = apply_ifft(x, pfft, pifft)
	function apply_ifft_pullback(c̄)
		return (
			ChainRulesCore.NoTangent(),
			1.0 ./ length(c̄) .* (pfft * c̄),
			ChainRulesCore.NoTangent(),
			ChainRulesCore.NoTangent()
		)
	end
	return y, apply_ifft_pullback
end

ReverseDiff.@grad_from_chainrules apply_fft(x::TrackedArray, pifft, pfft)
ReverseDiff.@grad_from_chainrules apply_ifft(x::TrackedArray, pifft, pfft)

ReverseDiff.jacobian(f, v0)

But this gives me a TypeError: in TrackedReal, in V, expected V<:Real, got Type{ComplexF64}. I’m guessing that I can’t define a custom pullback for function that output to C ?


  • ForwardDiff: the example I have isn’t minimal enough to post here, but I’ve tried writing a definition of apply_fft that accepts dual numbers, without any success. The end of this post suggests that this hasn’t been solved.

  • Enzyme:
using Enzyme

Enzyme.jacobian(Forward, f, v0)
InvalidIRError: compiling function 
#apply_fft(Matrix{Float64}, FFTW.cFFTWPlan{ComplexF64, -1, false, 2, UnitRange{Int64}}, AbstractFFTs.ScaledPlan{ComplexF64, FFTW.cFFTWPlan{ComplexF64, 1, false, 2, UnitRange{Int64}}, Float64}) 
resulted in invalid LLVM IR

Reason: unsupported jl_lazy_load_and_lookup
Enzyme.jacobian(Reverse, f, v0, Val(N[1]*N[2]))
InvalidIRError: compiling function 
#apply_fft(Matrix{Float64}, FFTW.cFFTWPlan{ComplexF64, -1, false, 2, UnitRange{Int64}}, AbstractFFTs.ScaledPlan{ComplexF64, FFTW.cFFTWPlan{ComplexF64, 1, false, 2, UnitRange{Int64}}, Float64}) 
resulted in invalid LLVM IR

Reason: unsupported jl_lazy_load_and_lookup

I’m at loss at what to do now. I’m guessing I can still use FiniteDifferences but for the same reasons (FFT being used) its hard to automatically (or manually I think) get the color vectors / sparsity pattern, so the performance is better than Zygote but still not good.

Thank you very much for taking the time to read and for your help!
(I can provide stack traces if needed, I just thought the post was long enough…)