Solving type instabilities in simple classification model

Hi,
I’m pretty new to Julia, so please be nice:)

I wrote a simple script to identify mnist digits.
So far it runs okay, and I like the language a lot.
But what I can’t get my head wrapped around are the type instabilities I’m trying to solve.

When I am calling this function:

function create_A_model()
    Chain(
        Dense(784, 512, relu), 
        Dense(512, 10),
        softmax)         
end

This is the output from code_warntype:

Arguments
#self#::Core.Const(create_model)
config::TrainingConfig
Body ::Chain
1 ─ %1 = Base.getproperty(config, :input_size)::Int64
│ %2 = Base.getproperty(config, :hidden_size)::Int64
│ %3 = Main.Dense(%1, %2, Main.relu) ::Dense{typeof(relu), Matrix{Float32}}
│ %4 = Base.getproperty(config, :hidden_size)::Int64
│ %5 = Base.getproperty(config, :output_size)::Int64
│ %6 = Main.Dense(%4, %5)::Dense{typeof(identity), Matrix{Float32}}
│ %7 = Main.Chain(%3, %6, Main.softmax)::Chain
└── return %7
I marked the red parts in bold

What’s even more confusing to me is that if I reduce the function to

function create_model()
    Chain(Dense(1,1)) 
end

I can only post 1 pictur, so I removed the output here, but
::Chain
::Dense{typeof(identity), Matrix{Float32}}
is still red

I have tried to use structs and I can’t seem to find answers in the documentation.

The other type instabilities I can’t get rid of, are within my training function.
Originally I had more functionality, but I stopped it down to pinpoint the problem.

function train_network(
    model::Chain{Tuple{Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, typeof(softmax)}}, 
    train_data::DataLoader{Tuple{Matrix{Float32}, Flux.OneHotMatrix{UInt32, Vector{UInt32}}}}, 
    opt::Descent)
    epochstest::Int64 = 5
    for epoch::Int64 in 1:epochstest::Int64
        for (input::Matrix{Float32}, labels::Flux.OneHotMatrix{UInt32, Vector{UInt32}}) in train_data
            Flux.train!((input, labels) -> loss_function(input, labels, model), Flux.params(model), [(input, labels)], opt)
        end
    end
end

All other functions in my code are type-stable according to code_warntype.
when I understand the output correctly, then _5 has partial instability.
Are _5, _7 and _10 internal elements? How could I make them type-stable?
Or are they unstable because the Chain in the earlier function is unstable?

I would be really glad for any kind of helpful input, because I’m not getting closer to a solution no matter what I try and I’ve been trying for too long :sweat_smile:

if I can provide any more information or more code, I’m happy to do so

thanks
l4s

1 Like

First thing’s first, could you try removing all the type annotations from your code? It looks way overtyped and hard to read :sweat_smile:

Julia can infer types. Usually, it does a better job than you can do annotating them them, so we usually only provide the bare minimum typing to make sure all of our inputs are correct. Adding more types doesn’t make the code any faster, and is a frequent source of bugs. Here’s one way to rewrite this:

function train_network(model::Chain, train_data::DataLoader, opt::Optimiser)

    epochstest = 5
    for epoch in 1:epochstest
        for (input, labels) in train_data
            Flux.train!(
                (input, labels) -> loss_function(input, labels, model), 
                Flux.params(model), 
                [(input, labels)], 
                opt
            )
        end
    end
end

Now that the code is cleaner and more readable, can you spot any mistakes, and is it there still a type instability?

Welcome to the discourse!

Is there any particular reason you strive for type-stability in these function? Usually you want type-stability if you need maximal performance in a piece of code but neither of your function seem to me to be bottlenecks. You probably generate a model like once in while and Flux.train! probably does the heavy lifting for you.

Note that type-stability/inferrability is not contagious in any form. If your train_network function is not fully inferrable, then Julia needs to do dynamic dispatch for every call with inputs where Julia couldn’t figure out the precise types. But after then Julia dispatches to the method with precisely known types, so the type inference game starts anew! Probably the authors of Flux.jl took care and on their side things are inferrable and so you get good performance.

In fact it is quite often the case that the “entrance layer” to a library contains type-unstable code (for example in DifferentialEquations.jl to figure out which solvers to use) but beyond that type-unstable setup layer everything is inferrable and so you get maximal flexibility and maximal performance. This pattern is called “function barrier” in the performance tips.

6 Likes

Thank you for your answers!

The overtyping was my attempt to solve the type-instabilities that the @code_warntype macro warned me about.
In the meantime I have read a lot about type instabilities in Julia and that their impact on performance is not always that big.

Furthermore I found some other means to increase the performance of my code and I’m quite happy now, that the training of the network only takes about 1/3 of the time that an equivalent PyTorch implementation needs.

This was the reason why I hoped to get a performance boost out of achieving perfect type-stability anyways.

I marked your answer as the solution @abraemer

4 Likes

Curious here, are you comparing to basic Pytorch or @torch.compile?

I am comparing basic Pytorch to Flux to Basic Julia (no advanced packages).

I played around with torch.compile, but it didn’t produce performance increases. To my understanding the effect is most noticeable when using a GPU, having complex models as well as exporting the completed model to a different programming environment.

I’m using CPU only, the model is very basic…

Interesting, did you try asking for help on the PyTorch forums?

No I haven’t since it’s out of the scope of the study at hand, but it definitely is an interesting development for python.

1 Like