Type-stable function with Flux Chain and Dense

Dear All,

I’m trying to make the following function containing Flux Chain and Dense type stable.

using Flux

@kwdef mutable struct Fnn{T <: Function} 
    neuron::Int = 5
    layer::Int = 3
    σ::T = Flux.relu
end

function model_build_1(nn::Fnn, n_in::Int, n_out::Int)

    inner_layer = []

    for i = 1:1:nn.layer
        push!(inner_layer, Flux.Dense(nn.neuron, nn.neuron, nn.σ))
    end

    y = Flux.Chain(
        fnn_input = Flux.Dense(n_in, nn.neuron, bias = false),
        fnn_inner = Flux.Chain(inner_layer...),
        fnn_output = Flux.Dense(nn.neuron, n_out, bias = false),
    )
    
    return y
end


network_def = Fnn()
@code_warntype model_build_1(network_def, 4, 1) # Christmas Garland


I did some try and errors with :

function model_build_2(nn::Fnn, n_in::Int, n_out::Int)

    inner_layer = []

    push!(inner_layer, Flux.Dense(n_in, nn.neuron, bias = false))
    for i = 1:1:nn.layer
        push!(inner_layer, Flux.Dense(nn.neuron, nn.neuron, nn.σ))
    end
    push!(inner_layer, Flux.Dense(nn.neuron, n_out, bias = false))

    y = Flux.Chain(inner_layer)
    
    return y
end

@code_warntype model_build_2(network_def, 4, 1) 

There are two problems that I cannot solve with red highlight:

  %11 = Core.kwcall(%8, %9, n_in, %10)::Dense{typeof(identity), Matrix{Float32}}

and

  %30 = (%26)(%27, %28, %29)::Dense{typeof(relu), Matrix{Float32}}
│         (%24)(%25, %30)
│         (@_5 = Base.iterate(%15, %23))

Does anyone have any ideas?
Regards

I think this can’t work, with the given struct. The type of the model you construct includes a Tuple of 3 dense layers, and this 3 is an ordinary value in the input Fnn().

julia> network_def = Fnn()  # 3 is number of inner Dense layers...
Fnn{typeof(relu)}(5, 3, relu)

julia> model = model_build_1(network_def, 4, 1) 
Chain(
  fnn_input = Dense(4 => 5; bias=false),  # 20 parameters
  fnn_inner = Chain(
    Dense(5 => 5, relu),                # 30 parameters
    Dense(5 => 5, relu),                # 30 parameters
    Dense(5 => 5, relu),                # 30 parameters
  ),
  fnn_output = Dense(5 => 1; bias=false),  # 5 parameters
)                   # Total: 8 arrays, 115 parameters, 900 bytes.

julia> model[2].layers  # ... and that 3 gives Tuple{Dense, Dense, Dense}
(Dense(5 => 5, relu), Dense(5 => 5, relu), Dense(5 => 5, relu))

You could fix this by changing Fnn to store 3 in the type, for instance by using Val(3). And then replacing inner_layer = []; push!(... code with ntuple(i -> Dense(nn.neuron, nn.neuron, nn.σ), Val(3)).

However, I’d start by questioning whether you want to. It’s important that code which runs in a tight loop be type-stable. But for code that basically runs once, like setup and plotting, it doesn’t matter at all. If you really want to run model_build_1 many times in a loop, can you explain why? Could you instead re-use the same model?

Thank you very much for your answer.

The solution allows to remove red flag with @code_warntype. However, I would like to assess the time with @time to measure gains, but it provides an error:

julia> @time model_build_1(network_def, 4, 1) 
ERROR: MethodError: no method matching (::Colon)(::Int64, ::Int64, ::Val{3})
The function `Colon()` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  (::Colon)(::T, ::T, ::T) where T<:Real
   @ Base range.jl:22
  (::Colon)(::T, ::Any, ::T) where T<:Real
   @ Base range.jl:50
  (::Colon)(::T, ::T) where T<:Real
   @ Base range.jl:5
  ...

In addition, the function can be called multiple times, for example, when performing a meta-heuristic optimization, where we modify the activation function, the number of layers, and the number of neurons.

Somewhere you are doing something like 1:2:Val(3) instead of 1:2:3, but we can’s see where.