Lux (And Flux), "parallel" Network Input. When Input is flat, Zygote gradient works, when input is not flat it doesn't

It’s easier to see the code:

using Lux
using Zygote
using ComponentArrays
using Random

x_test_flat = rand(16)
x_test_p = [rand(4) for _ in 1:4]

NNmodel_1 = Lux.Chain(Lux.Dense(16 => 18, sigmoid),
                        Lux.Dense(18 => 6, sigmoid),
                        Lux.Dense(6 => 1, sigmoid),
                        x -> x.*0.5)

NNmodel_2 = Lux.Chain(Lux.Dense(4 => 5, sigmoid),
                          Lux.Dense(5 => 3, sigmoid),
                          Lux.Dense(3 => 1, sigmoid),
                          x -> x.*0.2)
rng = Random.default_rng()
Random.seed!(rng, 0)
# Initialize Model
ps_NN_1, st_1 = Lux.setup(rng, NNmodel_1)
ps_NN_2, st_2 = Lux.setup(rng, NNmodel_2)
# Parameters must be a ComponentArray or an Array,
# Zygote Jacobian won't loop through NamedTuple
ps_NN_1 = ps_NN_1 |> ComponentArray
ps_NN_2 = ps_NN_2 |> ComponentArray


st = (st_1, st_2)
ps_NN = ComponentArray(ps_NN_1=ps_NN_1, ps_NN_2=ps_NN_2)

function NNmodel_total_flat(x, ps, st; NNmodel_1=NNmodel_1, NNmodel_2=NNmodel_2)
    ps_NN_1 = ps[:ps_NN_1]
    ps_NN_2 = ps[:ps_NN_2]
    st_1, st_2 = st
    x_2 = [getindex(x, [i,i+4,i+8,i+12]) for i=1:4]
    y_1 = NNmodel_1(x, ps_NN_1, st_1)[1]
    y_2 = [NNmodel_2(x_2_, ps_NN_2, st_2)[1] for x_2_ in x_2]
    (vcat(y_1, y_2...),)
end

function NNmodel_total_p(x, ps, st; NNmodel_1=NNmodel_1, NNmodel_2=NNmodel_2)
    ps_NN_1 = ps[:ps_NN_1]
    ps_NN_2 = ps[:ps_NN_2]
    st_1, st_q = st
    y_1 = NNmodel_1(vcat(x...), ps_NN_1, st_1)[1]
    y_2 = [NNmodel_2(x_, ps_NN_2, st_2)[1] for x_ in x]
    (vcat(y_1, y_2...),) # <- This is where the crash occurs
end

NNmodel_total_flat(x_test_flat, ps_NN, st)[1]
NNmodel_total_p(x_test_p, ps_NN, st)[1] # Both working, different resuls (obv.) but working

Zygote.gradient(ps -> sum(NNmodel_total_flat(x_test_flat, ps, st)[1]), ps_NN) # Works
Zygote.gradient(ps -> sum(NNmodel_total_p(x_test_p, ps, st)[1]), ps_NN) # Crash

When the “NN input” is flat, Zygote calculates the gradient properly, when the input isn’t flat it crashes.
Specifically, the crash occurs at the vcat (after some troubleshooting). Stacktrace:

ERROR: MethodError: no method matching +(::Vector{Vector{Float64}}, ::NTuple{4, Vector{Float64}})                                                                                                                                                                               
                                                                                                                                                                                                                                                                                
Closest candidates are:                                                                                                                                                                                                                                                         
  +(::Any, ::Any, ::Any, ::Any...)                                                                                                                                                                                                                                              
   @ Base operators.jl:578                                                                                                                                                                                                                                                      
  +(::Union{InitialValues.NonspecificInitialValue, InitialValues.SpecificInitialValue{typeof(+)}}, ::Any)                                                                                                                                                                       
   @ InitialValues /Net/Groups/BGI/people/mchettouh/.julia/packages/InitialValues/OWP8V/src/InitialValues.jl:154                                                                                                                                                                
  +(::ChainRulesCore.Tangent{P}, ::P) where P                                                                                                                                                                                                                                   
   @ ChainRulesCore /Net/Groups/BGI/people/mchettouh/.julia/packages/ChainRulesCore/0t04l/src/tangent_arithmetic.jl:146                                                                                                                                                         
  ...                                                                                                                                                                                                                                                                           
                                                                                                                                                                                                                                                                                
Stacktrace:                                                                                                                                                                                                                                                                     
  [1] accum(x::Vector{Vector{Float64}}, y::NTuple{4, Vector{Float64}})                                                                                                                                                                                                          
    @ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/lib/lib.jl:17                                                                                                                                                                                    
  [2] Pullback                                                                                                                                                                                                                                                                  
    @ /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:105 [inlined]                                                                                                                                                                    
  [3] Pullback                                                                                                                                                                                                                                                                  
    @ ~(file here):(line where NNmodel_total_p is defined) [inlined]

Splatting wont work (vcat(y_1, y_2...),) # <- This is where the crash occurs. do vcat(y_1, reduce(vcat, y_2))

Sadly still crashes, same stacktrace.

When I re-write y_2 to become a simple vector it still crashes:

function NNmodel_total_p(x, ps, st; NNmodel_1=NNmodel_1, NNmodel_2=NNmodel_2)
    ps_NN_1 = ps[:ps_NN_1]
    ps_NN_2 = ps[:ps_NN_2]
    st_1, st_q = st
    y_1 = NNmodel_1(vcat(x...), ps_NN_1, st_1)[1]
    y_2 = [NNmodel_2(x_, ps_NN_2, st_2)[1][1] for x_ in x]
    (vcat(y_1, y_2),) 
end

Update:

Flux gives the same error as well:


NNmodel_1_Flux = Flux.Chain(Flux.Dense(16 => 18, sigmoid),
                        Flux.Dense(18 => 6, sigmoid),
                        Flux.Dense(6 => 1, sigmoid)
                        x -> x.*0.5)

NNmodel_2_Flux = Flux.Chain(Flux.Dense(4 => 5, sigmoid),
                          Flux.Dense(5 => 3, sigmoid),
                          Flux.Dense(3 => 1, sigmoid)
                          x -> x.*0.2)
ps_Flux = ComponentArray(ps_NN_1=Flux.params(NNmodel_1_Flux), ps_NN_2=Flux.params(NNmodel_2_Flux))

function NNmodel_total_flat_Flux(x, ps; NNmodel_1=NNmodel_1_Flux, NNmodel_2=NNmodel_2_Flux)
    ps_NN_1 = ps[:ps_NN_1]
    ps_NN_2 = ps[:ps_NN_2]
    Flux.params!(ps_NN_1, NNmodel_1_Flux)
    Flux.params!(ps_NN_2, NNmodel_2_Flux)
    x_2 = [getindex(x, [i,i+4,i+8,i+12]) for i=1:4]
    y_1 = NNmodel_1(x)
    y_2 = [NNmodel_2(x_2_) for x_2_ in x_2]
    vcat(y_1, y_2...)
end

function NNmodel_total_p_Flux(x, ps; NNmodel_1=NNmodel_1_Flux, NNmodel_2=NNmodel_2_Flux)
    ps_NN_1 = ps[:ps_NN_1]
    ps_NN_2 = ps[:ps_NN_2]
    Flux.params!(ps_NN_1, NNmodel_1_Flux)
    Flux.params!(ps_NN_2, NNmodel_2_Flux)
    y_1 = NNmodel_1(vcat(x...))[1]
    y_2 = [NNmodel_2(x_) for x_ in x]
    vcat(y_1, y_2...) # <- This is still where the crash occurs
end

Zygote.gradient(ps -> sum(NNmodel_total_flat_Flux(x_test_flat, ps)), ps_Flux) # Works
Zygote.gradient(ps -> sum(NNmodel_total_p_Flux(x_test_p, ps)), ps_Flux) # Crash

Using Flux.parallel still gives the same error; Code:


function NNmodel_total_fluxp(x, ps; NNmodel_1=NNmodel_1_Flux, NNmodel_2=NNmodel_2_Flux)
    ps_NN_1 = ps[:ps_NN_1]
    ps_NN_2 = ps[:ps_NN_2]
    Flux.params!(ps_NN_1, NNmodel_1_Flux)
    Flux.params!(ps_NN_2, NNmodel_2_Flux)
    y_1 = NNmodel_1(vcat(x...))[1]
    return Flux.Parallel(vcat, NNmodel_1, NNmodel_2, NNmodel_2, NNmodel_2, NNmodel_2)(vcat(x...),x[1], x[2], x[3], x[4])
end

Flux.params is soft deprecated at this point and should not be used. Have a look through the intro sections of the docs to see how to work with model parameters now. In short, there should be little difference between Flux and Lux. If you must use a flat vector of parameters, I’d recommend solving your original problem/MWE first. Otherwise, do you have to have all the parameters stored in a single flat vector?

do you have to have all the parameters stored in a single flat vector?
Nope, however not having them in a “flat” vector and using “idiomatic” Flux is still giving the same error.

I also tried a new approach:

struct One2Many{T1, T2}
    primary_path::T1
    parallel_paths::T2
end 

One2Many(primary_path, paths...) = One2Many(primary_path, paths)
Flux.@functor One2Many

function (mj::One2Many)(xs::Vector{Vector{T}}) where{T} 
          vcat(mj.primary_path(vcat(xs...)),map((f, x) -> f(x), mj.parallel_paths, xs)...)
end

NNmodel_total_p_Flux = Flux.Chain(One2Many(
                                x -> NNmodel_1_Flux(vcat(x...)),
                                NNmodel_2_Flux,
                                NNmodel_2_Flux,
                                NNmodel_2_Flux,
                                       NNmodel_2_Flux))
Flux.gradient(() -> sum(NNmodel_total_p_Flux(x_test_p)), Flux.params(NNmodel_total_p_Flux))

this still gives the same error, a bit more “cryptic”, but the core of it is that it there is no + method between something (changes depending on implementation) and NTuple{4...}

Per my last message, I would remove all uses of params and instead follow the docs on how to take gradients of models. If you’re still running into issues then a MWE and full stacktrace will be required, per Please read: make it easier to help you.

1 Like

Full MWE for Flux (without Params):

using Zygote
using ComponentArrays
using Random
using Flux

x_test = [rand(4) for _ in 1:4]

NNmodel_1_Flux = Flux.Chain(Flux.Dense(16 => 18, sigmoid),
                        Flux.Dense(18 => 6, sigmoid),
                        Flux.Dense(6 => 1, sigmoid),
                        x -> x.*0.5)

NNmodel_2_Flux = Flux.Chain(Flux.Dense(4 => 5, sigmoid),
                          Flux.Dense(5 => 3, sigmoid),
                          Flux.Dense(3 => 1, sigmoid),
                          x -> x.*0.2)

struct One2Many{T1, T2}
    primary_path::T1
    parallel_paths::T2
end 

One2Many(primary_path, paths...) = One2Many(primary_path, paths)
Flux.@functor One2Many

function (mj::One2Many)(xs::Vector{Vector{T}}) where{T} 
          vcat(mj.primary_path(vcat(xs...)),map((f, x) -> f(x), mj.parallel_paths, xs)...)
end

model = Flux.Chain(One2Many(
                                x -> NNmodel_1_Flux(vcat(x...)),
                                NNmodel_2_Flux,
                                NNmodel_2_Flux,
                                NNmodel_2_Flux,
                                       NNmodel_2_Flux))
function loss(y)
    return sum(y)
end

Flux.gradient(m -> loss(m(x_test)), model)

With the StackTrace:

julia> Flux.gradient(m -> loss(m(x_test)), model)
ERROR: MethodError: no method matching (::ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{element::ChainRulesCore.ProjectTo{…}, axes::Tuple{…}}})(::NTuple{4, Float64})

Closest candidates are:
  (::ChainRulesCore.ProjectTo{T})(::ChainRulesCore.NotImplemented) where T
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/zoCjl/src/projection.jl:121
  (::ChainRulesCore.ProjectTo{AbstractArray})(::ChainRulesCore.Tangent)
   @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/chainrules.jl:200
  (::ChainRulesCore.ProjectTo{AbstractArray})(::Number)
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/zoCjl/src/projection.jl:253
  ...

Stacktrace:
  [1] (::ChainRules.var"#1396#1401"{ChainRulesCore.ProjectTo{…}, Tuple{…}, ChainRulesCore.Tangent{…}})()
    @ ChainRules ~/.julia/packages/ChainRules/pEOSw/src/rulesets/Base/array.jl:310
  [2] unthunk
    @ ~/.julia/packages/ChainRulesCore/zoCjl/src/tangent_types/thunks.jl:204 [inlined]
  [3] unthunk(x::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{…}, ChainRules.var"#1395#1400"{…}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/zoCjl/src/tangent_types/thunks.jl:237
  [4] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/WOy6z/src/compiler/chainrules.jl:110 [inlined]
  [5] map (repeats 2 times)
    @ ./tuple.jl:294 [inlined]
  [6] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/WOy6z/src/compiler/chainrules.jl:111 [inlined]
  [7] ZBack
    @ ~/.julia/packages/Zygote/WOy6z/src/compiler/chainrules.jl:211 [inlined]
  [8] (::Zygote.var"#291#292"{Tuple{Int64}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1397"{…}}})(Δ::NTuple{16, Float64})
    @ Zygote ~/.julia/packages/Zygote/WOy6z/src/lib/lib.jl:206
  [9] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{…}, Zygote.ZBack{…}}})(Δ::NTuple{16, Float64})
    @ Zygote ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71
 [10] One2Many
    @ ./REPL[21]:2 [inlined]
 [11] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:0
 [12] macro expansion
    @ ~/.julia/packages/Flux/UsEXa/src/layers/basic.jl:53 [inlined]
 [13] _applychain
    @ ~/.julia/packages/Flux/UsEXa/src/layers/basic.jl:53 [inlined]
 [14] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:0
 [15] Chain
    @ ~/.julia/packages/Flux/UsEXa/src/layers/basic.jl:51 [inlined]
 [16] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:0
 [17] #21
    @ ./REPL[29]:1 [inlined]
 [18] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:0
 [19] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface.jl:45
 [20] gradient(f::Function, args::Chain{Tuple{One2Many{var"#17#18", NTuple{4, Chain{Tuple{Dense{…}, Dense{…}, Dense{…}, var"#8#9"}}}}}})
    @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface.jl:97
 [21] top-level scope
    @ REPL[29]:1


function (mj::One2Many)(xs::AbstractVector{<:AbstractArray})
    primary = mj.primary_path(reduce(vcat, xs))
    parallel = map(|>, xs, mj.parallel_paths)
    return vcat(primary, parallel...)
end

model = Chain(
    One2Many(
        NNmodel_1_Flux,
        NNmodel_2_Flux,
        NNmodel_2_Flux,
        NNmodel_2_Flux,
        NNmodel_2_Flux
    )
)

These changes are enough to get rid of the errors. The reason is that it avoids both splatting xs... and mapping xs in the same function. Because the former treats it like a tuple and the latter like an array, Zygote gets confused and can’t reconcile the differently-typed gradients being returned.

You’ll notice I also changed the definition of model. There was a redundant vcat(x...) there, and x -> NNmodel_1_Flux(vcat(x...)) won’t work because we can’t extract parameters from NNmodel_1_Flux when it’s wrapped in an anonymous function like this.

An alternative approach to the above is to make x_test a tuple instead of an array, i.e. using xs_test = ntuple(_ - > rand(Float32, 4), 4). This should be faster because it’s more type stable, but may incur some compilation overhead if the length of x_test keeps changing. Using a tuple might also yet you make use of more built-in layers such as Flux.Parallel.