Originally posted this on the flux-bridged Slack channel, but thought this might be a good place to post. I need a spatial transformer or something equivalent for a project, written so far entirely in Julia. I thought Interpolations.jl might be a good solution, but it isn’t GPU-friendly so far. There’s a sample implementation but it’s old and not gpu-friendly and i’m not sure how to adapt it to be.
What I really need is a way to differentiably do this (from the paper)
The other options are a) moving everything to Pytorch (which has the parts I’m missing baked in – would rather not or b) getting a Pytorch model which somehow works with Zygote + the rest of my model upstream? Is there any hope?
Also I think it would be nice to provide an example for people (maybe in the Model Zoo or wherever) who might want a spatial transformer in Julia (or some of the functionality anyway) without having to do all this soul searching
Any help greatly appreciated!