Why is Flux.destructure type unstable?

I was building a simple model and at some point I needed to “unroll” it to get all the parameters in an array.

So I tired with Flux.destructure. I got some type instability, so I checked the documentation and I tried with the example provided there:

using Flux
model = Chain(Dense(2 => 1, tanh), Dense(1 => 1))
@code_warntype Flux.destructure(model)

But this gives a type instability as well!

Flux.destructure(model) = (Float32[0.27410066, 0.6508191, 0.0, 0.16767712, 0.0], Restructure(Chain, ..., 5))
MethodInstance for Optimisers.destructure(::Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
  from destructure(x) @ Optimisers
Arguments
  #self#::Core.Const(Optimisers.destructure)
  x::Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}
Locals
  @_3::Int64
  len::Int64
  off::NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}}
  flat::AbstractVector
Body::Tuple{AbstractVector, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, S} where S<:(NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}})}
1 ─ %1  = Optimisers._flatten(x)::Tuple{AbstractVector, NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}}, Int64}
│   %2  = Base.indexed_iterate(%1, 1)::Core.PartialStruct(Tuple{AbstractVector, Int64}, Any[AbstractVector, Core.Const(2)])
│         (flat = Core.getfield(%2, 1))
│         (@_3 = Core.getfield(%2, 2))
│   %5  = Base.indexed_iterate(%1, 2, @_3::Core.Const(2))::Core.PartialStruct(Tuple{NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}}, Int64}, Any[NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}}, Core.Const(3)])
│         (off = Core.getfield(%5, 1))
│         (@_3 = Core.getfield(%5, 2))
│   %8  = Base.indexed_iterate(%1, 3, @_3::Core.Const(3))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(4)])
│         (len = Core.getfield(%8, 1))
│   %10 = flat::AbstractVector
│   %11 = Optimisers.Restructure(x, off, len)::Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, S} where S<:(NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}})
│   %12 = Core.tuple(%10, %11)::Tuple{AbstractVector, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, S} where S<:(NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}})}
└──       return %12

What am I missing?

Is there any safe alternative? I would like to know if there is something similar to “unroll” in Jax.

I highly recommend Lux.jl if you’re trying to do this kind of thing. You just make the parameters a CompnentArray and done.

1 Like

Hi, thanks for your answer.

Could you please explain me how to do that with ComponentArrays?

I would like to “unpack” all the parameters as a single vector, do some operations on this vector and finally “repack” everything into the original structure. Is it possible to do this?

Chris is suggesting you re-start in a different package (from his group), in which the parameters are never part of the original structure in the first place. Then you can add ComponentArrays to handle the structured-to-flat transition in a slightly different way.

Same Q as issue here, please link things if you post the same question several places. In my answer there I’m not sure what the real question is. Do you want to know details about why some things work the way they do, taking your title literally? Or does something you are trying to do not work, or have you encountered an error which you think is caused by this, are or are you trying to solve some performance problem?

1 Like

If you have a Lux neural network:

using Lux, Random, ComponentArrays
rng = Random.Xoshiro()
U = Lux.Chain(Lux.Dense(2, 5, tanh), Lux.Dense(5, 5, tanh), Lux.Dense(5, 5, tanh),
              Lux.Dense(5, 2))

p, st = Lux.setup(rng, U)

Then p is already the parameters! It by default is a nested named tuple:

julia> p
(layer_1 = (weight = Float32[0.032702986 -0.8001958; -0.5739451 -0.68308777; … ; -0.44908214 0.3974201; -0.23948132 0.3856664], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[-0.58535755 -0.24707751 … -0.7356083 -0.7143577; 0.5912068 -0.3420531 … -0.044286303 -0.6271868; … ; 0.27350903 0.6479984 … -0.58502835 0.6814906; -0.47887245 0.37634572 … -0.3932655 0.6188778], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[-0.5810871 0.41245216 … -0.589331 0.52293205; -0.440483 0.4243137 … 0.026632357 0.65401685; … ; 0.3230968 -0.73377603 … -0.5547266 0.3263391; 0.6041936 -0.4778424 … 0.07565704 -0.60510886], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_4 = (weight = Float32[-0.86475086 -0.8281523 … -0.57153064 0.04483523; 0.2813924 0.7921467 … -0.29173294 0.2116531], bias = Float32[0.0; 0.0;;]))

but you can transform it into a ComponentArray like:

julia> _p = ComponentArray(p)
ComponentVector{Float32}(layer_1 = (weight = Float32[0.032702986 -0.8001958; -0.5739451 -0.68308777; … ; -0.44908214 0.3974201; -0.23948132 0.3856664], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[-0.58535755 -0.24707751 … -0.7356083 -0.7143577; 0.5912068 -0.3420531 … -0.044286303 -0.6271868; … ; 0.27350903 0.6479984 … -0.58502835 0.6814906; -0.47887245 0.37634572 … -0.3932655 0.6188778], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[-0.5810871 0.41245216 … -0.589331 0.52293205; -0.440483 0.4243137 … 0.026632357 0.65401685; … ; 0.3230968 -0.73377603 … -0.5547266 0.3263391; 0.6041936 -0.4778424 … 0.07565704 -0.60510886], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_4 = (weight = Float32[-0.86475086 -0.8281523 … -0.57153064 0.04483523; 0.2813924 0.7921467 … -0.29173294 0.2116531], bias = Float32[0.0; 0.0;;]))

Now it’s a vector with indexing on it. This means that linear algebra is trivial:

using LinearAlgebra
_p .* _p' * rand(87,87)

87×87 Matrix{Float64}:
 -0.161462  -0.15627  -0.174776  -0.153733  -0.149389  -0.206052  …  -0.127946  -0.161347  -0.130704  -0.133134
  2.8337     2.74257   3.06736    2.69805    2.62181    3.61626       2.24549    2.83168    2.29388    2.33653
  1.16773    1.13017   1.26401    1.11183    1.08041    1.49021       0.925332   1.16689    0.945271   0.962847
  ⋮                                                     ⋮         ⋱                         ⋮         
 -1.04498   -1.01138  -1.13115   -0.994957  -0.966841  -1.33356      -0.828067  -1.04423   -0.845911  -0.861639
  0.0        0.0       0.0        0.0        0.0        0.0       …   0.0        0.0        0.0        0.0
  0.0        0.0       0.0        0.0        0.0        0.0           0.0        0.0        0.0        0.0

You can then just use that component array with any optimization loop in order to have custom modifications to the neural network weights in a type stable loop.

There is no need to restart, there are tools to auto-convert Flux models using FromFluxAdaptor:

import Flux
using Adapt, Lux, Random

m = ResNet(18)
m2 = adapt(FromFluxAdaptor(), m.layers) # or FromFluxAdaptor()(m.layers)

That makes it essentially one line to match the switch for existing code.

1 Like

Thanks a lot, this works very well. Also, Component Arrays allows to very easily access elements with, say, _p.weight.

Is there a reason why you decided not to include the parameters in the model?

It makes these use cases much simpler to program and improves the efficiency in many cases where you don’t want to keep reconstructing this array. Also, it makes the neural network a pure function which takes all data as inputs, which has some nice properties for serialization, parallelization, and other applications.

1 Like