Context:
Hi everyone, I’m using Turing.jl
with NUTS, to infer the parameters a chaotic ODE system. Due to the chaos, doing inference on the positions of the particles as the time window of the inference gets longer becomes more and more difficult as the paths stray further and further away. Hence I am doing inference on the density of particles over time.
The “natural” way of calculating densities (that no reviewer will complain about) is using Voronoi diagrams to calculate the area around the particle, and inverting it to get the density around that particle. Hence, each time Turing
calculates the log probability, it solves the ODE, calculates a Voronoi diagram on a grid at each dt
, and then I perform the additional step of subsampling the density grid.
Disclaimer: This is for my Master’s thesis, if I end up using any code written by anyone else then I will credit them in the acknowledgements section unless they would prefer not to be mentioned.
Problem:
To calculate rasterised Voronoi diagrams (grids such that each cell is coloured/marked the colour of its closest seed) I am using the jump flood algorithm. Unfortunately, while this is fast, it is supposed to be done in parallel for each pixel. Given that I am running this on CPUs, I’m not sure how best to take advantage of this.
Additionally, my code (below) is not optimised for memory or speed, just to work. I’d like to reduce the memory allocated each iteration, as this is seriously slowing down the inference. I would highly appreciate any contributions to making this code faster. An easy win might be changing the datatype I’m using to store seeds from Set
to something else, but I’m not sure.
function distance(x, y)
sqrt((x[1] - y[1])^2 + (x[2] - y[2])^2)
end
function jump_flood_voronoi!(grid, original_seeds; count_colours = zeros(Int64, length(original_seeds)))
N, M = size(grid)
colours = 1:length(original_seeds)
seeds = [Set{Tuple{Int64,Int64}}() for _ in colours]
new_seeds = [Set{Tuple{Int64,Int64}}() for _ in colours]
seeds_to_delete = [Set{Tuple{Int64,Int64}}() for _ in colours]
for (colour, seed) in enumerate(original_seeds)
grid[seed...] = colour
push!(seeds[colour], seed)
count_colours[colour] += 1
end
# For correctness, many jump flood algorithms run additional passes
base_nsteps = ceil(Int, log2(max(M, N))) - 1
extra_steps = 2
extra_stepsizes = (2, 1)
step = 2^base_nsteps
ncompleted_steps = 0
while ncompleted_steps < base_nsteps + extra_steps
for colour in colours
original_seed_coords = original_seeds[colour]
for seed in seeds[colour]
for direction in ((1, 1), (1, 0), (1, -1), (0, 1), (0, -1), (-1, 1), (-1, 0), (-1, -1))
i, j = Int.(seed .+ (step .* direction))
if all(1 .<= (i, j) .<= (N, M))
if grid[i, j] == colour
continue
elseif grid[i, j] == 0
grid[i, j] = colour
push!(new_seeds[colour], (i, j))
count_colours[colour] += 1
else
current_colour = grid[i, j]
current_dist = distance((i, j), original_seeds[current_colour])
new_dist = distance((i, j), original_seed_coords)
if new_dist < current_dist
grid[i, j] = colour
push!(new_seeds[colour], (i, j))
count_colours[colour] += 1
push!(seeds_to_delete[current_colour], (i, j))
count_colours[current_colour] -= 1
end
end
end
end
end
end
for colour in colours
union!(seeds[colour], new_seeds[colour])
setdiff!(seeds[colour], seeds_to_delete[colour])
end
new_seeds .= (Set{Tuple{Int64,Int64}}() for _ in colours)
seeds_to_delete .= (Set{Tuple{Int64,Int64}}() for _ in colours)
ncompleted_steps += 1
if ncompleted_steps > base_nsteps
step = extra_stepsizes[ncompleted_steps-base_nsteps]
else
step ÷= 2
end
end
return nothing
end
Here’s an example of the code running:
using Plots
grid = zeros(Int64, 1000, 1000)
original_seeds = reinterpret(reshape, Tuple{Int64,Int64}, rand(1:1000, 2, 100))
jump_flood_voronoi!(grid, original_seeds)
heatmap(grid, aspect_ratio = 1.0, colour = palette(:viridis), legend = :none, size = (500, 500))
Currently, my benchmark says:
julia> @benchmark jump_flood_voronoi!(grid, original_seeds)
BenchmarkTools.Trial: 4204 samples with 1 evaluation.
Range (min … max): 784.421 μs … 16.627 ms ┊ GC (min … max): 0.00% … 91.92%
Time (median): 898.941 μs ┊ GC (median): 0.00%
Time (mean ± σ): 1.174 ms ± 1.133 ms ┊ GC (mean ± σ): 16.61% ± 14.92%
█▇▄▅▄▂ ▁
███████▇▄▆▆▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▁▄▇██▇▅▅▃▁▁▃ █
784 μs Histogram: log(frequency) by time 7.33 ms <
Memory estimate: 1.59 MiB, allocs estimate: 10828.