Divergence loss term Partial Derivative Flux Zygote

Hello,

I would like to train a neural network s_\theta : \mathbb{R}^{n_x} \times \mathbb{R}^{n_y} \to \mathbb{R}^{n_x} parameterized by the parameters \theta defined on a space of joint variables (x,y) \in \mathbb{R}^{n_x} \times \mathbb{R}^{n_y}.

The loss function involves the divergence of s_\theta(x,y) with respect to x, i.e. \nabla_x \cdot s(x,y) = \sum_{k=1}^{n_x} \partial s_\theta(x,y)/ \partial x_k

Note that \nabla_x \cdot s(x,y) also corresponds to the trace of the Jacobian \nabla_x \cdot s(x,y) = \operatorname{trace}(\nabla_x s(x,y))

Current issues:

  • How to deal with a function of several variables (x and y) in Flux. For now, I am concatenating these variables, but there might be a better solution.

  • I am struggling with try/catch issues to compute the gradient of the loss function with Zygote. Is it possible to only compute some partial derivatives with Zygote.jacobian.

I include a MWE:

using Flux
using LinearAlgebra
using Zygote

Nx = 5
Ny = 10

# Generate dataset

M = 100
X = randn(Nx, M)
Y = randn(Ny, M)
XY = convert(Matrix{Float32}, [X; Y]);

# Create model to train
model = Chain(Dense(Nx + Ny => 128, relu), 
           Dense(128 => 128, relu), 
           Dense(128 => Nx))

params = Flux.params(model)

# We use the variable z  to denote the augmented variable [x;y]

# Loss function for one sample z = [x; y]
function loss_sample(f, z)
    x = z[1:Nx]
    y = z[Nx+1:end]

    diver_term = tr(Zygote.jacobian(xprime -> f([xprime; y]), x)[1])

    return diver_term
end

# Loss function for the dataset
model_loss(f, Z) = sum(z-> loss_sample(f, z), eachcol(Z))

loss, grad = Flux.withgradient(params) do
    model_loss(model, XY)
end

Error message:

Compiling Tuple{ChainRules.var"#1415#1420"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Tuple{UnitRange{Int64}}, Vector{Float32}}}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations


Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:101 [inlined]
  [2] _pullback(::Zygote.Context{true}, ::ChainRules.var"#1415#1420"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Tuple{UnitRange{Int64}}, Vector{Float32}})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:101
  [3] _pullback
    @ ~/.julia/packages/ChainRulesCore/7MWx2/src/tangent_types/thunks.jl:204 [inlined]
  [4] _pullback
    @ ~/.julia/packages/ChainRulesCore/7MWx2/src/tangent_types/thunks.jl:237 [inlined]
  [5] _pullback(ctx::Zygote.Context{true}, f::typeof(ChainRulesCore.unthunk), args::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1415#1420"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Tuple{UnitRange{Int64}}, Vector{Float32}}}, ChainRules.var"#1414#1419"{Tuple{UnitRange{Int64}}, Vector{Float32}}})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [6] _pullback
    @ ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:110 [inlined]
  [7] (::Zygote.var"#661#665"{Zygote.Context{true}, typeof(Zygote.wrap_chainrules_output)})(args::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1415#1420"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Tuple{UnitRange{Int64}}, Vector{Float32}}}, ChainRules.var"#1414#1419"{Tuple{UnitRange{Int64}}, Vector{Float32}}})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/lib/array.jl:183
  [8] map
    @ ./tuple.jl:275 [inlined]
  [9] ∇map(cx::Zygote.Context{true}, f::typeof(Zygote.wrap_chainrules_output), args::Tuple{ChainRulesCore.NoTangent, ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1415#1420"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Tuple{UnitRange{Int64}}, Vector{Float32}}}, ChainRules.var"#1414#1419"{Tuple{UnitRange{Int64}}, Vector{Float32}}}, ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1415#1420"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Tuple{UnitRange{Int64}}, Vector{Float32}}}, ChainRules.var"#1414#1419"{Tuple{UnitRange{Int64}}, Vector{Float32}}}})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/lib/array.jl:183
 [10] adjoint
    @ ~/.julia/packages/Zygote/YYT6v/src/lib/array.jl:209 [inlined]
 [11] _pullback
    @ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:66 [inlined]
 [12] _pullback
    @ ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:111 [inlined]
 [13] _pullback(ctx::Zygote.Context{true}, f::typeof(Zygote.wrap_chainrules_output), args::Tuple{ChainRulesCore.NoTangent, ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1415#1420"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Tuple{UnitRange{Int64}}, Vector{Float32}}}, ChainRules.var"#1414#1419"{Tuple{UnitRange{Int64}}, Vector{Float32}}}, ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1415#1420"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Tuple{UnitRange{Int64}}, Vector{Float32}}}, ChainRules.var"#1414#1419"{Tuple{UnitRange{Int64}}, Vector{Float32}}}})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
 [14] _pullback
    @ ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:211 [inlined]
 [15] _pullback(ctx::Zygote.Context{true}, f::Zygote.ZBack{ChainRules.var"#vcat_pullback#1416"{Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}, Tuple{Tuple{Int64}, Tuple{Int64}}, Val{1}}}, args::Vector{Float32})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
 [16] _pullback
    @ ./In[13]:7 [inlined]

**Truncated part**


 [50] collect(itr::Base.Generator{Vector{SubArray{Float32, 1, Matrix{Float32}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}, ChainRules.var"#1659#1664"{Zygote.ZygoteRuleConfig{Zygote.Context{true}}, var"#15#16"{Chain{Tuple{Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}}})
    @ Base ./array.jl:782
 [51] rrule(config::Zygote.ZygoteRuleConfig{Zygote.Context{true}}, ::typeof(sum), f::var"#15#16"{Chain{Tuple{Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, xs::Vector{SubArray{Float32, 1, Matrix{Float32}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}; dims::Function)
    @ ChainRules ~/.julia/packages/ChainRules/snrkz/src/rulesets/Base/mapreduce.jl:102
 [52] rrule(config::Zygote.ZygoteRuleConfig{Zygote.Context{true}}, ::typeof(sum), f::var"#15#16"{Chain{Tuple{Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, xs::Vector{SubArray{Float32, 1, Matrix{Float32}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}})
    @ ChainRules ~/.julia/packages/ChainRules/snrkz/src/rulesets/Base/mapreduce.jl:76
 [53] chain_rrule(::Zygote.ZygoteRuleConfig{Zygote.Context{true}}, ::Function, ::Function, ::Vararg{Any})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:223
 [54] macro expansion
    @ ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:101 [inlined]
 [55] _pullback
    @ ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:101 [inlined]
 [56] _pullback
    @ ./In[14]:1 [inlined]
 [57] _pullback(::Zygote.Context{true}, ::typeof(model_loss), ::Chain{Tuple{Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, ::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
 [58] _pullback
    @ ./In[15]:2 [inlined]
 [59] _pullback(::Zygote.Context{true}, ::var"#17#18")
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
 [60] pullback(f::Function, ps::Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:414
 [61] withgradient(f::Function, args::Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:154