I was building a simple model and at some point I needed to “unroll” it to get all the parameters in an array.
So I tired with Flux.destructure. I got some type instability, so I checked the documentation and I tried with the example provided there:
using Flux
model = Chain(Dense(2 => 1, tanh), Dense(1 => 1))
@code_warntype Flux.destructure(model)
But this gives a type instability as well!
Flux.destructure(model) = (Float32[0.27410066, 0.6508191, 0.0, 0.16767712, 0.0], Restructure(Chain, ..., 5))
MethodInstance for Optimisers.destructure(::Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
from destructure(x) @ Optimisers
Arguments
#self#::Core.Const(Optimisers.destructure)
x::Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}
Locals
@_3::Int64
len::Int64
off::NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}}
flat::AbstractVector
Body::Tuple{AbstractVector, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, S} where S<:(NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}})}
1 ─ %1 = Optimisers._flatten(x)::Tuple{AbstractVector, NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}}, Int64}
│ %2 = Base.indexed_iterate(%1, 1)::Core.PartialStruct(Tuple{AbstractVector, Int64}, Any[AbstractVector, Core.Const(2)])
│ (flat = Core.getfield(%2, 1))
│ (@_3 = Core.getfield(%2, 2))
│ %5 = Base.indexed_iterate(%1, 2, @_3::Core.Const(2))::Core.PartialStruct(Tuple{NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}}, Int64}, Any[NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}}, Core.Const(3)])
│ (off = Core.getfield(%5, 1))
│ (@_3 = Core.getfield(%5, 2))
│ %8 = Base.indexed_iterate(%1, 3, @_3::Core.Const(3))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(4)])
│ (len = Core.getfield(%8, 1))
│ %10 = flat::AbstractVector
│ %11 = Optimisers.Restructure(x, off, len)::Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, S} where S<:(NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}})
│ %12 = Core.tuple(%10, %11)::Tuple{AbstractVector, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, S} where S<:(NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}})}
└── return %12
What am I missing?
Is there any safe alternative? I would like to know if there is something similar to “unroll” in Jax.