Zygote, Flux, and Interpolations

Hey everyone!

I’m working on radio telescope imaging and am trying to do image reconstruction with an iterative, maximum entropy method. Many papers suggest using conjugate gradient, and have intricate derivations of gradients. I thought it would be neat to do the same thing, but pushing the model through some sort of Autodiff.

In this process, I am needing to take samples of an interpolated FFT and calculate a mean squared error to measured radio telescope data.

Something like this

function vis_res(image, vis_data, uv)
    # Generate interpolated visibilities from the image
    N = length(vis_data)
    image_fft = fft(image) |> fftshift
    freqs = fftfreq(size(image)[1]) |> fftshift
    interpolation = LinearInterpolation((freqs, freqs), image_fft)
    vis_interp = [interpolation(uv[:,i]...) for i ∈ 1:N]
    # Calculate visibility residuals
    return (abs.(vis_interp .- vis_data)).^2
end

However, Flux/Zygote doesn’t seem to be happy with the indexing of the interpolation. Throwing the error:

ERROR: ArgumentError: unable to check bounds for indices of type Interpolations.WeightedAdjIndex{2, Float64}

Any help in this regard would be greatly appreciated.

2 Likes

Funny, I just ran into the same error. I’m trying to differentiate through a likelihood function that uses an interpolation using Zygote.

I expect the incompatibility is within the Interpolations library rather than the indexing you show. My likelihood function also calls a manual bi-linear interpolation function I wrote, and it has no issue with that as far as I can tell.

You might also want to try testing ForwardDiff unless you have to use Zygote for some reason.

1 Like

Yeah I dug in deep to this - and it actually does work on the current master branch of Interpolations, as the release hasn’t been bumped in a few months. There is still a strange indexing problem, stemming from strangeness in the dimensionality of the gradient. I solved my specific example by just providing an explicit gradient.

It would help to have a full stacktrace as well. As-is (having no knowledge of how interpolations works), I have no idea which line is even causing the error.

Looks like Interpolations didn’t gain AD support until after the latest release, so that makes sense Edit: saw you commented already on a linked issue :). Worth reporting an issue if you’re getting incorrect gradient values.