Accessing a ComponentArray's fields in a type-stable manner

I’m struggling with iterating over the named fields of a ComponentArray in a compiler-friendly way.

Take the following MWE, emulating some Neural Network:

  1. I create the nested structure layers, containing some weight matrices. Each is itself a NamedTuple with a lower-triangular and an upper-triangular component.

  2. I feed it to the function feedforward, which applies each layer to the input in sequence.

For layers a NamedTuple, I can simply iterate over the elements of layers and it works relatively well:

Code iterating over a NamedTuple
using NamedTupleTools, LinearAlgebra, ComponentArrays

function create_layers(size,number) 
    l = (lower_weight = LowerTriangular(rand(size,size)),
         upper_weight = UpperTriangular(rand(size,size)))

    layer_keys = Symbol.(["layer_$i" for i in 1:number])
    layer_data = [l for _ in 1:number]
    layers = namedtuple(layer_keys,layer_data)
    return layers

function feedforward(x::AbstractArray, layers)
    fwd = x
    for L in layers
        fwd = (L.lower_weight*L.upper_weight*fwd)
    return fwd 

layers = create_layers(3,3)
input = rand(3)
feedforward(input,layers) #Works OK

However, when I turn layers into a ComponentArray, indexing into it now returns its individual elements, breaking everything:

Incorrect attempt with ComponentArrays
layers_CA = ComponentArray(layers)
feedforward(input,layers_CA) #ERROR: type Float64 has no field lower_weight

A possible workaround is to collect layers_CA’s keys and then iterate over them, but then the compiler can’t really prepare for what’s coming and the result is type instability:

Code iterating over keys
function feedforward(x::AbstractArray, ps::ComponentArray)
    fwd = x
    for k in keys(ps)
        L = ps[k]
        fwd = (L.lower_weight*L.upper_weight*fwd)
    return fwd 

feedforward(input,layers_CA) #Works
@code_warntype feedforward(input,layers_CA)

Resulting in return-type of Any:

Result of @code_warntype
MethodInstance for feedforward(::Vector{Float64}, ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:18, Axis(lower_weight = ViewAxis(1:9, ShapedAxis((3, 3), NamedTuple())), upper_weight = ViewAxis(10:18, ShapedAxis((3, 3), NamedTuple())))), layer_2 = ViewAxis(19:36, Axis(lower_weight = ViewAxis(1:9, ShapedAxis((3, 3), NamedTuple())), upper_weight = ViewAxis(10:18, ShapedAxis((3, 3), NamedTuple())))), layer_3 = ViewAxis(37:54, Axis(lower_weight = ViewAxis(1:9, ShapedAxis((3, 3), NamedTuple())), upper_weight = ViewAxis(10:18, ShapedAxis((3, 3), NamedTuple())))))}}})
  from feedforward(x::AbstractArray, ps::ComponentArray) @ Main Untitled-4:44
  ps::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:18, Axis(lower_weight = ViewAxis(1:9, ShapedAxis((3, 3), NamedTuple())), upper_weight = ViewAxis(10:18, ShapedAxis((3, 3), NamedTuple())))), layer_2 = ViewAxis(19:36, Axis(lower_weight = ViewAxis(1:9, ShapedAxis((3, 3), NamedTuple())), upper_weight = ViewAxis(10:18, ShapedAxis((3, 3), NamedTuple())))), layer_3 = ViewAxis(37:54, Axis(lower_weight = ViewAxis(1:9, ShapedAxis((3, 3), NamedTuple())), upper_weight = ViewAxis(10:18, ShapedAxis((3, 3), NamedTuple())))))}}}
  @_4::Union{Nothing, Tuple{Symbol, Int64}}
1 ─       (fwd = x)
│   %2  = Main.keys(ps)::Core.Const((:layer_1, :layer_2, :layer_3))
│         (@_4 = Base.iterate(%2))
│   %4  = (@_4::Core.Const((:layer_1, 2)) === nothing)::Core.Const(false)
│   %5  = Base.not_int(%4)::Core.Const(true)
└──       goto #4 if not %5
2 ┄ %7  = @_4::Tuple{Symbol, Int64}
│         (k = Core.getfield(%7, 1))
│   %9  = Core.getfield(%7, 2)::Int64
│         (L = Base.getindex(ps, k))
│   %11 = Base.getproperty(L, :lower_weight)::Any
│   %12 = Base.getproperty(L, :upper_weight)::Any
│         (fwd = %11 * %12 * fwd)
│         (@_4 = Base.iterate(%2, %9))
│   %15 = (@_4 === nothing)::Bool
│   %16 = Base.not_int(%15)::Bool
└──       goto #4 if not %16
3 ─       goto #2
4 ┄       return fwd

How should I go about this? How can I create a function that knows (at compile time) what keys I’m going to iterate over without coding the function myself for each different layout of layers? I’ve seen people do this with metaprogramming but I don’t feel experienced enough to use @generated functions and the like.

I need this for a few custom Lux layers I’m constructing, which need to be easily translatable into ComponentArrays to be compatible with NeuralPDE.

ComponentArrays exports a function called valkeys that’s like keys but with each key wrapped in a Val for this exact reason. Try

function feedforward(x::AbstractArray, ps::ComponentArray)
    fwd = x
    for k in valkeys(ps)
        L = @view ps[k]
        fwd = (L.lower_weight*L.upper_weight*fwd)
    return fwd 
1 Like

Another option is to do for L in NamedTuple(ps) for the iteration. This seems to make @code_warntype slightly happier than the valkeys solution, but doesn’t make the @btime performance (time or allocations) any better or worse.