Simple function using Zygote gradient is type-unstable

Hi there, I have been trying to figure out why when I use Zygote to calculate the gradient of a simple function it returns the correct value but if I used @code_warntype it suggests the function is type-unstable and is returning type: ::Union{Nothing, Tuple{Any}}.

The code:

using Zygote

func2 = x -> x^2

function grad(x)

    gradient(var -> func2(var), x)

end

If I call grad(a) I get the answer I expect but I thought it would be type-stable. If I directly call the gradient function this doesn’t show any problems with @code_warntype

I feel like I’ve missed something obvious with type inference and gradients. Any pointers would be most appreciated, thanks!

@code_warntype grad(2f0)
MethodInstance for grad(::Float32)
  from grad(x) @ Main
Arguments
  #self#::Core.Const(grad)
  x::Float32
Locals
  #153::var"#153#154"
Body::Union{Nothing, Tuple{Any}}
1 ─      (#153 = %new(Main.:(var"#153#154")))
│   %2 = #153::Core.Const(var"#153#154"())
│   %3 = Main.gradient(%2, x)::Union{Nothing, Tuple{Any}}
└──      return %3

this is not a function, func2 is a global non-const variable that at the moment binds to an anonymous function

3 Likes

In other words, if you change

func2 = x -> x^2

to

func2_better(x) = x^2

it should become type-stable.

The reason @jling hinted at is that your anonymous function func2 is a global variable. This means that grad must plan for the possibility that func2 may be modified at any time, and even change type. Hence the instability.

4 Likes

gradient(func2, x) looks better than gradient(var -> func2(var), x). This is also covered in the Manual’s style guide section.

1 Like

Doh :grimacing:.
That’s bad from me, thank you both for that.
That question was more of a prelude for a slightly more involved example where I was seeing similar behaviour, this time without my faux pas with function declarations.

using Zygote
using Statistics
using Optimisers

a1 = 1f0 .+ randn(Float32, (1024,1024))

y1 = ones(Float32,(1024,1024))

iter = 100
rule = Optimisers.Adam()
state = Optimisers.setup(rule, a1)

function loss(x,y)

    return mean(abs.((x.^2 .- y.^2).^2))

end

function train_instance(a1,y1,state)

    grads = gradient(loss, a1,y1)[1]
    Optimisers.update!(state, a1, grads)  
    
end

function train(a1,y1,iter,state)

    for i in 1:iter

        state, a1 = train_instance(a1,y1,state)

    end
    return a1
end

p = train(a1,y1,iter,state)

This is still a pared down version of what I have but it includes more of the things I’m trying to do. It is contrived and does not perform well but I am most interested in trying to understand where the type instability comes from (Are my global non-constant a1 and y1 causing this? I found using @descend quite confusing on my first try). This is my first time using optimisers and Zygote having come from JAX so I’m sure I’ve missed some really basic things, please point out all these bits!

When I look at @code_warntype for the train variable I get the following readout:

MethodInstance for train(::Matrix{Float32}, ::Matrix{Float32}, ::Int64, ::Optimisers.Leaf{Adam{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}})
  from train(a1, y1, iter, state) @ Main 
Arguments
  #self#::Core.Const(train)
  a1@_2::Matrix{Float32}
  y1::Matrix{Float32}
  iter::Int64
  state@_5::Optimisers.Leaf{Adam{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}        
Locals
  @_6::Union{Nothing, Tuple{Int64, Int64}}
  @_7::Int64
  i::Int64
  a1@_9::Any
  state@_10::Optimisers.Leaf{Adam{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}       
Body::Any
1 ─       (a1@_9 = a1@_2)
│         (state@_10 = state@_5)
│   %3  = (1:iter)::Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])
│         (@_6 = Base.iterate(%3))
│   %5  = (@_6 === nothing)::Bool
│   %6  = Base.not_int(%5)::Bool
└──       goto #4 if not %6
2 ┄ %8  = @_6::Tuple{Int64, Int64}
│         (i = Core.getfield(%8, 1))
│   %10 = Core.getfield(%8, 2)::Int64
│   %11 = Main.train_instance(a1@_9, y1, state@_10)::Tuple{Optimisers.Leaf{Adam{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, Any}
│   %12 = Base.indexed_iterate(%11, 1)::Core.PartialStruct(Tuple{Optimisers.Leaf{Adam{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, Int64}, Any[Optimisers.Leaf{Adam{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, Core.Const(2)])
│         (state@_10 = Core.getfield(%12, 1))
│         (@_7 = Core.getfield(%12, 2))
│   %15 = Base.indexed_iterate(%11, 2, @_7::Core.Const(2))::Core.PartialStruct(Tuple{Any, Int64}, Any[Any, Core.Const(3)])
│         (a1@_9 = Core.getfield(%15, 1))
│         (@_6 = Base.iterate(%3, %10))
│   %18 = (@_6 === nothing)::Bool
│   %19 = Base.not_int(%18)::Bool
└──       goto #4 if not %19
3 ─       goto #2
4 ┄       return a1@_9

To expand on that, function blocks do a couple things at the global scope when the function doesn’t exist yet: (1) declare the provided name const, (2) function’s display includes that name. You can actually pull off (1) after an anonymous function exists by assigning it to a const global variable, which you can even use to add methods, but the function is still anonymous because (2) didn’t happen.

julia> func2 = x -> x^2  # non-const (unstable) global, anonymous function
#1 (generic function with 1 method)

julia> const f2 = func2  # const (stable) global, anonymous function
#1 (generic function with 1 method)

julia> f2() = 2  # add method via const global, still anonymous
#1 (generic function with 2 methods)

julia> func3(x) = x^2  # const global, named function
func3 (generic function with 1 method)

In a local scope, there is no const, so (1) doesn’t happen; you could reassign the variable but it’s better to leave it alone and use a different variable name for other things. When you return locally defined functions, you usually see the # typical of anonymous functions except you’ll see the name show up if you used the named function syntax. Display isn’t consistent, though.

julia> function gh(n)
         g() = n
         h = () -> n
         g, h
       end
gh (generic function with 1 method)

julia> g2, h2 = gh(1) # display in returned tuple
(var"#g#14"{Int64}(1), var"#13#15"{Int64}(1))

julia> g2  # slightly different display
(::var"#g#14"{Int64}) (generic function with 1 method)

julia> h2  # slightly different display
#13 (generic function with 1 method)

julia> g, h = gh(1)  # variable name affects display
(g, var"#13#15"{Int64}(1))

Analyzing this call with JET.@report_opt shows many run time dispatches, and Test.@inferred shows that the return value is not inferred. Seems like an Optimisers bug, assuming you’re using the package as documented (I’m not familiar with it).

Yeah, from looking at their code a bit, it seems like there’s no intention to either

  1. have fast code by preventing run time dispatch, or
  2. make the return type inferrable by Julia

Your type instability seems to be caused by the return type inference failure, which is caused by this:

I’m pretty sure you could fix the type inference failure by modifying the above to the following:

ret = if haskey(grads, ℓ)
  ℓ.state, x̄′ = apply!(ℓ.rule, ℓ.state, x, grads[ℓ]...)
  subtract!(x, x̄′)
else
  x # no gradient seen
end
params[(ℓ,x)] = ret
ret

Furthermore, it might make sense to try to instantiate this IdDict with a proper type parameter, if possible and easy enough:

Thanks for digging into this and providing some suggestions. I will take a look and see how far I get!

1 Like

Yeah, I suggest you make a PR or bug report.

1 Like