Type stability with Flux gradient of loss function requiring parameters

Hello,

I’m using Flux to replicate an economics application of machine learning (originally done in Python/Tensor flow here).

In this problem, neural networks are used to solve a functional equation of the form g(f(x),p) = h(f(x),P), where f(x) is an unknown function to be approximated by a neural network and P are some parameters that may need to be adjusted periodically.

The loss function used to train the neural network computes the residuals to the functional equation and requires simulation of f(x). In short, it’s a complicated function that requires several parameters.

A type stability problem arises when I compute the gradient using code like:

``````grads = Flux.gradient(m -> loss(m,P))
``````

where `P` is a struct containing various parameters.

This arises from the fact that `P` is a non-constant global in the closure `m -> loss (m,P)`.

One remedy is to declare P as a const. This works okay. However, I typically use Parameters.jl so that I can use the @unpack macro. When I do that, the type instability returns even if I declare P as a const. Below is a minimal example to illustrate the problem:

``````using Flux
using Parameters

@with_kw struct TestParas
expon::Float64 = 2.0
end

const P = TestParas()

mod1 = Dense(2 => 1)

function Loss1(model, P)

data = ones(2)

resid = sum(model(data).^P.expon)

return resid

end

function Loss2(model, P)

@unpack expon = P

data = ones(2)

resid = sum(model(data).^expon)

return resid

end
``````

Calling `@code_warntype Flux.gradient(m -> Loss1(m, P), mod1)` produces no type instabilities.

But calling `@code_warntype Flux.gradient(m -> Loss2(m, P), mod1)` does. Here’s the output:

``````MethodInstance for Zygote.gradient(::var"#8#9", ::Dense{typeof(identity), Matrix{Float32}, Vector{Float32}})
from gradient(f, args...) in Zygote at C:\Users\Patrick\.julia\packages\Zygote\g2w9o\src\compiler\interface.jl:95
Arguments
f::Core.Const(var"#8#9"())
args::Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}
Locals
@_4::Int64
grad::Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, Vector{Float32}, Nothing}}}
back::Zygote.var"#60#61"{typeof(∂(#8))}
y::Float64
Body::Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, Vector{Float32}, Nothing}}}
1 ─ %1  = Core.tuple(f)::Core.Const((var"#8#9"(),))
│   %2  = Core._apply_iterate(Base.iterate, Zygote.pullback, %1, args)::Core.PartialStruct(Tuple{Float64, Zygote.var"#60#61"{typeof(∂(#8))}}, Any[Float64, Core.PartialStruct(Zygote.var"#60#61"{typeof(∂(#8))}, Any[Core.PartialStruct(typeof(∂(#8)), Any[Core.PartialStruct(Tuple{typeof(∂(Loss1)), Zygote.var"#1923#back#149"{Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, TestParas}}}, Any[Core.PartialStruct(typeof(∂(Loss1)), Any[Core.PartialStruct(Tuple{typeof(∂(materialize)), Zygote.var"#2791#back#542"{Zygote.var"#538#540"{Vector{Float64}}}, typeof(∂(λ)), Zygote.var"#2077#back#218"{Zygote.var"#back#217"{:expon, Zygote.Context{false}, TestParas, Float64}}, typeof(∂(broadcasted)), Zygote.ZBack{ChainRules.var"#ones_pullback#797"{Tuple{Int64}}}}, Any[typeof(∂(materialize)), Zygote.var"#2791#back#542"{Zygote.var"#538#540"{Vector{Float64}}}, typeof(∂(λ)), Zygote.var"#2077#back#218"{Zygote.var"#back#217"{:expon, Zygote.Context{false}, TestParas, Float64}}, typeof(∂(broadcasted)), Core.Const(Zygote.ZBack{ChainRules.var"#ones_pullback#797"{Tuple{Int64}}}(ChainRules.var"#ones_pullback#797"{Tuple{Int64}}((2,))))])]), Core.PartialStruct(Zygote.var"#1923#back#149"{Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, TestParas}}, Any[Core.PartialStruct(Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, TestParas}, Any[Zygote.Context{false}, Core.Const(:(Main.P)), Core.Const(TestParas
expon: Float64 2.0
)])])])])])])
│   %3  = Base.indexed_iterate(%2, 1)::Core.PartialStruct(Tuple{Float64, Int64}, Any[Float64, Core.Const(2)])
│         (y = Core.getfield(%3, 1))
│         (@_4 = Core.getfield(%3, 2))
│   %6  = Base.indexed_iterate(%2, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Zygote.var"#60#61"{typeof(∂(#8))}, Int64}, Any[Core.PartialStruct(Zygote.var"#60#61"{typeof(∂(#8))}, Any[Core.PartialStruct(typeof(∂(#8)), Any[Core.PartialStruct(Tuple{typeof(∂(Loss1)), Zygote.var"#1923#back#149"{Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, TestParas}}}, Any[Core.PartialStruct(typeof(∂(Loss1)), Any[Core.PartialStruct(Tuple{typeof(∂(materialize)), Zygote.var"#2791#back#542"{Zygote.var"#538#540"{Vector{Float64}}}, typeof(∂(λ)), Zygote.var"#2077#back#218"{Zygote.var"#back#217"{:expon, Zygote.Context{false}, TestParas, Float64}}, typeof(∂(broadcasted)), Zygote.ZBack{ChainRules.var"#ones_pullback#797"{Tuple{Int64}}}}, Any[typeof(∂(materialize)), Zygote.var"#2791#back#542"{Zygote.var"#538#540"{Vector{Float64}}}, typeof(∂(λ)), Zygote.var"#2077#back#218"{Zygote.var"#back#217"{:expon, Zygote.Context{false}, TestParas, Float64}}, typeof(∂(broadcasted)), Core.Const(Zygote.ZBack{ChainRules.var"#ones_pullback#797"{Tuple{Int64}}}(ChainRules.var"#ones_pullback#797"{Tuple{Int64}}((2,))))])]), Core.PartialStruct(Zygote.var"#1923#back#149"{Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, TestParas}}, Any[Core.PartialStruct(Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, TestParas}, Any[Zygote.Context{false}, Core.Const(:(Main.P)), Core.Const(TestParas
expon: Float64 2.0
)])])])])]), Core.Const(3)])
│         (back = Core.getfield(%6, 1))
│   %8  = Zygote.sensitivity(y)::Core.Const(1.0)
│         (grad = (back::Core.PartialStruct(Zygote.var"#60#61"{typeof(∂(#8))}, Any[Core.PartialStruct(typeof(∂(#8)), Any[Core.PartialStruct(Tuple{typeof(∂(Loss1)), Zygote.var"#1923#back#149"{Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, TestParas}}}, Any[Core.PartialStruct(typeof(∂(Loss1)), Any[Core.PartialStruct(Tuple{typeof(∂(materialize)), Zygote.var"#2791#back#542"{Zygote.var"#538#540"{Vector{Float64}}}, typeof(∂(λ)), Zygote.var"#2077#back#218"{Zygote.var"#back#217"{:expon, Zygote.Context{false}, TestParas, Float64}}, typeof(∂(broadcasted)), Zygote.ZBack{ChainRules.var"#ones_pullback#797"{Tuple{Int64}}}}, Any[typeof(∂(materialize)), Zygote.var"#2791#back#542"{Zygote.var"#538#540"{Vector{Float64}}}, typeof(∂(λ)), Zygote.var"#2077#back#218"{Zygote.var"#back#217"{:expon, Zygote.Context{false}, TestParas, Float64}}, typeof(∂(broadcasted)), Core.Const(Zygote.ZBack{ChainRules.var"#ones_pullback#797"{Tuple{Int64}}}(ChainRules.var"#ones_pullback#797"{Tuple{Int64}}((2,))))])]), Core.PartialStruct(Zygote.var"#1923#back#149"{Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, TestParas}}, Any[Core.PartialStruct(Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, TestParas}, Any[Zygote.Context{false}, Core.Const(:(Main.P)), Core.Const(TestParas
expon: Float64 2.0
)])])])])]))(%8))
└──       goto #3 if not %10
2 ─       Core.Const(:(return Zygote.nothing))
3 ┄ %13 = Zygote.map(Zygote._project, args, grad)::Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, Vector{Float32}, Nothing}}}
└──       return %13

MethodInstance for Zygote.gradient(::var"#10#11", ::Dense{typeof(identity), Matrix{Float32}, Vector{Float32}})
from gradient(f, args...) in Zygote at C:\Users\Patrick\.julia\packages\Zygote\g2w9o\src\compiler\interface.jl:95
Arguments
f::Core.Const(var"#10#11"())
args::Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}
Locals
@_4::Int64
back::Zygote.var"#60#61"
y::Any
Body::Union{Nothing, Tuple{Any}}
1 ─ %1  = Core.tuple(f)::Core.Const((var"#10#11"(),))
│   %2  = Core._apply_iterate(Base.iterate, Zygote.pullback, %1, args)::Tuple{Any, Zygote.var"#60#61"}
│   %3  = Base.indexed_iterate(%2, 1)::Core.PartialStruct(Tuple{Any, Int64}, Any[Any, Core.Const(2)])
│         (y = Core.getfield(%3, 1))
│         (@_4 = Core.getfield(%3, 2))
│   %6  = Base.indexed_iterate(%2, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Zygote.var"#60#61", Int64}, Any[Zygote.var"#60#61", Core.Const(3)])
│         (back = Core.getfield(%6, 1))
│   %8  = Zygote.sensitivity(y)::Any
└──       goto #3 if not %10
2 ─       return Zygote.nothing
3 ─ %13 = Zygote.map(Zygote._project, args, grad::Tuple)::Tuple{Any}
└──       return %13
``````

Is there another efficient syntax for obtaining the gradient with respect to a particular argument of a function with multiple arguments?

Open to other suggestions as well. Just trying to avoid having to use the `P.` syntax everywhere.

1. Consistent float precision throughout the model. Although types can be inferred, it is usually desirable to keep the same float precision throughout the model. In this case, both data `ones(2)` and the parameters `P` are `Float64`, while the model itself is `Flaot32` (as this is the default for Flux layers).
Hence, it’s likely desirable to instead have data defined as `ones(Float32, 2)`.
For the parameters, if there are many and a common precision is desired, a possible approach could be to enforce that common precision through struc aprametrization:
``````@with_kw struct TestParas2{T<:Float32}
expon::T = 2f0
...
end
``````

Although it may be just for the sake of the example, it wouldn’t typically be advised to have the data generated within the loss function definition.

For this example, the following may be a preferable pattern:

``````function loss1(m, x, P)
resid = sum(m(x) .^ P.expon)
return resid
end
``````

Otherwise, I think some type unstability may have come from evaluating the gradients directly in the REPL, which would explain the need to specifiy a `const`. If the gradient evaluation is performed within a function, as will likely happen in the real running condition, then there’s no need to specify `P` as a
`const`.

Otherwise, although I understand that getting rid of the `P.` prefix was the original intention, I think it can on the other hand provides some clarity regarding the origin of these parameters (ie. `expon`).

An potential alternative could be to specify all kwargs in the loss function itself, which could then be passed as a simple `Dict`. ie:

``````function loss2(m, x; expon)
resid = sum(m(x) .^ expon)
return resid
end
P = Dict(:expon => 2f0)
loss2(m1, x1; P...)
@code_warntype loss2(m1, x1; P...)
``````

Blockquote
Although it may be just for the sake of the example, it wouldn’t typically be advised to have the data generated within the loss function definition.

The example I gave doesn’t make this clear, but the training process for the neural network involves taking random draws from the domain of the function to be approximated (similar to stochastic gradient descent). These are used to evaluate the left-hand side of the functional equation. The functional equation is a nonlinear stochastic difference equation “guess” for the neural network parameters. This is what requires the data to be generated within the loss function …the neural network itself is needed to generate the data over which the loss is evaluated. The functional equation is a stochastic difference equation. It would have been more accurate to write the functional equation as something like:

g(f(x),P) = h[f(j(f(x)),P)] where j() is a fairly simple function that maps the time t values of the unknown function, f(x_t) into time t+1 values of the function arguments x_{t+1}.

Blockquote
Consistent float precision throughout the model. Although types can be inferred, it is usually desirable to keep the same float precision throughout the model.

Noted. I’ll have to look into how to force Float64 for the model and also test to see whether I can get away with Float32. I’ve had previous experience using different solution methods where accuracy of has suffered when using Float32, but I can’t know for certain without testing.

Blockquote
I think some type unstability may have come from evaluating the gradients directly in the REPL, which would explain the need to specifiy a `const` . If the gradient evaluation is performed within a function, as will likely happen in the real running condition, then there’s no need to specify `P` as a
`const`

My investigations into these type instabilities began from a quest to diagnose what appeared to be relatively poor performance of my code compared to the Python/Tensor flow code I was trying to replicate. As you suggest, in the full application, the call to gradient was within another function and I was originally using `@unpack`. However, I still found I could obtain a roughly 15%-20% speedup by declaring all my parameters as consts (each one as a separate global constant).

I experimented with this a bit by enclosing the gradient call in the code above inside a function and testing the two alternate approaches (also set everything to Float32):

``````@with_kw struct TestParas
expon::Float32 = 2.0
end

const P = TestParas()  #declared as constant

P2 = TestParas()  #declared normally

function Loss1(model, P)

data = ones(Float32,2)

resid = sum(model(data).^P.expon)

return resid

end

function Loss2(model, P)

@unpack expon = P

data = ones(Float32,2)

resid = sum(model(data).^expon)

return resid

end

for i= 1:k
end
end

for i= 1:k
end
end
``````
``````@btime testgrad1(1000, P)
1.956 ms (38000 allocations: 2.93 MiB)
``````
``````@btime testgrad2(1000, P)
6.957 ms (72000 allocations: 5.14 MiB)
``````
``````@btime testgrad1(1000, P2)
1.942 ms (38000 allocations: 2.93 MiB)
``````
``````@btime testgrad2(1000, P2)
6.970 ms (72000 allocations: 5.14 MiB)
``````

These results show (as you suggested) that enclosing in a function alleviates the need to declare the struct as a const. But they also seem to suggest that a type instability issue persists when `@unpack` is used. I don’t know exactly how to diagnose it though. My understanding is that @code_warntype only detects type inference issues within the top function it is called on (i.e., testgrad) and not any internal function calls (i.e., call to Flux.gradient within testgrad4).

I’ll try out the Dict approach and provide some timings for that tomorrow.

Thanks for these very helpful pointers!

Yes and no. If the internal function can hide the type stability (e.g. by having a well-typed return value), then `@code_warntype` won’t capture it in the outer function. However, this is usually not the case and instabilities in internal functions bubble up to the outer function as well.

Back to the subject of this thread, what you’re seeing is a bad interaction of how Julia lowers some references to properties (in this case functions) on modules and how Zygote (Flux’s default AD, from which it reexports functions like `gradient`) differentiates code. `@unpack` expands to a call to `Unpack.unpack` with some (for our purposes unimportant) supporting code. Ordinarily this wouldn’t be a problem and Zygote would just differentiate the `unpack` call. However, quirks of how this particular macro and lowering in general work in Julia means that Zygote actually sees the following:

``````unpack = getproperty(Unpack, :unpack)
unpack(...)
``````

``````Unpack.unpack(...)
Zygote is not smart enough to see that the `getproperty` call is really a function lookup that can be resolved statically, so it treats it as a generic property access on an arbitrary value. This leads to the generation of type unstable code to handle the `getproperty` call, where really we’d like to just inline it instead.