Hi. I am stumped as to why the forward pass is type unstable. It’s effectively a varying length chain similar to Flux.Chain on an alternating sequence of dense and dropout layers:
struct MLP
dense::Vector{Flux.Dense}
drop::Vector{Flux.Dropout}
end
function MLP(layer_dims::Vector{Int}, dropout::Bool=true, activation=tanh)
dense = Flux.Dense[]
drop = Flux.Dropout[]
for i=1:length(layer_dims)-1
if dropout && i > 1
push!(drop, Flux.Dropout(0.5))
end
if i < length(layer_dims)
push!(dense, Flux.Dense(layer_dims[i]=>layer_dims[i+1], activation))
else
push!(dense, Flux.Dense(layer[i] => layer[i+1]))
end
end
end
function (mlp::MLP)(x::Matrix{Float32}) # Forward pass
temp = x
for i=1:length(mlp.drop)
temp = mlp.dense[i](temp)
temp = mlp.drop[i](temp)
end
mlp.dense[i](temp)
mlp.dense[end](temp)
temp
end
mlp = SR.MLP([2,2,])
x_fake = rand(Float32, 2, 100)
@code_warntype mlp(x_fake)
The lowered rep shows that temp is unstable:
MethodInstance for ()
Arguments
mlp::MLP
x::Matrix{Float32}
Locals
@_3::Union{Nothing, Tuple{Int64, Int64}}
temp::Any
i::Int64
Body::Any
1 ─ (temp = x)
│ %2 = Base.getproperty(mlp, :drop)::Vector{Flux.Dropout}
│ %3 = length(%2)::Int64
│ %4 = (1:%3)::Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])
│ (@_3 = Base.iterate(%4))
│ %6 = (@_3 === nothing)::Bool
│ %7 = Base.not_int(%6)::Bool
└── goto #4 if not %7
2 ┄ %9 = @_3::Tuple{Int64, Int64}
│ (i = Core.getfield(%9, 1))
│ %11 = Core.getfield(%9, 2)::Int64
│ %12 = Base.getproperty(mlp, :dense)::Vector{Flux.Dense}
│ %13 = Base.getindex(%12, i)::Flux.Dense
│ (temp = (%13)(temp))
│ %15 = Base.getproperty(mlp, :drop)::Vector{Flux.Dropout}
│ %16 = Base.getindex(%15, i)::Flux.Dropout
│ (temp = (%16)(temp))
│ (@_3 = Base.iterate(%4, %11))
│ %19 = (@_3 === nothing)::Bool
│ %20 = Base.not_int(%19)::Bool
└── goto #4 if not %20
3 ─ goto #2
4 ┄ %23 = Base.getproperty(mlp, :dense)::Vector{Flux.Dense}
│ %24 = Base.getindex(%23, i)::Any
│ (%24)(temp)
│ %26 = Base.getproperty(mlp, :dense)::Vector{Flux.Dense}
│ %27 = Base.lastindex(%26)::Int64
│ %28 = Base.getindex(%26, %27)::Flux.Dense
│ (%28)(temp)
└── return temp
I even tried evaluating with one Dense layer and the output it still Any. Is there anything I can do about this or is it not a problem?