Flux slows down by 10x when moving from local system to high performance cluster

Hi everyone,

I am using Flux to solve a nonstandard optimisation problem with custom loss.
The solution is part a of a larger simulation which I run many times (100,000+) so speed is key. When I try to run my script on the high performance cluster the code slows down by over 10 times.

I am using Julia v10.4.1 and Flux v0.14.19.
I have included a minimal working version with some magic numbers as loading the actual parameters would bloat the script.

On my local machine (OSX 14.5) I get:
8.243264 seconds (31.65 M allocations: 1.984 GiB, 7.36% gc time, 99.85% compilation time)
2.075706 seconds (1.22 M allocations: 6.835 GiB, 6.76% gc time)

On the HPC cluster (Linux 5.14.0-427.31.1.el9_4.x86_64) I get the following output:
23.756881 seconds (30.78 M allocations: 1.924 GiB, 5.56% gc time, 99.32% compilation time)
20.209329 seconds (1.22 M allocations: 6.810 GiB, 0.66% gc time)

I want to upload the file as an attachment but as I am new I have to copy paste the code in this text box.

I would appreciate any help or suggestions!

using Flux

network_width = 32

perceptron = Chain(
    Dense(6, network_width, leakyrelu), 
    Dense(network_width, network_width, leakyrelu), 
    Dense(network_width, network_width, leakyrelu), 
    Dense(network_width, network_width, leakyrelu),  
    Dense(network_width, 3, relu)      
  )

opt_state =  Flux.Optimiser(Flux.Adam(1e-6),ClipValue(1e-5))
state = Flux.setup(opt_state, perceptron)

k = Float32.(Vector(range(1,100,1000)))
b = Float32.(Vector(range(1,100,1000)))
w = Float32.(Vector(range(1,100,1000)))
r_k = Float32.(vcat(fill(0.1, length(k))...))
r_b = Float32.(vcat(fill(0.01, length(k))...))    
p = Float32.(vcat(fill(1, length(k))...))    
pi_ = Float32.(vcat(fill(0.01, length(k))...))    

function abs_appr(x)
    y = sqrt.(x.^2 .+ Float32(1e-6)) 
    return y 
end

function Residuals(perceptron, r_k, r_b, k, b, w, p, pi_,weights)
    n = size(w, 1) 

    s = hcat(r_k, r_b, k, b, w, p)'  
    x = perceptron(s)  

    c  = x[1, :] 
    k1 = x[2, :]
    b1 = x[3, :]

    d = k1 .- (1 .+ r_k) .* k

    rknext = Float32.(max.(exp.(log.(1 .+ r_k) .* 0.9 .+ 0.1 .+ 0.1 * randn(Float32, n)) .- 1,0))
    rbnext = Float32.(exp.(log.(1 .+ r_b) .* 0.9 .+ 0.1 .+ 0.1 * randn(Float32, n)) .- 1)
    pinext = Float32.(exp.(log.(1 .+ pi_) .* 0.9 .+ 0.1 .+ 0.1 * randn(Float32, n)) .- 1)
    wnext = Float32.(w .* 0.9 .+ 0.1 .+ 0.1 * randn(Float32, n))
    p1 = Float32.(p .* (1 .+ pinext))

    s = hcat(rknext, rbnext, k1, b1, wnext, p1)'  
    
    x = perceptron(s)  
    c1  = x[1, :] 
    k2  = x[2, :]

    d1 = k2 .- (1 .+ rknext) .* k1


    R1 =  Float32.(1 .- 0.95 .* (1 .+ rbnext) .* (c1  ./ c ).^(-1.5) .* (p ./ p1))
    R2 =  Float32.(w .+ (1 .+ r_b) .* b .+ (1 .+ r_k) .* k .- c .* p .- b1 .- 0.01 .* abs_appr.(d).^1.5 .- k1)
    R3 =  Float32.(1 .+ d .* 0.01 .* 1.5 .* abs_appr.(d).^(1.5 - 2) .- 0.95 .* (1 .+ rknext) .* (c1 ./ c ).^(-1.5) .* (p ./ p1) .* (1 .+ d1 .* 0.01 .* 1.5 .* abs_appr.(d1).^(1.5 - 2)))

    R_squared = sum(weights[1] * R1.^2 + weights[2] *R2.^2 + weights[3] *R3.^2)/n

    return R_squared
end


function train_me!(epochs, perceptron, w, k, b, r_k, r_b, p, pi_, state; weights = [1,1,1])

    for epoch in 1:epochs
        # Compute the value and gradients of the loss function
        val, grads = Flux.withgradient(perceptron) do m

            loss = Residuals(m, r_k, r_b, k, b, w, p, pi_, weights)

        end

        Flux.update!(state, perceptron, grads[1])

    end
end

@time train_me!(2, perceptron, w, k, b, r_k, r_b, p, pi_, state; weights = [1, 0.1, 1]);

@time train_me!(1000, perceptron, w, k, b, r_k, r_b, p, pi_, state; weights = [1, 0.1, 1]);```

I’m not an expert in Flux performance but it seems you’re doing basically all the work in Float64, before converting everything into Float32 at the very end. Does it change something when you try to stick to Float32 throughout?
You could start by replacing every occurrence of 0.1 with 1f-1 and so on.

Hi! I don’t think it is due to the Float32 casting as my original code passes parameters as Float32. I have also implemented the same code with your suggested way of writing the magic numbers and this did not affect performance.
However, I learned how to directly write a float in Float32 with your suggestion so thank you for that!

1 Like

Hi @SGHoekstra, I don’t know anything about Flux, but how are you running your code on the cluster? Are you running with just 1 CPU?

There may be differences due to your laptop multithreading particular things. For example matrix multiplication with 1 vs 12 threads on my laptop:

julia> a = rand(5000,5000);

julia> BLAS.set_num_threads(1)

julia> @elapsed a*a
4.6207308

julia> BLAS.set_num_threads(12)

julia> @elapsed a*a
1.5673504
1 Like

Can I understand why you use Flux instead of PyTorch?

Hi @p_f,

I am doing test runs now. The code will run parallel on CPU nodes as I will be doing Monte Carlo simulations (many independent runs).

The idea is that I spawn independent Julia workers to run the code in parallel using the Distributed package. On my Mac I would run 12 processes at the same time. At the high performance cluster I would run 128 processes at the same time as each node has 128 cores hence one process per core.

I have tried experimenting a bit with setting the number of threads to a different number. I got the following results:

Running on node: fcn1
Number of processors allocated by SLURM: 16
Number of threads: 1
21.707765 seconds (31.16 M allocations: 1.916 GiB, 6.08% gc time, 99.87% compilation time)
2.532860 seconds (1.01 M allocations: 6.150 GiB, 4.12% gc time)
Number of threads: 2
3.861321 seconds (1.01 M allocations: 6.150 GiB, 3.04% gc time)
Number of threads: 4
4.540994 seconds (1.01 M allocations: 6.150 GiB, 2.26% gc time)
Number of threads: 8
5.932519 seconds (1.01 M allocations: 6.150 GiB, 2.47% gc time)
Number of threads: 16
7.079593 seconds (1.01 M allocations: 6.150 GiB, 1.41% gc time)
Number of threads: 32
11.259781 seconds (1.01 M allocations: 6.150 GiB, 0.86% gc time)

Apparently less threads are better in this case? Anyways setting BLAS threads to one makes the HPC as fast as my local system. So thank you for your suggestion!

1 Like

The neural net is part of a larger simulation which is completely developed in Julia. I have tried using python and another programming language before and found that it was too unstable to be usable in this case. Hence I wanted to use a natively developed ML package and thus choose for FLUX.

For more explanations on this:
https://docs.julialang.org/en/v1/manual/performance-tips/#man-multithreading-linear-algebra

1 Like

Thanks. Makes sense. I

Thanks! Explains a lot

For a similar finding, please see this thread.

1 Like