# Speed up parallel maximum across columns

Hi Julia community!

I need help on speeding up a function with parallelization. Here’s a MWE.

``````using Distributed; n_cores = 3; addprocs(n_cores)
@everywhere using SharedArrays, BenchmarkTools
M = 10000; N = 10000; k = 2; mu = rand(N, 1)
U1     = SharedArray{Float32}(rand(M, N))
U2     = SharedArray{Float32}(rand(M, N))
idv_d  = SharedArray{Float32}(N, M)
d      = SharedArray{Float32}(M, 1)
function get_d!(U1, U2, mu, k, d, idv_d)
fill!(idv_d, 0.0)
@inbounds @sync @distributed for j in 1:N
u_1, arg_1 = findmax(@view U1[:, j])
u_2, arg_2 = findmax(@view U2[:, j])
if u_1 >= u_2
idv_d[j, arg_1] = mu[j]
else
idv_d[j, arg_2] = mu[j] ./ k
end
end
@inbounds @sync @distributed for j in 1:M
d[j] = sum(idv_d[:, j])
end
end
@btime get_d!(U1, U2, mu, k, d, idv_d)
``````

`M` and `N` will be two large numbers (on the order of 10^6), and I intend to run this on clusters so `n_cores` will be about 50-100. The function calculates aggregate demand given two utility matrices (`U1`, `U2`). The only output that I’m interested in is `d` not `idv_d`, but I don’t see a fast way to get `d` with parallelization without calculating `idv_d` first.

I have tried to optimize the code for a while and this is the best I can come up with. I noticed that looping over columns helps a lot. However, I wonder whether I’m leaving some performance on the table. Are there any ways to speed this up? I have no concerns about readability. Would really appreciate any help!

Update: I spent a bit more time on the code. Running the above in Julia v1.3.1 on my macbook gives

``````  308.112 ms (1440 allocations: 70.05 KiB)
``````

I changed where I put @views and got ~20% reduction on runtime.

``````function get_d_v2!(U1, U2, mu, k, d, idv_d)
fill!(idv_d, 0.0)
@inbounds @sync @distributed for j in 1:N
u_1, arg_1 = @views findmax(U1[:, j])
u_2, arg_2 = @views findmax(U2[:, j])
if u_1 >= u_2
idv_d[j, arg_1] = mu[j]
else
idv_d[j, arg_2] = mu[j] / k
end
end
@inbounds @sync @distributed for j in 1:M
d[j] = @views sum(idv_d[:, j])
end
end
@btime get_d_v2!(U1, U2, mu, k, d, idv_d)
``````
``````  255.680 ms (1439 allocations: 70.06 KiB)
``````