I have been trying to make a Spatial Transformer Network in Flux (https://arxiv.org/pdf/1506.02025.pdf) mostly for my own education and because I think it would be slick to make one in Julia via Flux + Interpolations.
I am having trouble defining the custom gradient, however. I’ve posted an example below, based off of MNIST example in the model zoo. The forward pass goes fine, but when I do train!
it gives me a MethodError: objects of type Float64 are not callable
.
After inspecting, it seems to me that the custom gradient I defined for FilledExtrapolation
objects is not being called, and it is instead trying to do the gradient itself, rather than by the Interpolations version.
What can I do to figure out more about what is being called in the backward pass? If I define a custom gradient, will it automatically work with broadcasting, as I’ve done here? What else am I missing?
Minimal excerpt:
using Flux
using Flux.Tracker: TrackedArray, TrackedReal, track, @grad, data
import Interpolations; const ITP = Interpolations;
# Fix overly-restrictive call signature in Interpolations (will submit PR)
@inline ITP.lispyprod(p, v::AbstractVector{T}, rest...) where T = ITP.lispyprod(p*zero(T), rest...)
(itp::ITP.FilledExtrapolation{<:Any, 4, <:ITP.AbstractInterpolation, <:Any, <:Any})(x::TrackedReal, y::TrackedReal, c, n) = track(itp, x, y, c, n)
@grad function (itp::ITP.FilledExtrapolation{<:Any, 4, <:ITP.AbstractInterpolation, <:Any, <:Any})(x::TrackedReal, y::TrackedReal, c::Int, n::Int)
g = ITP.gradient(itp, x, y, c, n)
return (itp(data(x), data(y), data(c), data(n)), D -> (D*g[1], D*g[2], nothing, nothing))
end
z = rand(5, 5, 1, 5)
itp = ITP.extrapolate(ITP.interpolate(z,
(QuadInterp, QuadInterp, ITP.NoInterp())), zero(eltype(z)))
y, back = Flux.Tracker.forward(itp, 1, 2, 1, 1) # this gives an error within Interpolations, suggesting custom gradient wasn't called
Full example:
using Flux, Flux.Data.MNIST, Statistics
using Flux: onehotbatch, onecold, crossentropy, throttle
using Flux.Tracker: TrackedArray, TrackedReal, track, @grad, data
using Base.Iterators: repeated, partition
import Interpolations; const ITP = Interpolations;
# using CuArrays
# Classify MNIST digits with a Spatial Transformer Network
# "Spatial Transformer Networks" M Jaderberg, K Simonyan, A Zisserman, K Kavukcuoglu
# ArXiv https://arxiv.org/abs/1506.02025
imgs = MNIST.images()
labels = onehotbatch(MNIST.labels(), 0:9)
# Partition into batches of size 1,000
train = [(cat(float.(imgs[i])..., dims = 4), labels[:,i])
for i in partition(1:60_000, 1000)]
train = gpu.(train)
# Prepare test set (first 1,000 images)
tX = cat(float.(MNIST.images(:test)[1:1000])..., dims = 4) |> gpu
tY = onehotbatch(MNIST.labels(:test)[1:1000], 0:9) |> gpu
# Fix overly-restrictive call signature in Interpolations (will submit PR)
@inline ITP.lispyprod(p, v::AbstractVector{T}, rest...) where T = ITP.lispyprod(p*zero(T), rest...)
# Define Interpolations gradients for Flux
const QuadInterp = ITP.BSpline(ITP.Quadratic(ITP.Line(ITP.OnCell())))
(itp::ITP.FilledExtrapolation{<:Any, 4, <:ITP.AbstractInterpolation, <:Any, <:Any})(x::TrackedReal, y::TrackedReal, c::TrackedReal, n::TrackedReal) = track(itp, x, y, c, n)
@grad function (itp::ITP.FilledExtrapolation{<:Any, 4, <:ITP.AbstractInterpolation, <:Any, <:Any})(x::TrackedReal, y::TrackedReal, c::Int, n::Int)
g = ITP.gradient(itp, x, y, c, n)
return (itp(data(x), data(y), data(c), data(n)), D -> (D*g[1], D*g[2], nothing, nothing))
end
mutable struct STN{NNT, GT}
localizer::NNT
grid::GT
end
function (m::STN)(x)
img_transform = m.localizer(x)
interp_grid_x = reshape(m.grid * img_transform[:, 1, :], size(x))
interp_grid_y = reshape(m.grid * img_transform[:, 2, :], size(x))
itp = ITP.extrapolate(ITP.interpolate(x,
(QuadInterp, QuadInterp, ITP.NoInterp(), ITP.NoInterp())), zero(eltype(x)))
itp.(interp_grid_x, interp_grid_y, 1, reshape(1:size(x, 4), 1, 1, 1, size(x, 4)))
end
Flux.@treelike STN
# this is the "localizer", which learns the affine transform to perform
localizer = Chain(x -> reshape(x, :, size(x, 4)),
Dense(28^2, 32, relu),
Dense(32, 6),
x -> reshape(x, 3, 2, size(x, 2))) |> gpu
# this makes a meshgrid, which is multiplied by the learned affine transform
grid = [i == 3 ? 1 : i == 1 ? (x-1)%28+1 : x÷28+1 for x=1:(28^2), i=1:3]
m = Chain(
STN(localizer, grid),
Conv((2,2), 1=>16, relu),
x -> maxpool(x, (2,2)),
Conv((2,2), 16=>8, relu),
x -> maxpool(x, (2,2)),
x -> reshape(x, :, size(x, 4)),
Dense(288, 10), softmax) |> gpu
m(train[1][1][:, :, :, 1:10])
loss(x, y) = crossentropy(m(x), y)
accuracy(x, y) = mean(onecold(m(x)) .== onecold(y))
evalcb = throttle(() -> @show(accuracy(tX, tY)), 10)
opt = ADAM(params(m))
Flux.train!(loss, train, opt, cb = evalcb)