Setting Flux custom gradients

I have been trying to make a Spatial Transformer Network in Flux (https://arxiv.org/pdf/1506.02025.pdf) mostly for my own education and because I think it would be slick to make one in Julia via Flux + Interpolations.

I am having trouble defining the custom gradient, however. I’ve posted an example below, based off of MNIST example in the model zoo. The forward pass goes fine, but when I do train! it gives me a MethodError: objects of type Float64 are not callable.

After inspecting, it seems to me that the custom gradient I defined for FilledExtrapolation objects is not being called, and it is instead trying to do the gradient itself, rather than by the Interpolations version.

What can I do to figure out more about what is being called in the backward pass? If I define a custom gradient, will it automatically work with broadcasting, as I’ve done here? What else am I missing?

Minimal excerpt:

using Flux
using Flux.Tracker: TrackedArray, TrackedReal, track, @grad, data
import Interpolations; const ITP = Interpolations;
# Fix overly-restrictive call signature in Interpolations (will submit PR)
@inline ITP.lispyprod(p, v::AbstractVector{T}, rest...) where T = ITP.lispyprod(p*zero(T), rest...)

(itp::ITP.FilledExtrapolation{<:Any, 4, <:ITP.AbstractInterpolation, <:Any, <:Any})(x::TrackedReal, y::TrackedReal, c, n) = track(itp, x, y, c, n)
@grad function (itp::ITP.FilledExtrapolation{<:Any, 4, <:ITP.AbstractInterpolation, <:Any, <:Any})(x::TrackedReal, y::TrackedReal, c::Int, n::Int)
  g = ITP.gradient(itp, x, y, c, n)
  return (itp(data(x), data(y), data(c), data(n)), D -> (D*g[1], D*g[2], nothing, nothing))
end

z = rand(5, 5, 1, 5)
itp = ITP.extrapolate(ITP.interpolate(z,
           (QuadInterp, QuadInterp, ITP.NoInterp())), zero(eltype(z)))

y, back = Flux.Tracker.forward(itp, 1, 2, 1, 1) # this gives an error within Interpolations, suggesting custom gradient wasn't called

Full example:

using Flux, Flux.Data.MNIST, Statistics
using Flux: onehotbatch, onecold, crossentropy, throttle
using Flux.Tracker: TrackedArray, TrackedReal, track, @grad, data
using Base.Iterators: repeated, partition
import Interpolations; const ITP = Interpolations;
# using CuArrays

# Classify MNIST digits with a Spatial Transformer Network
# "Spatial Transformer Networks" M Jaderberg, K Simonyan, A Zisserman, K Kavukcuoglu
# ArXiv https://arxiv.org/abs/1506.02025

imgs = MNIST.images()

labels = onehotbatch(MNIST.labels(), 0:9)

# Partition into batches of size 1,000
train = [(cat(float.(imgs[i])..., dims = 4), labels[:,i])
         for i in partition(1:60_000, 1000)]

train = gpu.(train)

# Prepare test set (first 1,000 images)
tX = cat(float.(MNIST.images(:test)[1:1000])..., dims = 4) |> gpu
tY = onehotbatch(MNIST.labels(:test)[1:1000], 0:9) |> gpu

# Fix overly-restrictive call signature in Interpolations (will submit PR)
@inline ITP.lispyprod(p, v::AbstractVector{T}, rest...) where T = ITP.lispyprod(p*zero(T), rest...)

# Define Interpolations gradients for Flux
const QuadInterp = ITP.BSpline(ITP.Quadratic(ITP.Line(ITP.OnCell())))
(itp::ITP.FilledExtrapolation{<:Any, 4, <:ITP.AbstractInterpolation, <:Any, <:Any})(x::TrackedReal, y::TrackedReal, c::TrackedReal, n::TrackedReal) = track(itp, x, y, c, n)
@grad function (itp::ITP.FilledExtrapolation{<:Any, 4, <:ITP.AbstractInterpolation, <:Any, <:Any})(x::TrackedReal, y::TrackedReal, c::Int, n::Int)
  g = ITP.gradient(itp, x, y, c, n)
  return (itp(data(x), data(y), data(c), data(n)), D -> (D*g[1], D*g[2], nothing, nothing))
end

mutable struct STN{NNT, GT}
  localizer::NNT
  grid::GT
end
function (m::STN)(x)
  img_transform = m.localizer(x)
  interp_grid_x = reshape(m.grid * img_transform[:, 1, :], size(x))
  interp_grid_y = reshape(m.grid * img_transform[:, 2, :], size(x))

  itp = ITP.extrapolate(ITP.interpolate(x,
    (QuadInterp, QuadInterp, ITP.NoInterp(), ITP.NoInterp())), zero(eltype(x)))

  itp.(interp_grid_x, interp_grid_y, 1, reshape(1:size(x, 4), 1, 1, 1, size(x, 4)))
end
Flux.@treelike STN

# this is the "localizer", which learns the affine transform to perform
localizer = Chain(x -> reshape(x, :, size(x, 4)),
                  Dense(28^2, 32, relu),
                  Dense(32, 6),
                  x -> reshape(x, 3, 2, size(x, 2))) |> gpu
# this makes a meshgrid, which is multiplied by the learned affine transform
grid = [i == 3 ? 1 : i == 1 ? (x-1)%28+1 : x÷28+1 for x=1:(28^2), i=1:3]
m = Chain(
  STN(localizer, grid),
  Conv((2,2), 1=>16, relu),
  x -> maxpool(x, (2,2)),
  Conv((2,2), 16=>8, relu),
  x -> maxpool(x, (2,2)),
  x -> reshape(x, :, size(x, 4)),
  Dense(288, 10), softmax) |> gpu

m(train[1][1][:, :, :, 1:10])

loss(x, y) = crossentropy(m(x), y)

accuracy(x, y) = mean(onecold(m(x)) .== onecold(y))

evalcb = throttle(() -> @show(accuracy(tX, tY)), 10)
opt = ADAM(params(m))

Flux.train!(loss, train, opt, cb = evalcb)

Maybe you will get a more detailed answer, but for a start: I’m not certain that @grad understands callable types which aren’t Functions. Can you make this work in the simplest possible example? Perhaps you need to define forward(f, args...) yourself?

Yes, if I do the following, using itp defined above:

foo = (x, y) -> itp(x, y, 1, 1)
y, back = Flux.Tracker.forward(foo, 1, 2)

then it works, but I’m not sure how to “broadcast” over the closure (by which I mean not broadcasting Flux.Tracker.forward, but broadcasting over a TrackedArray x and y, as well as the last integer in the itp argument, which in this case represents the mini-batch number.

OK, what I said isn’t the problem, nice to know that that works.

What exactly is the error in your minimal example, and can you reproduce this from some simple type you make up, not from another package?

In the full example, note that Flux’s broadcasting machinery uses ForwardDiff dual numbers, not its own TrackedReal, for broadcasting operations:

g(x) = (println(typeof(x)); x^2)
Flux.gradient(x -> sum(g.(x)), rand(2))

Thus the gradient you define for itp won’t be used. I wonder if it wouldn’t be simpler to define one custom gradient for m::STN.

1 Like

Thanks, I will try making a pared-down example. In retrospect the minimal example is a bit too minimal and misses the issue that comes up in the full example.

Thus the gradient you define for itp won’t be used. I wonder if it wouldn’t be simpler to define one custom gradient for m::STN .

This is a good insight, I tried to define a custom gradient on a non-broadcasted itp (so the function call would be itp(x::TrackedArray, y::TrackedArray, c, n) but couldn’t get it to work at the time. I will experiment some more.

I tried the following:

function (m::STN)(x)
  itp = ITP.extrapolate(ITP.interpolate(x,
    (QuadInterp, QuadInterp, ITP.NoInterp(), ITP.NoInterp())), zero(eltype(x)))

  interp_grid_x = x[:, :, :, :]
  interp_grid_y = x[:, :, :, :]

  itp.(interp_grid_x, interp_grid_y, 1, reshape(1:size(x, 4), 1, 1, 1, size(x, 4)))
end

And this actually worked, which suggests to me that it had something to do with the localizer network. I also tried the other variants, which forced itp to extrapolate, and I tested on subsets of x (like x[1:3, 1:2, :, :]) with the corresponding modifications in the Chain definition. Those all seemed to work.

I started running into trouble here:

function (m::STN)(x)
  itp = ITP.extrapolate(ITP.interpolate(x,
    (QuadInterp, QuadInterp, ITP.NoInterp(), ITP.NoInterp())), zero(eltype(x)))

  y = m.localizer(x)
  interp_grid_x = reshape(y, 3, 2, 1, size(x, 4))
  interp_grid_y = reshape(y, 3, 2, 1, size(x, 4))
    
  itp.(interp_grid_x, interp_grid_y, 1, reshape(1:size(x, 4), 1, 1, 1, size(x, 4)))
end

Which gave me an error DimensionMismatch("arrays could not be broadcast to a common size") during the backward pass, even though size(interp_grid_x) is the same in both cases and the forward pass was fine.

The main difference is that I call y = m.localizer(x), where localizer is a neural network itself. Maybe I shouldn’t be calling it in this function, and should structure it differently?

EDIT: If I use

function (m::STN)(x)
  y = m.localizer(x)
  reshape(y, 3, 2, 1, size(x, 4))
end

then there are no issues, so it is somehow the combination of broadcasting + the call to the model?

No great ideas… but just to be clear, m.localizer(x) is a Flux network, and the plan now is to let ForwardDiff act on itp and not provide any derivatives by hand?

Note that you have reshape(y twice. Also that zero(eltype(x))) may not match what the broadcasting backward pass will feed to itp. I don’t know whether these matter.