I’m struggling with iterating over the named fields of a ComponentArray in a compiler-friendly way.
Take the following MWE, emulating some Neural Network:
-
I create the nested structure
layers
, containing some weight matrices. Each is itself aNamedTuple
with a lower-triangular and an upper-triangular component. -
I feed it to the function
feedforward
, which applies each layer to the input in sequence.
For layers
a NamedTuple
, I can simply iterate over the elements of layers
and it works relatively well:
Code iterating over a NamedTuple
using NamedTupleTools, LinearAlgebra, ComponentArrays
function create_layers(size,number)
l = (lower_weight = LowerTriangular(rand(size,size)),
upper_weight = UpperTriangular(rand(size,size)))
layer_keys = Symbol.(["layer_$i" for i in 1:number])
layer_data = [l for _ in 1:number]
layers = namedtuple(layer_keys,layer_data)
return layers
end
function feedforward(x::AbstractArray, layers)
fwd = x
for L in layers
fwd = (L.lower_weight*L.upper_weight*fwd)
end
return fwd
end
layers = create_layers(3,3)
input = rand(3)
feedforward(input,layers) #Works OK
However, when I turn layers
into a ComponentArray
, indexing into it now returns its individual elements, breaking everything:
Incorrect attempt with ComponentArrays
layers_CA = ComponentArray(layers)
feedforward(input,layers_CA) #ERROR: type Float64 has no field lower_weight
A possible workaround is to collect layers_CA
’s keys and then iterate over them, but then the compiler can’t really prepare for what’s coming and the result is type instability:
Code iterating over keys
function feedforward(x::AbstractArray, ps::ComponentArray)
fwd = x
for k in keys(ps)
L = ps[k]
fwd = (L.lower_weight*L.upper_weight*fwd)
end
return fwd
end
feedforward(input,layers_CA) #Works
@code_warntype feedforward(input,layers_CA)
Resulting in return-type of Any:
Result of @code_warntype
MethodInstance for feedforward(::Vector{Float64}, ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:18, Axis(lower_weight = ViewAxis(1:9, ShapedAxis((3, 3), NamedTuple())), upper_weight = ViewAxis(10:18, ShapedAxis((3, 3), NamedTuple())))), layer_2 = ViewAxis(19:36, Axis(lower_weight = ViewAxis(1:9, ShapedAxis((3, 3), NamedTuple())), upper_weight = ViewAxis(10:18, ShapedAxis((3, 3), NamedTuple())))), layer_3 = ViewAxis(37:54, Axis(lower_weight = ViewAxis(1:9, ShapedAxis((3, 3), NamedTuple())), upper_weight = ViewAxis(10:18, ShapedAxis((3, 3), NamedTuple())))))}}})
from feedforward(x::AbstractArray, ps::ComponentArray) @ Main Untitled-4:44
Arguments
#self#::Core.Const(feedforward)
x::Vector{Float64}
ps::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:18, Axis(lower_weight = ViewAxis(1:9, ShapedAxis((3, 3), NamedTuple())), upper_weight = ViewAxis(10:18, ShapedAxis((3, 3), NamedTuple())))), layer_2 = ViewAxis(19:36, Axis(lower_weight = ViewAxis(1:9, ShapedAxis((3, 3), NamedTuple())), upper_weight = ViewAxis(10:18, ShapedAxis((3, 3), NamedTuple())))), layer_3 = ViewAxis(37:54, Axis(lower_weight = ViewAxis(1:9, ShapedAxis((3, 3), NamedTuple())), upper_weight = ViewAxis(10:18, ShapedAxis((3, 3), NamedTuple())))))}}}
Locals
@_4::Union{Nothing, Tuple{Symbol, Int64}}
fwd::Any
k::Symbol
L::Any
Body::Any
1 ─ (fwd = x)
│ %2 = Main.keys(ps)::Core.Const((:layer_1, :layer_2, :layer_3))
│ (@_4 = Base.iterate(%2))
│ %4 = (@_4::Core.Const((:layer_1, 2)) === nothing)::Core.Const(false)
│ %5 = Base.not_int(%4)::Core.Const(true)
└── goto #4 if not %5
2 ┄ %7 = @_4::Tuple{Symbol, Int64}
│ (k = Core.getfield(%7, 1))
│ %9 = Core.getfield(%7, 2)::Int64
│ (L = Base.getindex(ps, k))
│ %11 = Base.getproperty(L, :lower_weight)::Any
│ %12 = Base.getproperty(L, :upper_weight)::Any
│ (fwd = %11 * %12 * fwd)
│ (@_4 = Base.iterate(%2, %9))
│ %15 = (@_4 === nothing)::Bool
│ %16 = Base.not_int(%15)::Bool
└── goto #4 if not %16
3 ─ goto #2
4 ┄ return fwd
How should I go about this? How can I create a function that knows (at compile time) what keys I’m going to iterate over without coding the function myself for each different layout of layers
? I’ve seen people do this with metaprogramming but I don’t feel experienced enough to use @generated
functions and the like.
I need this for a few custom Lux
layers I’m constructing, which need to be easily translatable into ComponentArrays
to be compatible with NeuralPDE.