Spatial transformer (i.e. GPU-friendly interpolations with gradients) with Flux.jl?

Originally posted this on the flux-bridged Slack channel, but thought this might be a good place to post. I need a spatial transformer or something equivalent for a project, written so far entirely in Julia. I thought Interpolations.jl might be a good solution, but it isn’t GPU-friendly so far. There’s a sample implementation but it’s old and not gpu-friendly and i’m not sure how to adapt it to be.

What I really need is a way to differentiably do this (from the paper)

The other options are a) moving everything to Pytorch (which has the parts I’m missing baked in – would rather not :frowning: or b) getting a Pytorch model which somehow works with Zygote + the rest of my model upstream? Is there any hope?

Also I think it would be nice to provide an example for people (maybe in the Model Zoo or wherever) who might want a spatial transformer in Julia (or some of the functionality anyway) without having to do all this soul searching :slightly_smiling_face:

Any help greatly appreciated!

3 Likes

Coincidentally, a PR for grid sampling was just opened against NNlib (CUDA version here). That leaves the affine grid generation, for which the code at https://github.com/thebhatman/Spatial-Transformer-Network/blob/master/src/stn.jl#L43-L58 might be workable with some minor changes since it’s 100% vectorized already. Also have a look at NNlib · Flux to see if you can make use of anything there.

As for this, the Julia ML community is orders of magnitude smaller than those around TF or PyTorch, so the likelihood of newer/more specialized models like spatial transformers being added to the zoo is pretty low (case in point, this is the first time I’ve heard of this architecture!) The best way to remedy that is for folks with the know-how to contribute PRs to the zoo, so those are very much welcome :slight_smile:

3 Likes

Yes absolutely, happy to put this up if it works - I just need a few steps to get there!

The PR works! Note for anyone who wants to try: you need to merge the PR into the branches of NNlib.jl and NNlibCUDA.jl and add/precompile the packages that way (i think?) for the functionality to work, otherwise you may get a “this intrinsic must be compiled to be called” error when trying to take a gradient.

Now the only issue is the grid generator. This is my minor adaptation from thebhatman/Spatial-Transformer-Network. It works so far, I’ll put the full implementation up soon!


x = LinRange(-1f0, 1f0, width) |> gpu
y = LinRange(-1f0, 1f0, height) |> gpu

function affine_grid_generator(x, y, theta)
    batch_size = size(theta)[3]
    one1 = one(eltype(theta))
    x = LinRange(-one1, one1, width)
    y = LinRange(-one1, one1, height)
    x_t_flat = reshape(repeat(x, height), 1, height * width)
    y_t_flat = reshape(repeat(transpose(y), width), 1, height * width)
    all_ones = ones(eltype(x_t_flat), 1, size(x_t_flat)[2])

    sampling_grid = vcat(x_t_flat, y_t_flat, all_ones)
    sampling_grid = reshape(
        transpose(repeat(transpose(sampling_grid), batch_size)),
        3,
        size(x_t_flat)[2],
        batch_size,
    )

    batch_grids = batched_mul(theta, sampling_grid)
    y_s = reshape(batch_grids[2, :, :], width, height, batch_size)
    x_s = reshape(batch_grids[1, :, :], width, height, batch_size)
    return x_s, y_s

end
2 Likes

Would love to see an updated spatial transformer in the zoo too!

Ooh forgot about this, I’ll try and do it over the break!

2 Likes

Pardon the delay, just made a PR for a spatial transformer example up on the model zoo, https://github.com/FluxML/model-zoo/pull/361

2 Likes