ForwardDiff.jl with FFTW.jl

Hello all. I am working with ForwardDiff.jl on a problem that requires computing kernel density estimates (and thus convolutions which boils down to FFT). I am searching for a way to do this using autodifferentation. I realize there are a number of AD frameworks. My problems are relatively parametrically small (order of 100 params), involves control flow (while loops and if statements), and requires mutation. Based on this, I’ve narrowed down on ForwardDiff (though looking into Enzyme).

The problem is essentially of the following form. Given some parameters (p), use a black box (for our purposes) to draw samples from a distribution parameterized by p - Dist(p), use kernel density estimation to construct an approximate pdf at known data points pdf(Dist(p),Data). I have verified that my black box that I use to draw samples is ForwardDiff compatible. However the kernel density estimator is the problem. The most efficient way to compute a kernel density estimate is to construct a noisy histogram (H) from samples, and then smooth against a kernel (k). This takes the form

f(x;p) = H(y;p) \star k(y) (x) .

The most efficient way to do this is to FFT everything, turn the convolution into multiplication, then ifft. I currently use ‘rfft’ and ‘irfft’ for all this from FFTW.jl since I’m working exclusively with real valued data.

The problem becomes that eventually the full program differentiation needs the f_p derivative, which requires differentiating the FFT. Unfortunately FFTW is not ForwardDiff compatible (does not have appropriate typing to accept duals). Mathematically the derivative is trivial

\frac{d \, FFT[f](w;p)}{dp} = FFT \left[\frac{df}{dp} (x;p) \right](w)

The question is, how do you get ForwardDiff to work with this? Or alternatively, is there an efficient implementation of convolutions that would be compliant? That after all, is what I need (but is O(n^2) instead of O(n*log(n))).

I did find the following post. However it involved complex FFTs, FFT plans, and the solution was beyond my ability to follow.

1 Like

Thanks for the response. Is there any examples or docs for how this would integrate with a broader ForwardDiff package? FFT is usually part of a larger computation to which AD is being applied.

I’m pretty sure you load both packages and it just works…not sure what there is to document.

If it doesn’t work file a GitHub issue