Why is Flux.destructure type unstable?

If you have a Lux neural network:

using Lux, Random, ComponentArrays
rng = Random.Xoshiro()
U = Lux.Chain(Lux.Dense(2, 5, tanh), Lux.Dense(5, 5, tanh), Lux.Dense(5, 5, tanh),
              Lux.Dense(5, 2))

p, st = Lux.setup(rng, U)

Then p is already the parameters! It by default is a nested named tuple:

julia> p
(layer_1 = (weight = Float32[0.032702986 -0.8001958; -0.5739451 -0.68308777; … ; -0.44908214 0.3974201; -0.23948132 0.3856664], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[-0.58535755 -0.24707751 … -0.7356083 -0.7143577; 0.5912068 -0.3420531 … -0.044286303 -0.6271868; … ; 0.27350903 0.6479984 … -0.58502835 0.6814906; -0.47887245 0.37634572 … -0.3932655 0.6188778], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[-0.5810871 0.41245216 … -0.589331 0.52293205; -0.440483 0.4243137 … 0.026632357 0.65401685; … ; 0.3230968 -0.73377603 … -0.5547266 0.3263391; 0.6041936 -0.4778424 … 0.07565704 -0.60510886], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_4 = (weight = Float32[-0.86475086 -0.8281523 … -0.57153064 0.04483523; 0.2813924 0.7921467 … -0.29173294 0.2116531], bias = Float32[0.0; 0.0;;]))

but you can transform it into a ComponentArray like:

julia> _p = ComponentArray(p)
ComponentVector{Float32}(layer_1 = (weight = Float32[0.032702986 -0.8001958; -0.5739451 -0.68308777; … ; -0.44908214 0.3974201; -0.23948132 0.3856664], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[-0.58535755 -0.24707751 … -0.7356083 -0.7143577; 0.5912068 -0.3420531 … -0.044286303 -0.6271868; … ; 0.27350903 0.6479984 … -0.58502835 0.6814906; -0.47887245 0.37634572 … -0.3932655 0.6188778], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[-0.5810871 0.41245216 … -0.589331 0.52293205; -0.440483 0.4243137 … 0.026632357 0.65401685; … ; 0.3230968 -0.73377603 … -0.5547266 0.3263391; 0.6041936 -0.4778424 … 0.07565704 -0.60510886], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_4 = (weight = Float32[-0.86475086 -0.8281523 … -0.57153064 0.04483523; 0.2813924 0.7921467 … -0.29173294 0.2116531], bias = Float32[0.0; 0.0;;]))

Now it’s a vector with indexing on it. This means that linear algebra is trivial:

using LinearAlgebra
_p .* _p' * rand(87,87)

87×87 Matrix{Float64}:
 -0.161462  -0.15627  -0.174776  -0.153733  -0.149389  -0.206052  …  -0.127946  -0.161347  -0.130704  -0.133134
  2.8337     2.74257   3.06736    2.69805    2.62181    3.61626       2.24549    2.83168    2.29388    2.33653
  1.16773    1.13017   1.26401    1.11183    1.08041    1.49021       0.925332   1.16689    0.945271   0.962847
  ⋮                                                     ⋮         ⋱                         ⋮         
 -1.04498   -1.01138  -1.13115   -0.994957  -0.966841  -1.33356      -0.828067  -1.04423   -0.845911  -0.861639
  0.0        0.0       0.0        0.0        0.0        0.0       …   0.0        0.0        0.0        0.0
  0.0        0.0       0.0        0.0        0.0        0.0           0.0        0.0        0.0        0.0

You can then just use that component array with any optimization loop in order to have custom modifications to the neural network weights in a type stable loop.

There is no need to restart, there are tools to auto-convert Flux models using FromFluxAdaptor:

import Flux
using Adapt, Lux, Random

m = ResNet(18)
m2 = adapt(FromFluxAdaptor(), m.layers) # or FromFluxAdaptor()(m.layers)

That makes it essentially one line to match the switch for existing code.

1 Like