Type unstable gradients in Zygote (@code_warntype)

Ah, now that all the code is together I see the issue. x is a non-constant global, which means any accesses to it inside a closure may be type unstable:

julia> runcb(f) = f()
runcb (generic function with 1 method)

julia> @code_warntype runcb(() -> x + x)
MethodInstance for runcb(::var"#15#16")
  from runcb(f) @ Main REPL[15]:1
Arguments
  #self#::Core.Const(runcb)
  f::Core.Const(var"#15#16"())
Body::Any
1 ─ %1 = (f)()::Any
└──      return %1

There are a few ways to avoid this:

  1. Declare x as const
  2. Bind a local variable, e.g. using the let trick:
let x = x
  @code_warntype gradient(model -> sum(model(x)), model)
end
  1. Use a function barrier and pass x to the function as an argument

You’ll likely use 2) or 3) in practice.