It’s easier to see the code:
using Lux
using Zygote
using ComponentArrays
using Random
x_test_flat = rand(16)
x_test_p = [rand(4) for _ in 1:4]
NNmodel_1 = Lux.Chain(Lux.Dense(16 => 18, sigmoid),
Lux.Dense(18 => 6, sigmoid),
Lux.Dense(6 => 1, sigmoid),
x -> x.*0.5)
NNmodel_2 = Lux.Chain(Lux.Dense(4 => 5, sigmoid),
Lux.Dense(5 => 3, sigmoid),
Lux.Dense(3 => 1, sigmoid),
x -> x.*0.2)
rng = Random.default_rng()
Random.seed!(rng, 0)
# Initialize Model
ps_NN_1, st_1 = Lux.setup(rng, NNmodel_1)
ps_NN_2, st_2 = Lux.setup(rng, NNmodel_2)
# Parameters must be a ComponentArray or an Array,
# Zygote Jacobian won't loop through NamedTuple
ps_NN_1 = ps_NN_1 |> ComponentArray
ps_NN_2 = ps_NN_2 |> ComponentArray
st = (st_1, st_2)
ps_NN = ComponentArray(ps_NN_1=ps_NN_1, ps_NN_2=ps_NN_2)
function NNmodel_total_flat(x, ps, st; NNmodel_1=NNmodel_1, NNmodel_2=NNmodel_2)
ps_NN_1 = ps[:ps_NN_1]
ps_NN_2 = ps[:ps_NN_2]
st_1, st_2 = st
x_2 = [getindex(x, [i,i+4,i+8,i+12]) for i=1:4]
y_1 = NNmodel_1(x, ps_NN_1, st_1)[1]
y_2 = [NNmodel_2(x_2_, ps_NN_2, st_2)[1] for x_2_ in x_2]
(vcat(y_1, y_2...),)
end
function NNmodel_total_p(x, ps, st; NNmodel_1=NNmodel_1, NNmodel_2=NNmodel_2)
ps_NN_1 = ps[:ps_NN_1]
ps_NN_2 = ps[:ps_NN_2]
st_1, st_q = st
y_1 = NNmodel_1(vcat(x...), ps_NN_1, st_1)[1]
y_2 = [NNmodel_2(x_, ps_NN_2, st_2)[1] for x_ in x]
(vcat(y_1, y_2...),) # <- This is where the crash occurs
end
NNmodel_total_flat(x_test_flat, ps_NN, st)[1]
NNmodel_total_p(x_test_p, ps_NN, st)[1] # Both working, different resuls (obv.) but working
Zygote.gradient(ps -> sum(NNmodel_total_flat(x_test_flat, ps, st)[1]), ps_NN) # Works
Zygote.gradient(ps -> sum(NNmodel_total_p(x_test_p, ps, st)[1]), ps_NN) # Crash
When the “NN input” is flat, Zygote calculates the gradient properly, when the input isn’t flat it crashes.
Specifically, the crash occurs at the vcat
(after some troubleshooting). Stacktrace:
ERROR: MethodError: no method matching +(::Vector{Vector{Float64}}, ::NTuple{4, Vector{Float64}})
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...)
@ Base operators.jl:578
+(::Union{InitialValues.NonspecificInitialValue, InitialValues.SpecificInitialValue{typeof(+)}}, ::Any)
@ InitialValues /Net/Groups/BGI/people/mchettouh/.julia/packages/InitialValues/OWP8V/src/InitialValues.jl:154
+(::ChainRulesCore.Tangent{P}, ::P) where P
@ ChainRulesCore /Net/Groups/BGI/people/mchettouh/.julia/packages/ChainRulesCore/0t04l/src/tangent_arithmetic.jl:146
...
Stacktrace:
[1] accum(x::Vector{Vector{Float64}}, y::NTuple{4, Vector{Float64}})
@ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/lib/lib.jl:17
[2] Pullback
@ /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:105 [inlined]
[3] Pullback
@ ~(file here):(line where NNmodel_total_p is defined) [inlined]