Dear All,
I’m trying to make the following function containing Flux Chain and Dense type stable.
using Flux
@kwdef mutable struct Fnn{T <: Function}
neuron::Int = 5
layer::Int = 3
σ::T = Flux.relu
end
function model_build_1(nn::Fnn, n_in::Int, n_out::Int)
inner_layer = []
for i = 1:1:nn.layer
push!(inner_layer, Flux.Dense(nn.neuron, nn.neuron, nn.σ))
end
y = Flux.Chain(
fnn_input = Flux.Dense(n_in, nn.neuron, bias = false),
fnn_inner = Flux.Chain(inner_layer...),
fnn_output = Flux.Dense(nn.neuron, n_out, bias = false),
)
return y
end
network_def = Fnn()
@code_warntype model_build_1(network_def, 4, 1) # Christmas Garland
I did some try and errors with :
function model_build_2(nn::Fnn, n_in::Int, n_out::Int)
inner_layer = []
push!(inner_layer, Flux.Dense(n_in, nn.neuron, bias = false))
for i = 1:1:nn.layer
push!(inner_layer, Flux.Dense(nn.neuron, nn.neuron, nn.σ))
end
push!(inner_layer, Flux.Dense(nn.neuron, n_out, bias = false))
y = Flux.Chain(inner_layer)
return y
end
@code_warntype model_build_2(network_def, 4, 1)
There are two problems that I cannot solve with red highlight:
%11 = Core.kwcall(%8, %9, n_in, %10)::Dense{typeof(identity), Matrix{Float32}}
and
%30 = (%26)(%27, %28, %29)::Dense{typeof(relu), Matrix{Float32}}
│ (%24)(%25, %30)
│ (@_5 = Base.iterate(%15, %23))
Does anyone have any ideas?
Regards