How to optimize a neural network with a "traditional" optimizer?

Nowadays neural networks are represented as trees with parameters as leaves. This is how PyTorch, JAX and Julia’s Flux do it. It doesn’t let one update all parameters at once like x = x - lr * grad(objective)(x), so libraries instead walk the “model tree” using jax.tree_util.tree_map or Functors.fmap. However, traditional optimizers like IPOPT only work with vectors of parameters.

How do I optimize/fit/train a neural network with an optimizer that requires a vector (not tree!) of parameters?

Inspired by jax.flatten_util.ravel_pytree, I tried extracting all parameters into a vector using Functors and ComponentArrays:

julia> function functor_vec(layer)
        params_tuple, rebuild = Functors.functor(layer)
        params_component_vec = ComponentVector(params_tuple)
        ax = getaxes(params_component_vec)
        collect(params_component_vec), x -> ComponentVector(x, ax) |> NamedTuple |> rebuild
       end;

julia> struct MyLayer{T<:Real}
        W::AbstractMatrix{T}
        b::AbstractVector{T}
       end

julia> Functors.@functor MyLayer

julia> l = MyLayer(randn(5, 3), zeros(5));

julia> par_vec, rebuild = functor_vec(l);

julia> par_vec
20-element Vector{Float64}:
  0.519643525693406
  1.3357073627871439
  0.8184611283560812
  0.8694805669330345
  1.3867422066794435
 -0.09304748637230531
 -1.9661017560634666
 -0.5763398740637272
  0.24773712869934397
 -0.6204684719160402
 -0.3922590305057905
 -0.5495348092236947
 -0.016525852923050595
 -0.09057283663964283
 -0.215598824258112
  0.0
  0.0
  0.0
  0.0
  0.0

julia> rebuild(par_vec)
MyLayer{Float64}([0.519643525693406 -0.09304748637230531 -0.3922590305057905; 1.3357073627871439 -1.9661017560634666 -0.5495348092236947; … ; 0.8694805669330345 0.24773712869934397 -0.09057283663964283; 1.3867422066794435 -0.6204684719160402 -0.215598824258112], [0.0, 0.0, 0.0, 0.0, 0.0])

It works here, but breaks when a struct contains MyLayer:

julia> struct MyChain{T<:Real}
        l1::MyLayer{T}
        l2::MyLayer{T}
       end

julia> Functors.@functor MyChain

julia> ch = MyChain(MyLayer(randn(5, 3), zeros(5)), MyLayer(randn(5, 3), zeros(5)));

julia> par_vec, rebuild = functor_vec(ch);
ERROR: MethodError: no method matching length(::MyLayer{Float64})

Closest candidates are:
  length(::Pkg.Types.Manifest)
   @ Pkg ~/.julia/juliaup/julia-1.10.0-beta3+0.x64.apple.darwin14/share/julia/stdlib/v1.10/Pkg/src/Types.jl:315
  length(::Core.SimpleVector)
   @ Base essentials.jl:765
  length(::Base.MethodSpecializations)
   @ Base reflection.jl:1164
  ...

Stacktrace:
  [1] recursive_length(x::MyLayer{Float64})
    @ ComponentArrays ~/.julia/packages/ComponentArrays/7wTG6/src/utils.jl:40
  [2] MappingRF
    @ Base ./reduce.jl:100 [inlined]
  [3] afoldl(::Base.MappingRF{…}, ::Int64, ::MyLayer{…}, ::MyLayer{…})
    @ Base ./operators.jl:544
  [4] _foldl_impl(op::Base.MappingRF{…}, init::Int64, itr::Tuple{…})
    @ Base ./reduce.jl:68
  [5] foldl_impl(op::Base.MappingRF{…}, nt::Int64, itr::Tuple{…})
    @ Base ./reduce.jl:48
  [6] mapfoldl_impl(f::typeof(ComponentArrays.recursive_length), op::typeof(+), nt::Int64, itr::Tuple{MyLayer{…}, MyLayer{…}})
    @ Base ./reduce.jl:44
  [7] mapfoldl(f::Function, op::Function, itr::Tuple{MyLayer{Float64}, MyLayer{Float64}}; init::Int64)
    @ Base ./reduce.jl:175
  [8] mapfoldl
    @ Base ./reduce.jl:175 [inlined]
  [9] #mapreduce#302
    @ Base ./reduce.jl:307 [inlined]
 [10] recursive_length(nt::@NamedTuple{l1::MyLayer{Float64}, l2::MyLayer{Float64}})
    @ ComponentArrays ~/.julia/packages/ComponentArrays/7wTG6/src/utils.jl:43
 [11] make_idx(data::Vector{Any}, nt::@NamedTuple{l1::MyLayer{Float64}, l2::MyLayer{Float64}}, last_val::Int64)
    @ ComponentArrays ~/.julia/packages/ComponentArrays/7wTG6/src/componentarray.jl:157
 [12] make_carray_args(A::Type{Vector}, nt::@NamedTuple{l1::MyLayer{Float64}, l2::MyLayer{Float64}})
    @ ComponentArrays ~/.julia/packages/ComponentArrays/7wTG6/src/componentarray.jl:151
 [13] make_carray_args(nt::@NamedTuple{l1::MyLayer{Float64}, l2::MyLayer{Float64}})
    @ ComponentArrays ~/.julia/packages/ComponentArrays/7wTG6/src/componentarray.jl:144
 [14] ComponentArray(nt::@NamedTuple{l1::MyLayer{Float64}, l2::MyLayer{Float64}})
    @ ComponentArrays ~/.julia/packages/ComponentArrays/7wTG6/src/componentarray.jl:64
 [15] (ComponentVector)(nt::@NamedTuple{l1::MyLayer{Float64}, l2::MyLayer{Float64}})
    @ ComponentArrays ~/.julia/packages/ComponentArrays/7wTG6/src/componentarray.jl:84
 [16] functor_vec(layer::MyChain{Float64})
    @ Main ./REPL[10]:3
 [17] top-level scope
    @ REPL[32]:1

Is there a better way of extracting parameters of a tree-like model into a vector? Is there another way of optimizing a neural network with solvers that optimize over vectors, like IPOPT?

You could simply use Lux.jl which natively works with ComponentArrays Training a Neural ODE to Model Gravitational Waveforms | LuxDL Docs (and then use Optimization.jl for IPOPT)

1 Like