Hello,
I would like to train a neural network s_\theta : \mathbb{R}^{n_x} \times \mathbb{R}^{n_y} \to \mathbb{R}^{n_x} parameterized by the parameters \theta defined on a space of joint variables (x,y) \in \mathbb{R}^{n_x} \times \mathbb{R}^{n_y}.
The loss function involves the divergence of s_\theta(x,y) with respect to x, i.e. \nabla_x \cdot s(x,y) = \sum_{k=1}^{n_x} \partial s_\theta(x,y)/ \partial x_k
Note that \nabla_x \cdot s(x,y) also corresponds to the trace of the Jacobian \nabla_x \cdot s(x,y) = \operatorname{trace}(\nabla_x s(x,y))
Current issues:
-
How to deal with a function of several variables (x and y) in Flux. For now, I am concatenating these variables, but there might be a better solution.
-
I am struggling with try/catch issues to compute the gradient of the loss function with Zygote. Is it possible to only compute some partial derivatives with
Zygote.jacobian
.
I include a MWE:
using Flux
using LinearAlgebra
using Zygote
Nx = 5
Ny = 10
# Generate dataset
M = 100
X = randn(Nx, M)
Y = randn(Ny, M)
XY = convert(Matrix{Float32}, [X; Y]);
# Create model to train
model = Chain(Dense(Nx + Ny => 128, relu),
Dense(128 => 128, relu),
Dense(128 => Nx))
params = Flux.params(model)
# We use the variable z to denote the augmented variable [x;y]
# Loss function for one sample z = [x; y]
function loss_sample(f, z)
x = z[1:Nx]
y = z[Nx+1:end]
diver_term = tr(Zygote.jacobian(xprime -> f([xprime; y]), x)[1])
return diver_term
end
# Loss function for the dataset
model_loss(f, Z) = sum(z-> loss_sample(f, z), eachcol(Z))
loss, grad = Flux.withgradient(params) do
model_loss(model, XY)
end
Error message:
Compiling Tuple{ChainRules.var"#1415#1420"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Tuple{UnitRange{Int64}}, Vector{Float32}}}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations
Stacktrace:
[1] macro expansion
@ ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:101 [inlined]
[2] _pullback(::Zygote.Context{true}, ::ChainRules.var"#1415#1420"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Tuple{UnitRange{Int64}}, Vector{Float32}})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:101
[3] _pullback
@ ~/.julia/packages/ChainRulesCore/7MWx2/src/tangent_types/thunks.jl:204 [inlined]
[4] _pullback
@ ~/.julia/packages/ChainRulesCore/7MWx2/src/tangent_types/thunks.jl:237 [inlined]
[5] _pullback(ctx::Zygote.Context{true}, f::typeof(ChainRulesCore.unthunk), args::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1415#1420"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Tuple{UnitRange{Int64}}, Vector{Float32}}}, ChainRules.var"#1414#1419"{Tuple{UnitRange{Int64}}, Vector{Float32}}})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[6] _pullback
@ ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:110 [inlined]
[7] (::Zygote.var"#661#665"{Zygote.Context{true}, typeof(Zygote.wrap_chainrules_output)})(args::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1415#1420"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Tuple{UnitRange{Int64}}, Vector{Float32}}}, ChainRules.var"#1414#1419"{Tuple{UnitRange{Int64}}, Vector{Float32}}})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/lib/array.jl:183
[8] map
@ ./tuple.jl:275 [inlined]
[9] ∇map(cx::Zygote.Context{true}, f::typeof(Zygote.wrap_chainrules_output), args::Tuple{ChainRulesCore.NoTangent, ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1415#1420"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Tuple{UnitRange{Int64}}, Vector{Float32}}}, ChainRules.var"#1414#1419"{Tuple{UnitRange{Int64}}, Vector{Float32}}}, ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1415#1420"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Tuple{UnitRange{Int64}}, Vector{Float32}}}, ChainRules.var"#1414#1419"{Tuple{UnitRange{Int64}}, Vector{Float32}}}})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/lib/array.jl:183
[10] adjoint
@ ~/.julia/packages/Zygote/YYT6v/src/lib/array.jl:209 [inlined]
[11] _pullback
@ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:66 [inlined]
[12] _pullback
@ ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:111 [inlined]
[13] _pullback(ctx::Zygote.Context{true}, f::typeof(Zygote.wrap_chainrules_output), args::Tuple{ChainRulesCore.NoTangent, ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1415#1420"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Tuple{UnitRange{Int64}}, Vector{Float32}}}, ChainRules.var"#1414#1419"{Tuple{UnitRange{Int64}}, Vector{Float32}}}, ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1415#1420"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Tuple{UnitRange{Int64}}, Vector{Float32}}}, ChainRules.var"#1414#1419"{Tuple{UnitRange{Int64}}, Vector{Float32}}}})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[14] _pullback
@ ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:211 [inlined]
[15] _pullback(ctx::Zygote.Context{true}, f::Zygote.ZBack{ChainRules.var"#vcat_pullback#1416"{Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}, Tuple{Tuple{Int64}, Tuple{Int64}}, Val{1}}}, args::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[16] _pullback
@ ./In[13]:7 [inlined]
**Truncated part**
[50] collect(itr::Base.Generator{Vector{SubArray{Float32, 1, Matrix{Float32}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}, ChainRules.var"#1659#1664"{Zygote.ZygoteRuleConfig{Zygote.Context{true}}, var"#15#16"{Chain{Tuple{Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}}})
@ Base ./array.jl:782
[51] rrule(config::Zygote.ZygoteRuleConfig{Zygote.Context{true}}, ::typeof(sum), f::var"#15#16"{Chain{Tuple{Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, xs::Vector{SubArray{Float32, 1, Matrix{Float32}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}; dims::Function)
@ ChainRules ~/.julia/packages/ChainRules/snrkz/src/rulesets/Base/mapreduce.jl:102
[52] rrule(config::Zygote.ZygoteRuleConfig{Zygote.Context{true}}, ::typeof(sum), f::var"#15#16"{Chain{Tuple{Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, xs::Vector{SubArray{Float32, 1, Matrix{Float32}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}})
@ ChainRules ~/.julia/packages/ChainRules/snrkz/src/rulesets/Base/mapreduce.jl:76
[53] chain_rrule(::Zygote.ZygoteRuleConfig{Zygote.Context{true}}, ::Function, ::Function, ::Vararg{Any})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:223
[54] macro expansion
@ ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:101 [inlined]
[55] _pullback
@ ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:101 [inlined]
[56] _pullback
@ ./In[14]:1 [inlined]
[57] _pullback(::Zygote.Context{true}, ::typeof(model_loss), ::Chain{Tuple{Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, ::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[58] _pullback
@ ./In[15]:2 [inlined]
[59] _pullback(::Zygote.Context{true}, ::var"#17#18")
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[60] pullback(f::Function, ps::Params{Zygote.Buffer{Any, Vector{Any}}})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:414
[61] withgradient(f::Function, args::Params{Zygote.Buffer{Any, Vector{Any}}})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:154