Type unstable gradients in Zygote (@code_warntype)

Hello there,
I’m trying to write efficient code for custom automatic differentiation in Julia.

I noticed that even for a simple case, Zygote gives type unstable gradients. Can you please help me understand why?

Here’s the example: I define a model called GaussianModel which takes a single parameter as input. This model can also be evaluated to return the value of its parameter. This is inspired by implicit gradients in Zygote’s documentation.

using LinearAlgebra, Random, Zygote

abstract type Model end

mutable struct GaussianModel{T<:AbstractFloat} <: Model
    σ::T
end

(model::GaussianModel)() = model.σ

function logprob(x::T, y::T, model::GaussianModel{T}) where {T<:AbstractFloat}
    σ = model()
    return -(y - x)^2 / (2σ^2) - log(2 * π * σ^2) / 2
end

Now with this code

x = 0.2
y = 0.1
mymodel = GaussianModel(0.5)
logprob(x, y, mymodel)
gradient(model -> logprob(x, y, model), mymodel)

I get the correct result, but this

@code_warntype gradient(model -> logprob(x, y, model), mymodel)

gives me

MethodInstance for Zygote.gradient(::var"#129#130", ::GaussianModel{Float64})
  from gradient(f, args...) in Zygote at [...]
Arguments
  #self#e[36m::Core.Const(Zygote.gradient)e[39m
  fe[36m::Core.Const(var"#129#130"())e[39m
  argse[36m::Tuple{GaussianModel{Float64}}e[39m
Locals
  @_4e[36m::Int64e[39m
  grade[33me[1m::Union{Nothing, Tuple}e[22me[39m
  backe[91me[1m::Zygote.var"#75#76"e[22me[39m
  ye[91me[1m::Anye[22me[39m
Bodye[33me[1m::Union{Nothing, Tuple{Any}}e[22me[39m
e[90m1 ─e[39m %1  = Core.tuple(f)e[36m::Core.Const((var"#129#130"(),))e[39m
e[90m│  e[39m %2  = Core._apply_iterate(Base.iterate, Zygote.pullback, %1, args)e[91me[1m::Tuple{Any, Zygote.var"#75#76"}e[22me[39m
e[90m│  e[39m %3  = Base.indexed_iterate(%2, 1)e[36m::Core.PartialStruct(Tuple{Any, Int64}, Any[Any, Core.Const(2)])e[39m
e[90m│  e[39m       (y = Core.getfield(%3, 1))
e[90m│  e[39m       (@_4 = Core.getfield(%3, 2))
e[90m│  e[39m %6  = Base.indexed_iterate(%2, 2, @_4::Core.Const(2))e[36m::Core.PartialStruct(Tuple{Zygote.var"#75#76", Int64}, Any[Zygote.var"#75#76", Core.Const(3)])e[39m
e[90m│  e[39m       (back = Core.getfield(%6, 1))
e[90m│  e[39m %8  = Zygote.sensitivity(y)e[91me[1m::Anye[22me[39m
e[90m│  e[39m       (grad = (back)(%8))
e[90m│  e[39m %10 = Zygote.isnothing(grad)e[36m::Boole[39m
e[90m└──e[39m       goto #3 if not %10
e[90m2 ─e[39m       return Zygote.nothing
e[90m3 ─e[39m %13 = Zygote.map(Zygote._project, args, grad::Tuple)e[91me[1m::Tuple{Any}e[22me[39m
e[90m└──e[39m       return %13

which seems a type instability. I attach a screenshot for clarity.
Thank you in advance!

Due to fundamental limitations in how Zygote itself is designed, taking gradients with respect to mutable structs will almost always be type unstable. If you’d like GaussianModel to be type stable under AD, the easiest approach would be to make it immutable and use a library such as Setfield.jl or Accessors.jl to do “mutation”.

Maybe I’m wrong, but I’m not sure that’s the issue. I did some test and the problem remains even for immutable structs.

This example is copy-pasted directly from Zygote’s documentation: Home · Zygote

struct Linear
    W
    b
end

(l::Linear)(x) = l.W * x .+ l.b

model = Linear(rand(2, 5), rand(2))
x = rand(5)

Now

@code_warntype gradient(model -> sum(model(x)), model)

gives

MethodInstance for Zygote.gradient(::var"#11#12", ::Linear)
  from gradient(f, args...) in Zygote at [...]
Arguments
  #self#e[36m::Core.Const(Zygote.gradient)e[39m
  fe[36m::Core.Const(var"#11#12"())e[39m
  argse[36m::Tuple{Linear}e[39m
Locals
  @_4e[36m::Int64e[39m
  grade[33me[1m::Union{Nothing, Tuple}e[22me[39m
  backe[91me[1m::Zygote.var"#75#76"e[22me[39m
  ye[91me[1m::Anye[22me[39m
Bodye[33me[1m::Union{Nothing, Tuple{Any}}e[22me[39m
e[90m1 ─e[39m %1  = Core.tuple(f)e[36m::Core.Const((var"#11#12"(),))e[39m
e[90m│  e[39m %2  = Core._apply_iterate(Base.iterate, Zygote.pullback, %1, args)e[91me[1m::Tuple{Any, Zygote.var"#75#76"}e[22me[39m
e[90m│  e[39m %3  = Base.indexed_iterate(%2, 1)e[36m::Core.PartialStruct(Tuple{Any, Int64}, Any[Any, Core.Const(2)])e[39m
e[90m│  e[39m       (y = Core.getfield(%3, 1))
e[90m│  e[39m       (@_4 = Core.getfield(%3, 2))
e[90m│  e[39m %6  = Base.indexed_iterate(%2, 2, @_4::Core.Const(2))e[36m::Core.PartialStruct(Tuple{Zygote.var"#75#76", Int64}, Any[Zygote.var"#75#76", Core.Const(3)])e[39m
e[90m│  e[39m       (back = Core.getfield(%6, 1))
e[90m│  e[39m %8  = Zygote.sensitivity(y)e[91me[1m::Anye[22me[39m
e[90m│  e[39m       (grad = (back)(%8))
e[90m│  e[39m %10 = Zygote.isnothing(grad)e[36m::Boole[39m
e[90m└──e[39m       goto #3 if not %10
e[90m2 ─e[39m       return Zygote.nothing
e[90m3 ─e[39m %13 = Zygote.map(Zygote._project, args, grad::Tuple)e[91me[1m::Tuple{Any}e[22me[39m
e[90m└──e[39m       return %13

That’s for a completely different reason. Notice that Linear doesn’t have any type annotations on its fields. This means any access to them will be type unstable, so it has nothing to do with AD:

julia> @code_warntype model(x)
MethodInstance for (::Linear)(::Vector{Float64})
  from (l::Linear)(x) @ Main REPL[2]:1
Arguments
  l::Linear
  x::Vector{Float64}
Body::Any
1 ─ %1 = Main.:+::Core.Const(+)
│   %2 = Base.getproperty(l, :W)::Any
│   %3 = (%2 * x)::Any
│   %4 = Base.getproperty(l, :b)::Any
│   %5 = Base.broadcasted(%1, %3, %4)::Any
│   %6 = Base.materialize(%5)::Any
└──      return %6

This is more well known as Performance Tips · The Julia Language under the Julia Performance Tips. I think the original author of that tutorial opted to omit the types to make it less intimidating for Julia newcomers, but we could revisit that decision.

Thank you for the reply. But again

struct Linear{T<:AbstractFloat}
    W::Matrix{T}
    b::Vector{T}
end

(l::Linear)(x) = l.W * x .+ l.b

model = Linear(rand(2, 5), rand(2))
x = rand(5)
@code_warntype gradient(model -> sum(model(x)), model)

gives

MethodInstance for Zygote.gradient(::var"#13#14", ::Linear{Float64})
  from gradient(f, args...) in Zygote at [...]
Arguments
  #self#e[36m::Core.Const(Zygote.gradient)e[39m
  fe[36m::Core.Const(var"#13#14"())e[39m
  argse[36m::Tuple{Linear{Float64}}e[39m
Locals
  @_4e[36m::Int64e[39m
  grade[33me[1m::Union{Nothing, Tuple}e[22me[39m
  backe[91me[1m::Zygote.var"#75#76"e[22me[39m
  ye[91me[1m::Anye[22me[39m
Bodye[33me[1m::Union{Nothing, Tuple{Any}}e[22me[39m
e[90m1 ─e[39m %1  = Core.tuple(f)e[36m::Core.Const((var"#13#14"(),))e[39m
e[90m│  e[39m %2  = Core._apply_iterate(Base.iterate, Zygote.pullback, %1, args)e[91me[1m::Tuple{Any, Zygote.var"#75#76"}e[22me[39m
e[90m│  e[39m %3  = Base.indexed_iterate(%2, 1)e[36m::Core.PartialStruct(Tuple{Any, Int64}, Any[Any, Core.Const(2)])e[39m
e[90m│  e[39m       (y = Core.getfield(%3, 1))
e[90m│  e[39m       (@_4 = Core.getfield(%3, 2))
e[90m│  e[39m %6  = Base.indexed_iterate(%2, 2, @_4::Core.Const(2))e[36m::Core.PartialStruct(Tuple{Zygote.var"#75#76", Int64}, Any[Zygote.var"#75#76", Core.Const(3)])e[39m
e[90m│  e[39m       (back = Core.getfield(%6, 1))
e[90m│  e[39m %8  = Zygote.sensitivity(y)e[91me[1m::Anye[22me[39m
e[90m│  e[39m       (grad = (back)(%8))
e[90m│  e[39m %10 = Zygote.isnothing(grad)e[36m::Boole[39m
e[90m└──e[39m       goto #3 if not %10
e[90m2 ─e[39m       return Zygote.nothing
e[90m3 ─e[39m %13 = Zygote.map(Zygote._project, args, grad::Tuple)e[91me[1m::Tuple{Any}e[22me[39m
e[90m└──e[39m       return %13

even if

@code_warntype model(x)

is stable

MethodInstance for (::Linear{Float64})(::Vector{Float64})
  from (l::Linear)(x::Vector{T}) where T<:AbstractFloat in Main at [...]
Static Parameters
  T = e[36mFloat64e[39m
Arguments
  le[36m::Linear{Float64}e[39m
  xe[36m::Vector{Float64}e[39m
Bodye[36m::Vector{Float64}e[39m
e[90m1 ─e[39m %1 = Base.getproperty(l, :W)e[36m::Matrix{Float64}e[39m
e[90m│  e[39m %2 = (%1 * x)e[36m::Vector{Float64}e[39m
e[90m│  e[39m %3 = Base.getproperty(l, :b)e[36m::Vector{Float64}e[39m
e[90m│  e[39m %4 = Base.broadcasted(Main.:+, %2, %3)e[36m::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(+), Tuple{Vector{Float64}, Vector{Float64}}}e[39m
e[90m│  e[39m %5 = Base.materialize(%4)e[36m::Vector{Float64}e[39m
e[90m└──e[39m      return %5
1 Like

Ah, now that all the code is together I see the issue. x is a non-constant global, which means any accesses to it inside a closure may be type unstable:

julia> runcb(f) = f()
runcb (generic function with 1 method)

julia> @code_warntype runcb(() -> x + x)
MethodInstance for runcb(::var"#15#16")
  from runcb(f) @ Main REPL[15]:1
Arguments
  #self#::Core.Const(runcb)
  f::Core.Const(var"#15#16"())
Body::Any
1 ─ %1 = (f)()::Any
└──      return %1

There are a few ways to avoid this:

  1. Declare x as const
  2. Bind a local variable, e.g. using the let trick:
let x = x
  @code_warntype gradient(model -> sum(model(x)), model)
end
  1. Use a function barrier and pass x to the function as an argument

You’ll likely use 2) or 3) in practice.