Hello,
I’m using Flux to replicate an economics application of machine learning (originally done in Python/Tensor flow here).
In this problem, neural networks are used to solve a functional equation of the form g(f(x),p) = h(f(x),P), where f(x) is an unknown function to be approximated by a neural network and P are some parameters that may need to be adjusted periodically.
The loss function used to train the neural network computes the residuals to the functional equation and requires simulation of f(x). In short, it’s a complicated function that requires several parameters.
A type stability problem arises when I compute the gradient using code like:
grads = Flux.gradient(m -> loss(m,P))
where P
is a struct containing various parameters.
This arises from the fact that P
is a non-constant global in the closure m -> loss (m,P)
.
One remedy is to declare P as a const. This works okay. However, I typically use Parameters.jl so that I can use the @unpack macro. When I do that, the type instability returns even if I declare P as a const. Below is a minimal example to illustrate the problem:
using Flux
using Parameters
@with_kw struct TestParas
expon::Float64 = 2.0
end
const P = TestParas()
mod1 = Dense(2 => 1)
function Loss1(model, P)
data = ones(2)
resid = sum(model(data).^P.expon)
return resid
end
function Loss2(model, P)
@unpack expon = P
data = ones(2)
resid = sum(model(data).^expon)
return resid
end
Calling @code_warntype Flux.gradient(m -> Loss1(m, P), mod1)
produces no type instabilities.
But calling @code_warntype Flux.gradient(m -> Loss2(m, P), mod1)
does. Here’s the output:
MethodInstance for Zygote.gradient(::var"#8#9", ::Dense{typeof(identity), Matrix{Float32}, Vector{Float32}})
from gradient(f, args...) in Zygote at C:\Users\Patrick\.julia\packages\Zygote\g2w9o\src\compiler\interface.jl:95
Arguments
#self#::Core.Const(Zygote.gradient)
f::Core.Const(var"#8#9"())
args::Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}
Locals
@_4::Int64
grad::Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, Vector{Float32}, Nothing}}}
back::Zygote.var"#60#61"{typeof(∂(#8))}
y::Float64
Body::Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, Vector{Float32}, Nothing}}}
1 ─ %1 = Core.tuple(f)::Core.Const((var"#8#9"(),))
│ %2 = Core._apply_iterate(Base.iterate, Zygote.pullback, %1, args)::Core.PartialStruct(Tuple{Float64, Zygote.var"#60#61"{typeof(∂(#8))}}, Any[Float64, Core.PartialStruct(Zygote.var"#60#61"{typeof(∂(#8))}, Any[Core.PartialStruct(typeof(∂(#8)), Any[Core.PartialStruct(Tuple{typeof(∂(Loss1)), Zygote.var"#1923#back#149"{Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, TestParas}}}, Any[Core.PartialStruct(typeof(∂(Loss1)), Any[Core.PartialStruct(Tuple{typeof(∂(materialize)), Zygote.var"#2791#back#542"{Zygote.var"#538#540"{Vector{Float64}}}, typeof(∂(λ)), Zygote.var"#2077#back#218"{Zygote.var"#back#217"{:expon, Zygote.Context{false}, TestParas, Float64}}, typeof(∂(broadcasted)), Zygote.ZBack{ChainRules.var"#ones_pullback#797"{Tuple{Int64}}}}, Any[typeof(∂(materialize)), Zygote.var"#2791#back#542"{Zygote.var"#538#540"{Vector{Float64}}}, typeof(∂(λ)), Zygote.var"#2077#back#218"{Zygote.var"#back#217"{:expon, Zygote.Context{false}, TestParas, Float64}}, typeof(∂(broadcasted)), Core.Const(Zygote.ZBack{ChainRules.var"#ones_pullback#797"{Tuple{Int64}}}(ChainRules.var"#ones_pullback#797"{Tuple{Int64}}((2,))))])]), Core.PartialStruct(Zygote.var"#1923#back#149"{Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, TestParas}}, Any[Core.PartialStruct(Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, TestParas}, Any[Zygote.Context{false}, Core.Const(:(Main.P)), Core.Const(TestParas
expon: Float64 2.0
)])])])])])])
│ %3 = Base.indexed_iterate(%2, 1)::Core.PartialStruct(Tuple{Float64, Int64}, Any[Float64, Core.Const(2)])
│ (y = Core.getfield(%3, 1))
│ (@_4 = Core.getfield(%3, 2))
│ %6 = Base.indexed_iterate(%2, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Zygote.var"#60#61"{typeof(∂(#8))}, Int64}, Any[Core.PartialStruct(Zygote.var"#60#61"{typeof(∂(#8))}, Any[Core.PartialStruct(typeof(∂(#8)), Any[Core.PartialStruct(Tuple{typeof(∂(Loss1)), Zygote.var"#1923#back#149"{Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, TestParas}}}, Any[Core.PartialStruct(typeof(∂(Loss1)), Any[Core.PartialStruct(Tuple{typeof(∂(materialize)), Zygote.var"#2791#back#542"{Zygote.var"#538#540"{Vector{Float64}}}, typeof(∂(λ)), Zygote.var"#2077#back#218"{Zygote.var"#back#217"{:expon, Zygote.Context{false}, TestParas, Float64}}, typeof(∂(broadcasted)), Zygote.ZBack{ChainRules.var"#ones_pullback#797"{Tuple{Int64}}}}, Any[typeof(∂(materialize)), Zygote.var"#2791#back#542"{Zygote.var"#538#540"{Vector{Float64}}}, typeof(∂(λ)), Zygote.var"#2077#back#218"{Zygote.var"#back#217"{:expon, Zygote.Context{false}, TestParas, Float64}}, typeof(∂(broadcasted)), Core.Const(Zygote.ZBack{ChainRules.var"#ones_pullback#797"{Tuple{Int64}}}(ChainRules.var"#ones_pullback#797"{Tuple{Int64}}((2,))))])]), Core.PartialStruct(Zygote.var"#1923#back#149"{Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, TestParas}}, Any[Core.PartialStruct(Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, TestParas}, Any[Zygote.Context{false}, Core.Const(:(Main.P)), Core.Const(TestParas
expon: Float64 2.0
)])])])])]), Core.Const(3)])
│ (back = Core.getfield(%6, 1))
│ %8 = Zygote.sensitivity(y)::Core.Const(1.0)
│ (grad = (back::Core.PartialStruct(Zygote.var"#60#61"{typeof(∂(#8))}, Any[Core.PartialStruct(typeof(∂(#8)), Any[Core.PartialStruct(Tuple{typeof(∂(Loss1)), Zygote.var"#1923#back#149"{Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, TestParas}}}, Any[Core.PartialStruct(typeof(∂(Loss1)), Any[Core.PartialStruct(Tuple{typeof(∂(materialize)), Zygote.var"#2791#back#542"{Zygote.var"#538#540"{Vector{Float64}}}, typeof(∂(λ)), Zygote.var"#2077#back#218"{Zygote.var"#back#217"{:expon, Zygote.Context{false}, TestParas, Float64}}, typeof(∂(broadcasted)), Zygote.ZBack{ChainRules.var"#ones_pullback#797"{Tuple{Int64}}}}, Any[typeof(∂(materialize)), Zygote.var"#2791#back#542"{Zygote.var"#538#540"{Vector{Float64}}}, typeof(∂(λ)), Zygote.var"#2077#back#218"{Zygote.var"#back#217"{:expon, Zygote.Context{false}, TestParas, Float64}}, typeof(∂(broadcasted)), Core.Const(Zygote.ZBack{ChainRules.var"#ones_pullback#797"{Tuple{Int64}}}(ChainRules.var"#ones_pullback#797"{Tuple{Int64}}((2,))))])]), Core.PartialStruct(Zygote.var"#1923#back#149"{Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, TestParas}}, Any[Core.PartialStruct(Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, TestParas}, Any[Zygote.Context{false}, Core.Const(:(Main.P)), Core.Const(TestParas
expon: Float64 2.0
)])])])])]))(%8))
│ %10 = Zygote.isnothing(grad)::Core.Const(false)
└── goto #3 if not %10
2 ─ Core.Const(:(return Zygote.nothing))
3 ┄ %13 = Zygote.map(Zygote._project, args, grad)::Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, Vector{Float32}, Nothing}}}
└── return %13
MethodInstance for Zygote.gradient(::var"#10#11", ::Dense{typeof(identity), Matrix{Float32}, Vector{Float32}})
from gradient(f, args...) in Zygote at C:\Users\Patrick\.julia\packages\Zygote\g2w9o\src\compiler\interface.jl:95
Arguments
#self#::Core.Const(Zygote.gradient)
f::Core.Const(var"#10#11"())
args::Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}
Locals
@_4::Int64
grad::Union{Nothing, Tuple}
back::Zygote.var"#60#61"
y::Any
Body::Union{Nothing, Tuple{Any}}
1 ─ %1 = Core.tuple(f)::Core.Const((var"#10#11"(),))
│ %2 = Core._apply_iterate(Base.iterate, Zygote.pullback, %1, args)::Tuple{Any, Zygote.var"#60#61"}
│ %3 = Base.indexed_iterate(%2, 1)::Core.PartialStruct(Tuple{Any, Int64}, Any[Any, Core.Const(2)])
│ (y = Core.getfield(%3, 1))
│ (@_4 = Core.getfield(%3, 2))
│ %6 = Base.indexed_iterate(%2, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Zygote.var"#60#61", Int64}, Any[Zygote.var"#60#61", Core.Const(3)])
│ (back = Core.getfield(%6, 1))
│ %8 = Zygote.sensitivity(y)::Any
│ (grad = (back)(%8))
│ %10 = Zygote.isnothing(grad)::Bool
└── goto #3 if not %10
2 ─ return Zygote.nothing
3 ─ %13 = Zygote.map(Zygote._project, args, grad::Tuple)::Tuple{Any}
└── return %13
Is there another efficient syntax for obtaining the gradient with respect to a particular argument of a function with multiple arguments?
Open to other suggestions as well. Just trying to avoid having to use the P.
syntax everywhere.
Thanks in advance!