If you have a Lux neural network:
using Lux, Random, ComponentArrays
rng = Random.Xoshiro()
U = Lux.Chain(Lux.Dense(2, 5, tanh), Lux.Dense(5, 5, tanh), Lux.Dense(5, 5, tanh),
Lux.Dense(5, 2))
p, st = Lux.setup(rng, U)
Then p is already the parameters! It by default is a nested named tuple:
julia> p
(layer_1 = (weight = Float32[0.032702986 -0.8001958; -0.5739451 -0.68308777; … ; -0.44908214 0.3974201; -0.23948132 0.3856664], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[-0.58535755 -0.24707751 … -0.7356083 -0.7143577; 0.5912068 -0.3420531 … -0.044286303 -0.6271868; … ; 0.27350903 0.6479984 … -0.58502835 0.6814906; -0.47887245 0.37634572 … -0.3932655 0.6188778], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[-0.5810871 0.41245216 … -0.589331 0.52293205; -0.440483 0.4243137 … 0.026632357 0.65401685; … ; 0.3230968 -0.73377603 … -0.5547266 0.3263391; 0.6041936 -0.4778424 … 0.07565704 -0.60510886], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_4 = (weight = Float32[-0.86475086 -0.8281523 … -0.57153064 0.04483523; 0.2813924 0.7921467 … -0.29173294 0.2116531], bias = Float32[0.0; 0.0;;]))
but you can transform it into a ComponentArray like:
julia> _p = ComponentArray(p)
ComponentVector{Float32}(layer_1 = (weight = Float32[0.032702986 -0.8001958; -0.5739451 -0.68308777; … ; -0.44908214 0.3974201; -0.23948132 0.3856664], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[-0.58535755 -0.24707751 … -0.7356083 -0.7143577; 0.5912068 -0.3420531 … -0.044286303 -0.6271868; … ; 0.27350903 0.6479984 … -0.58502835 0.6814906; -0.47887245 0.37634572 … -0.3932655 0.6188778], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[-0.5810871 0.41245216 … -0.589331 0.52293205; -0.440483 0.4243137 … 0.026632357 0.65401685; … ; 0.3230968 -0.73377603 … -0.5547266 0.3263391; 0.6041936 -0.4778424 … 0.07565704 -0.60510886], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_4 = (weight = Float32[-0.86475086 -0.8281523 … -0.57153064 0.04483523; 0.2813924 0.7921467 … -0.29173294 0.2116531], bias = Float32[0.0; 0.0;;]))
Now it’s a vector with indexing on it. This means that linear algebra is trivial:
using LinearAlgebra
_p .* _p' * rand(87,87)
87×87 Matrix{Float64}:
-0.161462 -0.15627 -0.174776 -0.153733 -0.149389 -0.206052 … -0.127946 -0.161347 -0.130704 -0.133134
2.8337 2.74257 3.06736 2.69805 2.62181 3.61626 2.24549 2.83168 2.29388 2.33653
1.16773 1.13017 1.26401 1.11183 1.08041 1.49021 0.925332 1.16689 0.945271 0.962847
⋮ ⋮ ⋱ ⋮
-1.04498 -1.01138 -1.13115 -0.994957 -0.966841 -1.33356 -0.828067 -1.04423 -0.845911 -0.861639
0.0 0.0 0.0 0.0 0.0 0.0 … 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
You can then just use that component array with any optimization loop in order to have custom modifications to the neural network weights in a type stable loop.
There is no need to restart, there are tools to auto-convert Flux models using FromFluxAdaptor:
import Flux
using Adapt, Lux, Random
m = ResNet(18)
m2 = adapt(FromFluxAdaptor(), m.layers) # or FromFluxAdaptor()(m.layers)
That makes it essentially one line to match the switch for existing code.