Nowadays neural networks are represented as trees with parameters as leaves. This is how PyTorch, JAX and Julia’s Flux do it. It doesn’t let one update all parameters at once like x = x - lr * grad(objective)(x)
, so libraries instead walk the “model tree” using jax.tree_util.tree_map
or Functors.fmap
. However, traditional optimizers like IPOPT only work with vectors of parameters.
How do I optimize/fit/train a neural network with an optimizer that requires a vector (not tree!) of parameters?
Inspired by jax.flatten_util.ravel_pytree
, I tried extracting all parameters into a vector using Functors
and ComponentArrays
:
julia> function functor_vec(layer)
params_tuple, rebuild = Functors.functor(layer)
params_component_vec = ComponentVector(params_tuple)
ax = getaxes(params_component_vec)
collect(params_component_vec), x -> ComponentVector(x, ax) |> NamedTuple |> rebuild
end;
julia> struct MyLayer{T<:Real}
W::AbstractMatrix{T}
b::AbstractVector{T}
end
julia> Functors.@functor MyLayer
julia> l = MyLayer(randn(5, 3), zeros(5));
julia> par_vec, rebuild = functor_vec(l);
julia> par_vec
20-element Vector{Float64}:
0.519643525693406
1.3357073627871439
0.8184611283560812
0.8694805669330345
1.3867422066794435
-0.09304748637230531
-1.9661017560634666
-0.5763398740637272
0.24773712869934397
-0.6204684719160402
-0.3922590305057905
-0.5495348092236947
-0.016525852923050595
-0.09057283663964283
-0.215598824258112
0.0
0.0
0.0
0.0
0.0
julia> rebuild(par_vec)
MyLayer{Float64}([0.519643525693406 -0.09304748637230531 -0.3922590305057905; 1.3357073627871439 -1.9661017560634666 -0.5495348092236947; … ; 0.8694805669330345 0.24773712869934397 -0.09057283663964283; 1.3867422066794435 -0.6204684719160402 -0.215598824258112], [0.0, 0.0, 0.0, 0.0, 0.0])
It works here, but breaks when a struct
contains MyLayer
:
julia> struct MyChain{T<:Real}
l1::MyLayer{T}
l2::MyLayer{T}
end
julia> Functors.@functor MyChain
julia> ch = MyChain(MyLayer(randn(5, 3), zeros(5)), MyLayer(randn(5, 3), zeros(5)));
julia> par_vec, rebuild = functor_vec(ch);
ERROR: MethodError: no method matching length(::MyLayer{Float64})
Closest candidates are:
length(::Pkg.Types.Manifest)
@ Pkg ~/.julia/juliaup/julia-1.10.0-beta3+0.x64.apple.darwin14/share/julia/stdlib/v1.10/Pkg/src/Types.jl:315
length(::Core.SimpleVector)
@ Base essentials.jl:765
length(::Base.MethodSpecializations)
@ Base reflection.jl:1164
...
Stacktrace:
[1] recursive_length(x::MyLayer{Float64})
@ ComponentArrays ~/.julia/packages/ComponentArrays/7wTG6/src/utils.jl:40
[2] MappingRF
@ Base ./reduce.jl:100 [inlined]
[3] afoldl(::Base.MappingRF{…}, ::Int64, ::MyLayer{…}, ::MyLayer{…})
@ Base ./operators.jl:544
[4] _foldl_impl(op::Base.MappingRF{…}, init::Int64, itr::Tuple{…})
@ Base ./reduce.jl:68
[5] foldl_impl(op::Base.MappingRF{…}, nt::Int64, itr::Tuple{…})
@ Base ./reduce.jl:48
[6] mapfoldl_impl(f::typeof(ComponentArrays.recursive_length), op::typeof(+), nt::Int64, itr::Tuple{MyLayer{…}, MyLayer{…}})
@ Base ./reduce.jl:44
[7] mapfoldl(f::Function, op::Function, itr::Tuple{MyLayer{Float64}, MyLayer{Float64}}; init::Int64)
@ Base ./reduce.jl:175
[8] mapfoldl
@ Base ./reduce.jl:175 [inlined]
[9] #mapreduce#302
@ Base ./reduce.jl:307 [inlined]
[10] recursive_length(nt::@NamedTuple{l1::MyLayer{Float64}, l2::MyLayer{Float64}})
@ ComponentArrays ~/.julia/packages/ComponentArrays/7wTG6/src/utils.jl:43
[11] make_idx(data::Vector{Any}, nt::@NamedTuple{l1::MyLayer{Float64}, l2::MyLayer{Float64}}, last_val::Int64)
@ ComponentArrays ~/.julia/packages/ComponentArrays/7wTG6/src/componentarray.jl:157
[12] make_carray_args(A::Type{Vector}, nt::@NamedTuple{l1::MyLayer{Float64}, l2::MyLayer{Float64}})
@ ComponentArrays ~/.julia/packages/ComponentArrays/7wTG6/src/componentarray.jl:151
[13] make_carray_args(nt::@NamedTuple{l1::MyLayer{Float64}, l2::MyLayer{Float64}})
@ ComponentArrays ~/.julia/packages/ComponentArrays/7wTG6/src/componentarray.jl:144
[14] ComponentArray(nt::@NamedTuple{l1::MyLayer{Float64}, l2::MyLayer{Float64}})
@ ComponentArrays ~/.julia/packages/ComponentArrays/7wTG6/src/componentarray.jl:64
[15] (ComponentVector)(nt::@NamedTuple{l1::MyLayer{Float64}, l2::MyLayer{Float64}})
@ ComponentArrays ~/.julia/packages/ComponentArrays/7wTG6/src/componentarray.jl:84
[16] functor_vec(layer::MyChain{Float64})
@ Main ./REPL[10]:3
[17] top-level scope
@ REPL[32]:1
Is there a better way of extracting parameters of a tree-like model into a vector? Is there another way of optimizing a neural network with solvers that optimize over vectors, like IPOPT?