Memory kills this code

This code is giving accurate plot but needs too much memory and is slow. I have been working since few months to get 3D volume plot of this Astrophysical data linked here. You can download this data by clicking on Download icon and needs no sign in. This parallel code is working fine but it consumes too much RAM memory and process gets killed.

julia> @time fig = plothdf_volume_parallel(“sane00.prim.01800.athdf”; samples_per_axis = 20, grid_res = 100)
Processing 872 blocks with 14 threads…
Killed julia -t auto

This data contains many mesh blocks( 872). My 32GB ram gets filled before processing about 60 mesh blocks in this @threads for b in 1:num_blocks iteration. You can replace this with @threads for b in 5 on 28th code line to get working plot. # place number of any mesh block you like to plot.

  • What should i do to reduce memory usage and also improve performance? I want to make it run fast as much possible. I have to get many such plots.
using GLMakie, Random, Interpolations
using HDF5, Meshes, CoordRefSystems, Unitful, Statistics
using Base.Threads

function plothdf_volume_parallel(h5file::String; samples_per_axis::Int = 6, grid_res::Int = 100)
    # --- Load HDF5 data ---
    h5 = h5open(h5file, "r")
    num_blocks::Int32 = read(HDF5.attributes(h5), "NumMeshBlocks")
    MeshBlockSize::Vector{Int32} = read(HDF5.attributes(h5), "MeshBlockSize")
    B = read(h5["B"])
    Br, Bθ, Bϕ = @views B[:,:,:,:,1], B[:,:,:,:,2], B[:,:,:,:,3]
    Bp = @. hypot(Br, Bθ) / Bϕ
    x_all::Array{Float32, 2} = read(h5["x1f"])
    y_all::Array{Float32, 2} = read(h5["x2f"])
    z_all::Array{Float32, 2} = read(h5["x3f"])
    close(h5)

    # --- Visualization setup ---
    fig = Figure(size = (1000, 800))
    ax = Axis3(fig[1, 1], title = "Interpolated Volume", aspect = :data, viewmode = :free)

    println("Processing $num_blocks blocks with $(nthreads()) threads...")

    # --- Thread-local accumulators (stored in a Dict) ---
    accumulators = Dict{Int, NamedTuple{(:X, :Y, :Z, :C), NTuple{4, Vector{Float32}}}}()

    # --- Parallel over mesh blocks ---
    @threads for b in 1:num_blocks
        tid = threadid()
        acc = get!(accumulators, tid) do
            (X = Float32[], Y = Float32[], Z = Float32[], C = Float32[])
        end

        @views Bp_block = Bp[:, :, :, b]
        r = x_all[:, b]
        θ = y_all[:, b]
        ϕ = z_all[:, b]
        g = RectilinearGrid{𝔼, typeof(Spherical(0,0,0))}((r, θ, ϕ))

        for i in 1:nelements(g)
            el = element(g, i)

            # ---- Corner values ----
            clr = Vector{Float32}(undef, 8)
            bd = Boundary{3,0}(topology(g))(i)
            for (ii, id) in enumerate(bd)
                adels = Coboundary{0,3}(topology(g))(id)
                s = 0.0f0
                for idx in adels
                    inds = elem2cart(topology(g), idx)
                    s += Bp_block[inds...]
                end
                clr[ii] = s / length(adels)
            end

            # ---- Vertex coordinates ----
            verts = [Float64.(ustrip.((c.r, c.θ, c.ϕ))) for c in coords.(vertices(el))]
            rgrid = LinRange(extrema(getindex.(verts, 1))..., 2)
            θgrid = LinRange(extrema(getindex.(verts, 2))..., 2)
            ϕgrid = LinRange(extrema(getindex.(verts, 3))..., 2)

            # ---- Build 2×2×2 data cube ----
            data = Array{Float32}(undef, 2, 2, 2)
            data[1,1,1] = clr[1]; data[2,1,1] = clr[2]
            data[2,2,1] = clr[3]; data[1,2,1] = clr[4]
            data[1,1,2] = clr[5]; data[2,1,2] = clr[6]
            data[2,2,2] = clr[7]; data[1,2,2] = clr[8]

            itp = interpolate((rgrid, θgrid, ϕgrid), data, Gridded(Linear()))

            r_eval = LinRange(rgrid[1], rgrid[end], samples_per_axis)
            θ_eval = LinRange(θgrid[1], θgrid[end], samples_per_axis)
            ϕ_eval = LinRange(ϕgrid[1], ϕgrid[end], samples_per_axis)

            for rr in r_eval, th in θ_eval, ph in ϕ_eval
                val = itp(rr, th, ph)
                x = rr * sin(th) * cos(ph)
                y = rr * sin(th) * sin(ph)
                z = rr * cos(th)
                push!(acc.X, x)
                push!(acc.Y, y)
                push!(acc.Z, z)
                push!(acc.C, val)
            end
        end
    end

    println("\nInterpolation complete. Merging results...")

    # --- Merge all thread accumulators ---
    Xacc_all = reduce(vcat, [v.X for v in values(accumulators)])
    Yacc_all = reduce(vcat, [v.Y for v in values(accumulators)])
    Zacc_all = reduce(vcat, [v.Z for v in values(accumulators)])
    Cacc_all = reduce(vcat, [v.C for v in values(accumulators)])

    # --- Compute bounding box ---
    x_min, x_max = extrema(Xacc_all)
    y_min, y_max = extrema(Yacc_all)
    z_min, z_max = extrema(Zacc_all)

    nx = ny = nz = grid_res
    x_range = LinRange(x_min, x_max, nx)
    y_range = LinRange(y_min, y_max, ny)
    z_range = LinRange(z_min, z_max, nz)

    # --- Fill 3D grid with nearest-point assignment ---
    field = fill(NaN32, nx, ny, nz)
    for (xi, yi, zi, ci) in zip(Xacc_all, Yacc_all, Zacc_all, Cacc_all)
        ix = clamp(Int(round((xi - x_min)/(x_max - x_min) * (nx - 1))) + 1, 1, nx)
        iy = clamp(Int(round((yi - y_min)/(y_max - y_min) * (ny - 1))) + 1, 1, ny)
        iz = clamp(Int(round((zi - z_min)/(z_max - z_min) * (nz - 1))) + 1, 1, nz)
        field[ix, iy, iz] = ci
    end

    # --- Compute color normalization ---
    finite_vals = filter(!isnan, vec(field))
    q_low  = quantile(finite_vals, 0.05)
    q_high = quantile(finite_vals, 0.95)
    cr = (q_low, q_high)


    println("Rendering volume...")
    volume!(ax, x_min..x_max, y_min..y_max, z_min..z_max, field; colorrange=cr)
    fig
end

@time fig = plothdf_volume_parallel("sane00.prim.01800.athdf"; samples_per_axis = 20, grid_res = 100)
display(fig)

This looks like a rather complex code with many dependencies. If you want people to help, can you try profiling first (with @profview and @profview_allocs) to reduce the MWE and focus it on the parts that truly impact performance?

2 Likes

Just a few general comments that make it easier to work with such a complex code:

  1. Right now you have tangled up: loading the data, processing the data, plotting. Especially loading is inherently type unstable (although you mitigate to large extend by type annotations - however B is lacking a type annotation). So to improve modularity and avoid type inference issues I propose splitting this up into at least 3 functions for loading, processing and plotting. Note you can still have a wrapper that combines all 3 so you can have the interface that you have now. After the split we can focus on making the processing function as fast as possible without any other part interfering.
  2. Your pattern is not thread-safe

For an explanation please refer to

I wonder whether you could instead preallocate all the workspace you need. That would also bring some performance improvements. It seems that you know how many elements each block adds to an accumulator and you know how many blocks there are. So instead of using this attempted thread-local state, push!ing elements and then using vcat, you could just preallocate some matrices/vectors and write into them. Then you avoid allocations and don’t need the merge step.

  1. Small observation:

Here you almost assign the values in order, i.e. it almost a reshape. Just the data[1,2,:] and data[2,2,:] are swapped. Is that intended? I have no idea what you are doing, so I can’t judge whether that is correct - may as well be correct :slight_smile: . It just stood out to me while reading through the code.

1 Like