Speed up parallel maximum across columns

Hi Julia community!

I need help on speeding up a function with parallelization. Here’s a MWE.

using Distributed; n_cores = 3; addprocs(n_cores)
@everywhere using SharedArrays, BenchmarkTools
M = 10000; N = 10000; k = 2; mu = rand(N, 1)
U1     = SharedArray{Float32}(rand(M, N))
U2     = SharedArray{Float32}(rand(M, N))
idv_d  = SharedArray{Float32}(N, M)
d      = SharedArray{Float32}(M, 1)
function get_d!(U1, U2, mu, k, d, idv_d)
    fill!(idv_d, 0.0)
    @inbounds @sync @distributed for j in 1:N
        u_1, arg_1 = findmax(@view U1[:, j])
        u_2, arg_2 = findmax(@view U2[:, j])
        if u_1 >= u_2
            idv_d[j, arg_1] = mu[j]
        else
            idv_d[j, arg_2] = mu[j] ./ k
        end
    end
    @inbounds @sync @distributed for j in 1:M
        d[j] = sum(idv_d[:, j])
    end
end
@btime get_d!(U1, U2, mu, k, d, idv_d)

M and N will be two large numbers (on the order of 10^6), and I intend to run this on clusters so n_cores will be about 50-100. The function calculates aggregate demand given two utility matrices (U1, U2). The only output that I’m interested in is d not idv_d, but I don’t see a fast way to get d with parallelization without calculating idv_d first.

I have tried to optimize the code for a while and this is the best I can come up with. I noticed that looping over columns helps a lot. However, I wonder whether I’m leaving some performance on the table. Are there any ways to speed this up? I have no concerns about readability. Would really appreciate any help!

Update: I spent a bit more time on the code. Running the above in Julia v1.3.1 on my macbook gives

  308.112 ms (1440 allocations: 70.05 KiB)

I changed where I put @views and got ~20% reduction on runtime.

function get_d_v2!(U1, U2, mu, k, d, idv_d)
    fill!(idv_d, 0.0)
    @inbounds @sync @distributed for j in 1:N
        u_1, arg_1 = @views findmax(U1[:, j])
        u_2, arg_2 = @views findmax(U2[:, j])
        if u_1 >= u_2
            idv_d[j, arg_1] = mu[j]
        else
            idv_d[j, arg_2] = mu[j] / k
        end
    end
    @inbounds @sync @distributed for j in 1:M
        d[j] = @views sum(idv_d[:, j])
    end
end
@btime get_d_v2!(U1, U2, mu, k, d, idv_d)
  255.680 ms (1439 allocations: 70.06 KiB)