Looping in multiple dimensions extremely slow

Hi all, I am quite new to julia and I have been rewriting some old code in Julia as practice. I am currently writing a code to compute the power spectrum of a 3D field. The code consists basically of these two functions:

function mode_count_fundamental(delta_k::Array{<:Complex{T}, 3}, dims::Tuple{Int,Int,Int}, box_size::SVector{3,T}) where T <:AbstractFloat
    middle = [Int32(d / 2)  for d in dims]
    k_fund = 2. * pi / maximum(box_size)
    k_ny = middle .* k_fund
    prefactor = [pi / d for d in dims]
    k_max = Int32.(floor.(sqrt.(middle.^2 .+ middle.^2 .+ middle.^2)))
    k_edges = zeros(Float32, k_max[1] + 1)
    pk = [zeros(Float32, k_max[1] + 1) for _ in 1:3]
    pk_phase = zeros(Float32, k_max[1] + 1)
    n_modes = zeros(Int32, k_max[1] + 1)

    R = CartesianIndices(delta_k)

    for I in ProgressBar(R)
        kxx, kyy, kzz = Tuple(I) 
        kx = kxx > middle[1] ? kxx - dims[1] : kxx
        ky = kyy > middle[2] ? kyy - dims[2] : kyy
        kz = kzz > middle[3] ? kzz - dims[3] : kzz
        
        #if (kx == 0.) || ((kx == middle[1] ) && (dims[2] % 2 == 0) )
        #    if ky < 0
        #        continue
        #    elseif (ky == 0.) || ((ky == middle[2]) && (dims[2] % 2 == 0))
        #        if kz < 0
        #            continue
        #        end
        #    end
        #end

    
        cic_corr = *([cic_correction(prefactor[1] * k_) for k_ in (kx, ky, kz)]...)
        k_norm = sqrt(kx^2 + ky^2 + kz^2)
        
        k_index = Int32(floor(k_norm))
        
        k_par = kz
        #k_per = Int32(round(sqrt(kx*kx + ky*ky)))

        mu = k_norm == 0. ? 0. : k_par / k_norm
        musq = mu^2
        k_par = k_par < 0. ? -k_par : k_par
        #delta_k[kxx,kyy,kzz] *= cic_correction_x * cic_correction_y * cic_correction_z
        delta_k[I] *= cic_corr
        delta_k_sq = abs2(delta_k[kxx,kyy,kzz])
        phase = angle(delta_k[kxx,kyy,kzz])

        k_edges[k_index + 1] += k_norm
        pk[1][k_index + 1] += delta_k_sq
        pk[2][k_index + 1] += (delta_k_sq * (3. * musq - 1.) / 2.)
        pk[3][k_index + 1] += (delta_k_sq * (35. * musq^2 - 30. * musq + 3.) / 8.)
        pk_phase[k_index + 1] += phase^2
        n_modes[k_index + 1] += 1
    end
    println("Done")
    units_factor = (box_size[1] / dims[1]^2)^3
    for i in 1:length(k_edges)
        k_edges[i] *= k_fund / n_modes[i]
        pk[1][i] *= 1. / n_modes[i] * units_factor
        pk[2][i] *= 5. / n_modes[i] * units_factor
        pk[3][i] *= 9. / n_modes[i] * units_factor
        pk_phase[i] *= units_factor / n_modes[i]
    end

    return k_edges, pk, pk_phase, n_modes
end

    

function powspec_fundamental(delta::Array{<:T, 3}, box_size::SVector{3,T}, k_lim::T) where T<:AbstractFloat
    dims = size(delta)
    
    
    
    
    println("Computing FFT.")
    @time delta_k = rfft(delta)
    #delta_k[1,1,1] = 0.
    println("Done")
    println("Computing Pk from FFT")
    @time k_edges, pk, pk_phase, n_modes = mode_count_fundamental(delta_k, dims, box_size)

    return [k for k in k_edges if k<k_lim], [[pk[j][i] for i in 1:length(k_edges) if k_edges[i]<k_lim] for j in 1:3], [pk_phase[i] for i in 1:length(k_edges) if k_edges[i]<k_lim], [n_modes[i] for i in 1:length(k_edges) if k_edges[i]<k_lim] 

end

I have added many @time macros to try to find where the issue is. I see now that it comes from the multidimensional loop for I in R. I am confused because CartesianIndices is supposed to allow you to loop in a way that is memory-layout-efficient and fast.
The powspec_fundamental function takes more than 5k seconds to finish (for a field of 1024^3 grid cells) but a similar version in python+numba takes ~70s.

My code aims at reproducing part of the code here, starting from line 263.

Finally, I also noticed that the allocations for the mode-counting loop in mode_count_fundamental are too large, surpassing 500G for a 1024^3 field.

Any help with this matter is appreciated. Thanks in advance.

Here you are creating a new array at every iteration. Do really need that? Can you allocate outside the loop and mutate?

Have you tried running @code_warntype on your functions (with smaller matrices to save time)? It looks like you want everything 32-bit, but the functions specify delta_k as AbstractFloat. There may be unwanted conversions/promotions happening and slowing the code down. If you really only want to run 32-bit, then it might help to specify T<:Float32.

1 Like

use: for I in R

Julia has actual multidimensional arrays, BTW, you don’t have to do lists of lists like in Python

I did but it didn’t improve the performance noticeably.
Edit: I was doing it by preallocating like this before the loop

cic_correction_buffer = [0. for _ in 1:3]

and then within the loop I assigned with

cic_correction_buffer .= [cic_correction(prefactor[1]...

thinking the .= sintax would avoid allocating a new array. It seems though it doesn’t. I replaced this by the simpler approach

cic_corr = cic_correction(prefactor[1] * kx)
cic_corr *= cic_correction(prefactor[2] * ky)
cic_corr *= cic_correction(prefactor[3] * kz)

which DRASTICALLY decreases the loop time (20 times with a 256^3 field). A quick test with a 256^3 field seems to put the Julia function roughly as fast as the python+numba implementation. A 1024^3 field is still slower in Julia. Here the updated code.

function mode_count_fundamental(delta_k::Array{<:Complex{T}, 3}, dims::Tuple{Int,Int,Int}, box_size::SVector{3,T}) where T <:AbstractFloat
    middle = [Int32(d / 2)  for d in dims]
    k_fund = 2. * pi / maximum(box_size)
    k_ny = middle .* k_fund
    prefactor::Vector{T} = [(pi / d) for d in dims]
    k_max = Int32.(floor.(sqrt.(middle.^2 .+ middle.^2 .+ middle.^2)))
    k_edges = zeros(T, k_max[1] + 1)
    pk = zeros(T, (3, k_max[1] + 1))
    pk_phase = zeros(T, k_max[1] + 1)
    n_modes = zeros(Int32, k_max[1] + 1)

    R = CartesianIndices(delta_k)

    for I in R
        kxx, kyy, kzz = Tuple(I) 
        kx = (kxx - 1.) > middle[1] ? (kxx - dims[1] - 1) : (kxx - 1.)
        ky = (kyy - 1.) > middle[2] ? (kyy - dims[2] - 1) : (kyy - 1.)
        kz = (kzz - 1.) > middle[3] ? (kzz - dims[3] - 1) : (kzz - 1.)
        
        if (kx == 0.) || ((kx == middle[1] ) && (dims[2] % 2 == 0) )
            if ky < 0
                continue
            elseif (ky == 0.) || ((ky == middle[2]) && (dims[2] % 2 == 0))
                if kz < 0
                    continue
                end
            end
        end

        cic_corr = cic_correction(prefactor[1] * kx)
        cic_corr *= cic_correction(prefactor[2] * ky)
        cic_corr *= cic_correction(prefactor[3] * kz)
        
        
        k_norm::T = sqrt(kx^2 + ky^2 + kz^2)
        
        k_index = Int32(floor(k_norm))
        
        k_par = kz
        #k_per = Int32(round(sqrt(kx*kx + ky*ky)))

        mu::T = k_norm == 0. ? 0. : k_par / k_norm
        musq = mu^2
        k_par = k_par < 0. ? -k_par : k_par
        #delta_k[kxx,kyy,kzz] *= cic_correction_x * cic_correction_y * cic_correction_z
        delta_k[I] *= cic_corr
        delta_k_sq = abs2(delta_k[I])
        phase = angle(delta_k[I])

        k_edges[k_index + 1] += k_norm
        pk[1,k_index + 1] += delta_k_sq
        pk[2,k_index + 1] += (delta_k_sq * (3. * musq - 1.) / 2.)
        pk[3,k_index + 1] += (delta_k_sq * (35. * musq^2 - 30. * musq + 3.) / 8.)
        pk_phase[k_index + 1] += phase^2
        n_modes[k_index + 1] += 1
    end
    println("Done")
    units_factor = (box_size[1] / dims[1]^2)^3
    for i in 1:length(k_edges)
        k_edges[i] *= k_fund / n_modes[i]
        pk[1,i] *= 1. / n_modes[i] * units_factor
        pk[2,i] *= 5. / n_modes[i] * units_factor
        pk[3,i] *= 9. / n_modes[i] * units_factor
        pk_phase[i] *= units_factor / n_modes[i]
    end

    return k_edges, pk, pk_phase, n_modes
end

    

function powspec_fundamental(delta::Array{<:T, 3}, box_size::SVector{3,T}, k_lim::T) where T<:AbstractFloat
    dims = size(delta)
    
    
    
    
    println("Computing FFT.")
    @time delta_k = rfft(delta)
    #delta_k[1,1,1] = 0.
    println("Done")
    println("Computing Pk from FFT")
    @time k_edges, pk, pk_phase, n_modes = mode_count_fundamental(delta_k, dims, box_size)
    kmask = k_edges .< k_lim
    return k_edges[kmask], pk[:,kmask], pk_phase[kmask], n_modes[kmask]

end

This doesn’t seem to affect performance that much, I added it to see how long it would take but in the end any overhead from progressbar seems negligible.