Accessing a ComponentArray's fields in a type-stable manner

I’m struggling with iterating over the named fields of a ComponentArray in a compiler-friendly way.

Take the following MWE, emulating some Neural Network:

  1. I create the nested structure layers, containing some weight matrices. Each is itself a NamedTuple with a lower-triangular and an upper-triangular component.

  2. I feed it to the function feedforward, which applies each layer to the input in sequence.

For layers a NamedTuple, I can simply iterate over the elements of layers and it works relatively well:

Code iterating over a NamedTuple
using NamedTupleTools, LinearAlgebra, ComponentArrays

function create_layers(size,number) 
    l = (lower_weight = LowerTriangular(rand(size,size)),
         upper_weight = UpperTriangular(rand(size,size)))

    layer_keys = Symbol.(["layer_$i" for i in 1:number])
    layer_data = [l for _ in 1:number]
    layers = namedtuple(layer_keys,layer_data)
    return layers
end 

function feedforward(x::AbstractArray, layers)
    fwd = x
    for L in layers
        fwd = (L.lower_weight*L.upper_weight*fwd)
    end     
    return fwd 
end 

layers = create_layers(3,3)
input = rand(3)
feedforward(input,layers) #Works OK

However, when I turn layers into a ComponentArray, indexing into it now returns its individual elements, breaking everything:

Incorrect attempt with ComponentArrays
layers_CA = ComponentArray(layers)
feedforward(input,layers_CA) #ERROR: type Float64 has no field lower_weight

A possible workaround is to collect layers_CA’s keys and then iterate over them, but then the compiler can’t really prepare for what’s coming and the result is type instability:

Code iterating over keys
function feedforward(x::AbstractArray, ps::ComponentArray)
    fwd = x
    for k in keys(ps)
        L = ps[k]
        fwd = (L.lower_weight*L.upper_weight*fwd)
    end     
    return fwd 
end

feedforward(input,layers_CA) #Works
@code_warntype feedforward(input,layers_CA)

Resulting in return-type of Any:

Result of @code_warntype
MethodInstance for feedforward(::Vector{Float64}, ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:18, Axis(lower_weight = ViewAxis(1:9, ShapedAxis((3, 3), NamedTuple())), upper_weight = ViewAxis(10:18, ShapedAxis((3, 3), NamedTuple())))), layer_2 = ViewAxis(19:36, Axis(lower_weight = ViewAxis(1:9, ShapedAxis((3, 3), NamedTuple())), upper_weight = ViewAxis(10:18, ShapedAxis((3, 3), NamedTuple())))), layer_3 = ViewAxis(37:54, Axis(lower_weight = ViewAxis(1:9, ShapedAxis((3, 3), NamedTuple())), upper_weight = ViewAxis(10:18, ShapedAxis((3, 3), NamedTuple())))))}}})
  from feedforward(x::AbstractArray, ps::ComponentArray) @ Main Untitled-4:44
Arguments
  #self#::Core.Const(feedforward)
  x::Vector{Float64}
  ps::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:18, Axis(lower_weight = ViewAxis(1:9, ShapedAxis((3, 3), NamedTuple())), upper_weight = ViewAxis(10:18, ShapedAxis((3, 3), NamedTuple())))), layer_2 = ViewAxis(19:36, Axis(lower_weight = ViewAxis(1:9, ShapedAxis((3, 3), NamedTuple())), upper_weight = ViewAxis(10:18, ShapedAxis((3, 3), NamedTuple())))), layer_3 = ViewAxis(37:54, Axis(lower_weight = ViewAxis(1:9, ShapedAxis((3, 3), NamedTuple())), upper_weight = ViewAxis(10:18, ShapedAxis((3, 3), NamedTuple())))))}}}
Locals
  @_4::Union{Nothing, Tuple{Symbol, Int64}}
  fwd::Any
  k::Symbol
  L::Any
Body::Any
1 ─       (fwd = x)
│   %2  = Main.keys(ps)::Core.Const((:layer_1, :layer_2, :layer_3))
│         (@_4 = Base.iterate(%2))
│   %4  = (@_4::Core.Const((:layer_1, 2)) === nothing)::Core.Const(false)
│   %5  = Base.not_int(%4)::Core.Const(true)
└──       goto #4 if not %5
2 ┄ %7  = @_4::Tuple{Symbol, Int64}
│         (k = Core.getfield(%7, 1))
│   %9  = Core.getfield(%7, 2)::Int64
│         (L = Base.getindex(ps, k))
│   %11 = Base.getproperty(L, :lower_weight)::Any
│   %12 = Base.getproperty(L, :upper_weight)::Any
│         (fwd = %11 * %12 * fwd)
│         (@_4 = Base.iterate(%2, %9))
│   %15 = (@_4 === nothing)::Bool
│   %16 = Base.not_int(%15)::Bool
└──       goto #4 if not %16
3 ─       goto #2
4 ┄       return fwd

How should I go about this? How can I create a function that knows (at compile time) what keys I’m going to iterate over without coding the function myself for each different layout of layers? I’ve seen people do this with metaprogramming but I don’t feel experienced enough to use @generated functions and the like.

I need this for a few custom Lux layers I’m constructing, which need to be easily translatable into ComponentArrays to be compatible with NeuralPDE.

ComponentArrays exports a function called valkeys that’s like keys but with each key wrapped in a Val for this exact reason. Try

function feedforward(x::AbstractArray, ps::ComponentArray)
    fwd = x
    for k in valkeys(ps)
        L = @view ps[k]
        fwd = (L.lower_weight*L.upper_weight*fwd)
    end     
    return fwd 
end
1 Like

Another option is to do for L in NamedTuple(ps) for the iteration. This seems to make @code_warntype slightly happier than the valkeys solution, but doesn’t make the @btime performance (time or allocations) any better or worse.