Image Rotation Algorithm for CUDA and Zygote

Hey,

today I wanted to use a image rotation algorithm to rotate a 3D array around one dimension (so basically many 2D images).
For my applications it needs to be fast and also fully differentiable by Zygote.

I tried ImageTransformations.jl but the algorithms don’t play well with CUDA (some Interpolations.jl errors).
Consequently, I was looking for potential algorithms to implement:

I’m not totally sure how to achieve this because I’m not really familiar with CUDA (last week I bought the first GPU since my old AGP 8x :smiley:). I would be happy if some people could point me into relevant directions or algorithms.

Or does there even exist an Julia code offering a rotation?

Thanks,

Felix

As far as I know there is nothing out of the box with works with GPU and gradients. As rotations are just a special case of affine transformations, one could take the shot and port the grid sampler (docs, GPU source) and the affine grid generator from pytorch. This can handle arbitrary transformations and is used to build spatial transformer networks. However this is a lot of complex code and there must be a julian way to handle this much more elegantly, but I can’t think of any right now.

1 Like

Thanks for your answer!

In the meanwhile, I created a FFT based rotation algorithm for 3D arrays for a single special case.
It’s pretty fast (you need basically 6 fft(arr, [1]))) and also fast with Zygote (since the gradient of fft is known).

The code is not public at the moment, and there must be invested even more work to generalize it to for any rotation axis.

But if desired, I could post my special case code here.

1 Like

Yes, please! :slight_smile:

1 Like

I’ll try to boil it down to a 2D example. Should be easier to understand.

I am one of the co-authors of Larkin et al. that is mentioned in the link that you provided. My (unregistered) package Eigenbroetler.jl has a Julia implementation of the 2D algorithm. I’m sure it will need changing to suit your needs, but it’s there for you to look at and modify.

1 Like

Here is the code for a 2D FFT based rotation.
fftpos is extracted from PhysicalOptics.jl and I copied it for simplicity.

I didn’t include it here, but it should work pretty much the same with CUDA (tested it last week).

using FFTW

function shear(arr, Δx)
	ĩ = 2
	c = eltype(arr)(2π * Δx)
	
	ϕ_1D_shift = c .* rfftfreq(size(arr)[ĩ], one(eltype(arr)))'
	ϕ_shift_strength = fftpos(one(eltype(arr)), size(arr)[2])
	ϕ_2D = exp.(1im .* ϕ_1D_shift .* ϕ_shift_strength)
	
	arr_ft = rfft(arr, [ĩ])
	
	return irfft(arr_ft .* ϕ_2D, size(arr)[ĩ], [ĩ])
end

function rotate(arr, θ)
	α = -tan(θ/2)
    β = sin(θ)
    
	arr = shear(arr, α * size(arr)[1])
	arr = permutedims(arr, (2,1))
	arr = shear(arr, β * size(arr)[1])
	arr = permutedims(arr, (2,1))
	arr = shear(arr, α * size(arr)[1])
	
	return arr
end

Here is a full Pluto example:

Full Example
### A Pluto.jl notebook ###
# v0.12.19

using Markdown
using InteractiveUtils

# This Pluto notebook uses @bind for interactivity. When running this notebook outside of Pluto, the following 'mock version' of @bind gives bound variables a default value (instead of an error).
macro bind(def, element)
    quote
        local el = $(esc(element))
        global $(esc(def)) = Core.applicable(Base.get, el) ? Base.get(el) : missing
        el
    end
end

# ╔═╡ 7560fc84-6862-11eb-0f99-bfa76ae8a694
using Revise, FFTW, FFTResampling, TestImages, Colors, PlutoUI

# ╔═╡ 16e4aa72-6863-11eb-35ad-21c3ba47ebad
function fftpos(l, N)
    if N % 2 == 0
        dx = l / N
        return range(-l/2, l/2-dx, length=N)
    else
        return range(-l/2, l/2, length=N) 
    end
end

# ╔═╡ 79d641c0-6862-11eb-15db-77d9548980c1
function shear(arr, Δx)
	ĩ = 2
	c = eltype(arr)(2π * Δx)
	
	ϕ_1D_shift = c .* rfftfreq(size(arr)[ĩ], one(eltype(arr)))'
	ϕ_shift_strength = fftpos(one(eltype(arr)), size(arr)[2])
	ϕ_2D = exp.(1im .* ϕ_1D_shift .* ϕ_shift_strength)
	
	arr_ft = rfft(arr, [ĩ])
	
	return irfft(arr_ft .* ϕ_2D, size(arr)[ĩ], [ĩ])
end

# ╔═╡ 0323a81a-6864-11eb-1c5f-eb449060c54b
function rotate(arr, θ)
	α = -tan(θ/2)
    β = sin(θ)
    
	arr = shear(arr, α * size(arr)[1])
	arr = permutedims(arr, (2,1))
	arr = shear(arr, β * size(arr)[1])
	arr = permutedims(arr, (2,1))
	arr = shear(arr, α * size(arr)[1])
	
	return arr
end

# ╔═╡ 79b6251e-6862-11eb-0260-a71fc759825b
begin
	img = Float32.(testimage("fabip_gray_256"))
	img_pad = FFTResampling.center_set!(zeros(eltype(img), (400, 400)), img)
end

# ╔═╡ 4c4f8fde-6866-11eb-3e26-1b5be5dfd4fa
md"""
$(@bind θ Slider(-180:180))
"""

# ╔═╡ b20c01c8-6866-11eb-33c1-89d7c45484ed
img_s = rotate(img_pad, θ / 180 * π)

# ╔═╡ b4b405c6-6866-11eb-1b55-173146524f32
md"""

$ \theta= $ $(θ)°
"""

# ╔═╡ 95d4846c-6863-11eb-11c5-5568bca3e754
Gray.(img_s)

# ╔═╡ Cell order:
# ╠═7560fc84-6862-11eb-0f99-bfa76ae8a694
# ╠═16e4aa72-6863-11eb-35ad-21c3ba47ebad
# ╠═79d641c0-6862-11eb-15db-77d9548980c1
# ╠═0323a81a-6864-11eb-1c5f-eb449060c54b
# ╠═79b6251e-6862-11eb-0260-a71fc759825b
# ╠═b20c01c8-6866-11eb-33c1-89d7c45484ed
# ╠═4c4f8fde-6866-11eb-3e26-1b5be5dfd4fa
# ╠═b4b405c6-6866-11eb-1b55-173146524f32
# ╠═95d4846c-6863-11eb-11c5-5568bca3e754
1 Like