Type stability with Flux gradient of loss function requiring parameters

Hello,

I’m using Flux to replicate an economics application of machine learning (originally done in Python/Tensor flow here).

In this problem, neural networks are used to solve a functional equation of the form g(f(x),p) = h(f(x),P), where f(x) is an unknown function to be approximated by a neural network and P are some parameters that may need to be adjusted periodically.

The loss function used to train the neural network computes the residuals to the functional equation and requires simulation of f(x). In short, it’s a complicated function that requires several parameters.

A type stability problem arises when I compute the gradient using code like:

grads = Flux.gradient(m -> loss(m,P))

where P is a struct containing various parameters.

This arises from the fact that P is a non-constant global in the closure m -> loss (m,P).

One remedy is to declare P as a const. This works okay. However, I typically use Parameters.jl so that I can use the @unpack macro. When I do that, the type instability returns even if I declare P as a const. Below is a minimal example to illustrate the problem:

using Flux
using Parameters

@with_kw struct TestParas
    expon::Float64 = 2.0
end

const P = TestParas()

mod1 = Dense(2 => 1)

function Loss1(model, P)

    data = ones(2)

    resid = sum(model(data).^P.expon)

    return resid

end

function Loss2(model, P)

    @unpack expon = P

    data = ones(2)

    resid = sum(model(data).^expon)

    return resid

end

Calling @code_warntype Flux.gradient(m -> Loss1(m, P), mod1) produces no type instabilities.

But calling @code_warntype Flux.gradient(m -> Loss2(m, P), mod1) does. Here’s the output:

MethodInstance for Zygote.gradient(::var"#8#9", ::Dense{typeof(identity), Matrix{Float32}, Vector{Float32}})
  from gradient(f, args...) in Zygote at C:\Users\Patrick\.julia\packages\Zygote\g2w9o\src\compiler\interface.jl:95
Arguments
  #self#::Core.Const(Zygote.gradient)
  f::Core.Const(var"#8#9"())
  args::Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}
Locals
  @_4::Int64
  grad::Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, Vector{Float32}, Nothing}}}
  back::Zygote.var"#60#61"{typeof(∂(#8))}
  y::Float64
Body::Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, Vector{Float32}, Nothing}}}
1 ─ %1  = Core.tuple(f)::Core.Const((var"#8#9"(),))
│   %2  = Core._apply_iterate(Base.iterate, Zygote.pullback, %1, args)::Core.PartialStruct(Tuple{Float64, Zygote.var"#60#61"{typeof(∂(#8))}}, Any[Float64, Core.PartialStruct(Zygote.var"#60#61"{typeof(∂(#8))}, Any[Core.PartialStruct(typeof(∂(#8)), Any[Core.PartialStruct(Tuple{typeof(∂(Loss1)), Zygote.var"#1923#back#149"{Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, TestParas}}}, Any[Core.PartialStruct(typeof(∂(Loss1)), Any[Core.PartialStruct(Tuple{typeof(∂(materialize)), Zygote.var"#2791#back#542"{Zygote.var"#538#540"{Vector{Float64}}}, typeof(∂(λ)), Zygote.var"#2077#back#218"{Zygote.var"#back#217"{:expon, Zygote.Context{false}, TestParas, Float64}}, typeof(∂(broadcasted)), Zygote.ZBack{ChainRules.var"#ones_pullback#797"{Tuple{Int64}}}}, Any[typeof(∂(materialize)), Zygote.var"#2791#back#542"{Zygote.var"#538#540"{Vector{Float64}}}, typeof(∂(λ)), Zygote.var"#2077#back#218"{Zygote.var"#back#217"{:expon, Zygote.Context{false}, TestParas, Float64}}, typeof(∂(broadcasted)), Core.Const(Zygote.ZBack{ChainRules.var"#ones_pullback#797"{Tuple{Int64}}}(ChainRules.var"#ones_pullback#797"{Tuple{Int64}}((2,))))])]), Core.PartialStruct(Zygote.var"#1923#back#149"{Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, TestParas}}, Any[Core.PartialStruct(Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, TestParas}, Any[Zygote.Context{false}, Core.Const(:(Main.P)), Core.Const(TestParas
  expon: Float64 2.0
)])])])])])])
│   %3  = Base.indexed_iterate(%2, 1)::Core.PartialStruct(Tuple{Float64, Int64}, Any[Float64, Core.Const(2)])
│         (y = Core.getfield(%3, 1))
│         (@_4 = Core.getfield(%3, 2))
│   %6  = Base.indexed_iterate(%2, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Zygote.var"#60#61"{typeof(∂(#8))}, Int64}, Any[Core.PartialStruct(Zygote.var"#60#61"{typeof(∂(#8))}, Any[Core.PartialStruct(typeof(∂(#8)), Any[Core.PartialStruct(Tuple{typeof(∂(Loss1)), Zygote.var"#1923#back#149"{Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, TestParas}}}, Any[Core.PartialStruct(typeof(∂(Loss1)), Any[Core.PartialStruct(Tuple{typeof(∂(materialize)), Zygote.var"#2791#back#542"{Zygote.var"#538#540"{Vector{Float64}}}, typeof(∂(λ)), Zygote.var"#2077#back#218"{Zygote.var"#back#217"{:expon, Zygote.Context{false}, TestParas, Float64}}, typeof(∂(broadcasted)), Zygote.ZBack{ChainRules.var"#ones_pullback#797"{Tuple{Int64}}}}, Any[typeof(∂(materialize)), Zygote.var"#2791#back#542"{Zygote.var"#538#540"{Vector{Float64}}}, typeof(∂(λ)), Zygote.var"#2077#back#218"{Zygote.var"#back#217"{:expon, Zygote.Context{false}, TestParas, Float64}}, typeof(∂(broadcasted)), Core.Const(Zygote.ZBack{ChainRules.var"#ones_pullback#797"{Tuple{Int64}}}(ChainRules.var"#ones_pullback#797"{Tuple{Int64}}((2,))))])]), Core.PartialStruct(Zygote.var"#1923#back#149"{Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, TestParas}}, Any[Core.PartialStruct(Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, TestParas}, Any[Zygote.Context{false}, Core.Const(:(Main.P)), Core.Const(TestParas
  expon: Float64 2.0
)])])])])]), Core.Const(3)])
│         (back = Core.getfield(%6, 1))
│   %8  = Zygote.sensitivity(y)::Core.Const(1.0)
│         (grad = (back::Core.PartialStruct(Zygote.var"#60#61"{typeof(∂(#8))}, Any[Core.PartialStruct(typeof(∂(#8)), Any[Core.PartialStruct(Tuple{typeof(∂(Loss1)), Zygote.var"#1923#back#149"{Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, TestParas}}}, Any[Core.PartialStruct(typeof(∂(Loss1)), Any[Core.PartialStruct(Tuple{typeof(∂(materialize)), Zygote.var"#2791#back#542"{Zygote.var"#538#540"{Vector{Float64}}}, typeof(∂(λ)), Zygote.var"#2077#back#218"{Zygote.var"#back#217"{:expon, Zygote.Context{false}, TestParas, Float64}}, typeof(∂(broadcasted)), Zygote.ZBack{ChainRules.var"#ones_pullback#797"{Tuple{Int64}}}}, Any[typeof(∂(materialize)), Zygote.var"#2791#back#542"{Zygote.var"#538#540"{Vector{Float64}}}, typeof(∂(λ)), Zygote.var"#2077#back#218"{Zygote.var"#back#217"{:expon, Zygote.Context{false}, TestParas, Float64}}, typeof(∂(broadcasted)), Core.Const(Zygote.ZBack{ChainRules.var"#ones_pullback#797"{Tuple{Int64}}}(ChainRules.var"#ones_pullback#797"{Tuple{Int64}}((2,))))])]), Core.PartialStruct(Zygote.var"#1923#back#149"{Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, TestParas}}, Any[Core.PartialStruct(Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, TestParas}, Any[Zygote.Context{false}, Core.Const(:(Main.P)), Core.Const(TestParas
  expon: Float64 2.0
)])])])])]))(%8))
│   %10 = Zygote.isnothing(grad)::Core.Const(false)
└──       goto #3 if not %10
2 ─       Core.Const(:(return Zygote.nothing))
3 ┄ %13 = Zygote.map(Zygote._project, args, grad)::Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, Vector{Float32}, Nothing}}}
└──       return %13

MethodInstance for Zygote.gradient(::var"#10#11", ::Dense{typeof(identity), Matrix{Float32}, Vector{Float32}})
  from gradient(f, args...) in Zygote at C:\Users\Patrick\.julia\packages\Zygote\g2w9o\src\compiler\interface.jl:95
Arguments
  #self#::Core.Const(Zygote.gradient)
  f::Core.Const(var"#10#11"())
  args::Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}
Locals
  @_4::Int64
  grad::Union{Nothing, Tuple}
  back::Zygote.var"#60#61"
  y::Any
Body::Union{Nothing, Tuple{Any}}
1 ─ %1  = Core.tuple(f)::Core.Const((var"#10#11"(),))
│   %2  = Core._apply_iterate(Base.iterate, Zygote.pullback, %1, args)::Tuple{Any, Zygote.var"#60#61"}
│   %3  = Base.indexed_iterate(%2, 1)::Core.PartialStruct(Tuple{Any, Int64}, Any[Any, Core.Const(2)])
│         (y = Core.getfield(%3, 1))
│         (@_4 = Core.getfield(%3, 2))
│   %6  = Base.indexed_iterate(%2, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Zygote.var"#60#61", Int64}, Any[Zygote.var"#60#61", Core.Const(3)])
│         (back = Core.getfield(%6, 1))
│   %8  = Zygote.sensitivity(y)::Any
│         (grad = (back)(%8))
│   %10 = Zygote.isnothing(grad)::Bool
└──       goto #3 if not %10
2 ─       return Zygote.nothing
3 ─ %13 = Zygote.map(Zygote._project, args, grad::Tuple)::Tuple{Any}
└──       return %13

Is there another efficient syntax for obtaining the gradient with respect to a particular argument of a function with multiple arguments?

Open to other suggestions as well. Just trying to avoid having to use the P. syntax everywhere.

Thanks in advance!

It may not directly answers your question, but a few observations about the approach:

  1. Consistent float precision throughout the model. Although types can be inferred, it is usually desirable to keep the same float precision throughout the model. In this case, both data ones(2) and the parameters P are Float64, while the model itself is Flaot32 (as this is the default for Flux layers).
    Hence, it’s likely desirable to instead have data defined as ones(Float32, 2).
    For the parameters, if there are many and a common precision is desired, a possible approach could be to enforce that common precision through struc aprametrization:
@with_kw struct TestParas2{T<:Float32}
    expon::T = 2f0
    ...
end

Although it may be just for the sake of the example, it wouldn’t typically be advised to have the data generated within the loss function definition.

For this example, the following may be a preferable pattern:

function loss1(m, x, P)
    resid = sum(m(x) .^ P.expon)
    return resid
end

Otherwise, I think some type unstability may have come from evaluating the gradients directly in the REPL, which would explain the need to specifiy a const. If the gradient evaluation is performed within a function, as will likely happen in the real running condition, then there’s no need to specify P as a
const.

Otherwise, although I understand that getting rid of the P. prefix was the original intention, I think it can on the other hand provides some clarity regarding the origin of these parameters (ie. expon).

An potential alternative could be to specify all kwargs in the loss function itself, which could then be passed as a simple Dict. ie:

function loss2(m, x; expon)
    resid = sum(m(x) .^ expon)
    return resid
end
P = Dict(:expon => 2f0)
loss2(m1, x1; P...)
@code_warntype loss2(m1, x1; P...)

Blockquote
Although it may be just for the sake of the example, it wouldn’t typically be advised to have the data generated within the loss function definition.

The example I gave doesn’t make this clear, but the training process for the neural network involves taking random draws from the domain of the function to be approximated (similar to stochastic gradient descent). These are used to evaluate the left-hand side of the functional equation. The functional equation is a nonlinear stochastic difference equation “guess” for the neural network parameters. This is what requires the data to be generated within the loss function …the neural network itself is needed to generate the data over which the loss is evaluated. The functional equation is a stochastic difference equation. It would have been more accurate to write the functional equation as something like:

g(f(x),P) = h[f(j(f(x)),P)] where j() is a fairly simple function that maps the time t values of the unknown function, f(x_t) into time t+1 values of the function arguments x_{t+1}.

Blockquote
Consistent float precision throughout the model. Although types can be inferred, it is usually desirable to keep the same float precision throughout the model.

Noted. I’ll have to look into how to force Float64 for the model and also test to see whether I can get away with Float32. I’ve had previous experience using different solution methods where accuracy of has suffered when using Float32, but I can’t know for certain without testing.

Blockquote
I think some type unstability may have come from evaluating the gradients directly in the REPL, which would explain the need to specifiy a const . If the gradient evaluation is performed within a function, as will likely happen in the real running condition, then there’s no need to specify P as a
const

My investigations into these type instabilities began from a quest to diagnose what appeared to be relatively poor performance of my code compared to the Python/Tensor flow code I was trying to replicate. As you suggest, in the full application, the call to gradient was within another function and I was originally using @unpack. However, I still found I could obtain a roughly 15%-20% speedup by declaring all my parameters as consts (each one as a separate global constant).

I experimented with this a bit by enclosing the gradient call in the code above inside a function and testing the two alternate approaches (also set everything to Float32):

@with_kw struct TestParas
    expon::Float32 = 2.0
end

const P = TestParas()  #declared as constant

P2 = TestParas()  #declared normally

function Loss1(model, P)

    data = ones(Float32,2)

    resid = sum(model(data).^P.expon)

    return resid

end

function Loss2(model, P)

    @unpack expon = P

    data = ones(Float32,2)

    resid = sum(model(data).^expon)

    return resid

end

function testgrad1(k, P)
    for i= 1:k
        Flux.gradient(m -> Loss1(m,P), mod1)
    end
end

function testgrad2(k, P)
    for i= 1:k
        Flux.gradient(m -> Loss2(m,P), mod1)
    end
end
@btime testgrad1(1000, P)
1.956 ms (38000 allocations: 2.93 MiB)
@btime testgrad2(1000, P)
6.957 ms (72000 allocations: 5.14 MiB)
@btime testgrad1(1000, P2)
1.942 ms (38000 allocations: 2.93 MiB)
@btime testgrad2(1000, P2)
6.970 ms (72000 allocations: 5.14 MiB)

These results show (as you suggested) that enclosing in a function alleviates the need to declare the struct as a const. But they also seem to suggest that a type instability issue persists when @unpack is used. I don’t know exactly how to diagnose it though. My understanding is that @code_warntype only detects type inference issues within the top function it is called on (i.e., testgrad) and not any internal function calls (i.e., call to Flux.gradient within testgrad4).

I’ll try out the Dict approach and provide some timings for that tomorrow.

Thanks for these very helpful pointers!

Yes and no. If the internal function can hide the type stability (e.g. by having a well-typed return value), then @code_warntype won’t capture it in the outer function. However, this is usually not the case and instabilities in internal functions bubble up to the outer function as well.

Back to the subject of this thread, what you’re seeing is a bad interaction of how Julia lowers some references to properties (in this case functions) on modules and how Zygote (Flux’s default AD, from which it reexports functions like gradient) differentiates code. @unpack expands to a call to Unpack.unpack with some (for our purposes unimportant) supporting code. Ordinarily this wouldn’t be a problem and Zygote would just differentiate the unpack call. However, quirks of how this particular macro and lowering in general work in Julia means that Zygote actually sees the following:

unpack = getproperty(Unpack, :unpack)
unpack(...)

Instead of:

Unpack.unpack(...)

Zygote is not smart enough to see that the getproperty call is really a function lookup that can be resolved statically, so it treats it as a generic property access on an arbitrary value. This leads to the generation of type unstable code to handle the getproperty call, where really we’d like to just inline it instead.

What can be done about this? At least for your example, we can try to detect this type of module field lookup and tell Zygote to not differentiate it. That should eliminate the aforementioned type stability. The big question is whether doing so would break anything else, so I’ve opened Don't differentiate getproperty on const module fields by ToucheSir · Pull Request #1371 · FluxML/Zygote.jl · GitHub to check.

1 Like

Thanks for doing that!

-Patrick