Hi,
I’m pretty new to Julia, so please be nice:)
I wrote a simple script to identify mnist digits.
So far it runs okay, and I like the language a lot.
But what I can’t get my head wrapped around are the type instabilities I’m trying to solve.
When I am calling this function:
function create_A_model()
Chain(
Dense(784, 512, relu),
Dense(512, 10),
softmax)
end
This is the output from code_warntype:
Arguments
#self#::Core.Const(create_model)
config::TrainingConfig
Body ::Chain
1 ─ %1 = Base.getproperty(config, :input_size)::Int64
│ %2 = Base.getproperty(config, :hidden_size)::Int64
│ %3 = Main.Dense(%1, %2, Main.relu) ::Dense{typeof(relu), Matrix{Float32}}
│ %4 = Base.getproperty(config, :hidden_size)::Int64
│ %5 = Base.getproperty(config, :output_size)::Int64
│ %6 = Main.Dense(%4, %5)::Dense{typeof(identity), Matrix{Float32}}
│ %7 = Main.Chain(%3, %6, Main.softmax)::Chain
└── return %7
I marked the red parts in bold
What’s even more confusing to me is that if I reduce the function to
function create_model()
Chain(Dense(1,1))
end
I can only post 1 pictur, so I removed the output here, but
::Chain
::Dense{typeof(identity), Matrix{Float32}}
is still red
I have tried to use structs and I can’t seem to find answers in the documentation.
The other type instabilities I can’t get rid of, are within my training function.
Originally I had more functionality, but I stopped it down to pinpoint the problem.
function train_network(
model::Chain{Tuple{Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, typeof(softmax)}},
train_data::DataLoader{Tuple{Matrix{Float32}, Flux.OneHotMatrix{UInt32, Vector{UInt32}}}},
opt::Descent)
epochstest::Int64 = 5
for epoch::Int64 in 1:epochstest::Int64
for (input::Matrix{Float32}, labels::Flux.OneHotMatrix{UInt32, Vector{UInt32}}) in train_data
Flux.train!((input, labels) -> loss_function(input, labels, model), Flux.params(model), [(input, labels)], opt)
end
end
end
All other functions in my code are type-stable according to code_warntype.
when I understand the output correctly, then _5 has partial instability.
Are _5, _7 and _10 internal elements? How could I make them type-stable?
Or are they unstable because the Chain in the earlier function is unstable?
I would be really glad for any kind of helpful input, because I’m not getting closer to a solution no matter what I try and I’ve been trying for too long
if I can provide any more information or more code, I’m happy to do so
thanks
l4s