Hard to beat Numba / JAX loop for generating Mandelbrot

Here is some verbose and ugly code agnostic with respect to `N`
using VectorizationBase, LayoutPointers

function run5_unroll!(fractal, y, x, ::Val{U}=Val(4), ::Val{N} = Val(20)) where {U,N}
  # Base.require_one_based_indexing(fractal, y, x)
  VW = pick_vector_width(Base.promote_eltype(fractal,y,x))
  _fp, fractalpres = LayoutPointers.stridedpointer_preserve(fractal)
  _yp, ypres = LayoutPointers.stridedpointer_preserve(y)
  _xp, xpres = LayoutPointers.stridedpointer_preserve(x)
  fp = VectorizationBase.zero_offsets(_fp)
  yp = VectorizationBase.zero_offsets(_yp)
  xp = VectorizationBase.zero_offsets(_xp)
  W = length(x)
  H = length(y)
  I = eltype(fractal)
  # LW = LX & (U-1)
  # LH = LY & (VW-1)
  GC.@preserve fractalpres ypres xpres begin
    w = 0
    while w < (W+1-U)
      _c_re = vload(xp, VectorizationBase.Unroll{1,1,U,0,1,zero(UInt64)}((w,)))
      z_re4 = c_re4 = vbroadcast(VW, _c_re)
      h = 0
      while h < H
        _c_im = vload(yp, (MM(VW, h),))
        z_im4 = c_im4 = VecUnroll(ntuple(Returns(_c_im), Val(U)))
        _m = VectorizationBase.Mask(VectorizationBase.mask(VW, h, H))
        m4 = VecUnroll(ntuple(Returns(_m), Val(U)))
        i = one(I)
        fhw4 = VecUnroll(ntuple(Returns(vbroadcast(VW,N%I)), Val(U)))
        # fhw4 = vload(xp, VectorizationBase.Unroll{2,1,U,1,Int(VW),0xf%UInt64}((h,w)), _m)
        while true
          z_re4,z_im4 = c_re4 + z_re4*z_re4 - z_im4*z_im4, c_im4 + 2f0*z_re4*z_im4
          az4 = (z_re4*z_re4 + z_im4*z_im4) > 4f0

          fhw4 = VectorizationBase.ifelse(m4 & az4, i, fhw4)
          m4 &= (!az4)
          (any(VectorizationBase.data(VectorizationBase.vany(m4))) && ((i+=one(i)) <= N)) || break
        end
        vstore!(fp, fhw4, VectorizationBase.Unroll{2,1,U,1,Int(VW),0xf%UInt64}((h,w)), _m)
        h += VW
      end
      w += U
    end
    while w < W
      z_re = c_re = vbroadcast(VW, vload(xp, (w,)))
      h = 0
      while h < H
        z_im = c_im = vload(yp, (MM(VW, h),))
        m = mi = VectorizationBase.Mask(VectorizationBase.mask(VW, h, H))
        i = zero(Int32)
        fhw = vbroadcast(VW,N%I)
        while true
          z_re,z_im = c_re + z_re*z_re - z_im*z_im, c_im + 2f0*z_re*z_im
          az4 = (z_re*z_re + z_im*z_im) > 4f0

          fhw = VectorizationBase.ifelse(mi & az4, i, fhw)
          mi &= (!az4)
          (VectorizationBase.vany(mi) && ((i+=one(i)) < N)) || break
        end
        vstore!(fp, fhw, (MM(VW,h), w), m)
        h += VW
      end      
      w += 1
    end
  end
  fractal
end
function run5(height, width)
  y = range(-1.0f0, 0.0f0; length = height)
  x = range(-1.5f0, 0.0f0; length = width)
  fractal = Matrix{Int32}(undef, height, width)
  return run5_unroll!(fractal, y, x, Val(4), Val(20))
end

Here are prettier benchmarks:

julia> @benchmark run5(10,10)
BenchmarkTools.Trial: 10000 samples with 196 evaluations.
 Range (min … max):  465.469 ns …   5.591 μs  ┊ GC (min … max): 0.00% … 84.27%
 Time  (median):     472.441 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   488.324 ns ± 166.302 ns  ┊ GC (mean ± σ):  1.46% ±  3.88%

  ▃▃█▆▄▂▂▃▃▃▂▁▁                                   ▁ ▁           ▁
  ██████████████▅▅▅▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▁▁▁▁▁▁▁▁▁▃▅▇▇████▇▆▆▆▇▆▇▇▇ █
  465 ns        Histogram: log(frequency) by time        626 ns <

 Memory estimate: 496 bytes, allocs estimate: 1.

julia> @benchmark run_julia(10,10)
BenchmarkTools.Trial: 10000 samples with 8 evaluations.
 Range (min … max):  3.318 μs … 172.239 μs  ┊ GC (min … max): 0.00% … 95.47%
 Time  (median):     3.368 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   3.539 μs ±   2.293 μs  ┊ GC (mean ± σ):  0.89% ±  1.36%

  ▅█▇▄▃▂      ▃▄▂▁  ▃▆▅▃▂▁▁▃▃▂▁▁                              ▂
  ████████▅▅▂▅█████▇██████████████▇▆▆▅▆▆▅▅▅▅▅▄▄▃▄▄▃▄▅▃▄▄▄▅▄▄▆ █
  3.32 μs      Histogram: log(frequency) by time       4.4 μs <

 Memory estimate: 1.36 KiB, allocs estimate: 2.

julia> @benchmark run5(100,100)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):   9.768 μs …  1.219 ms  ┊ GC (min … max): 0.00% … 95.06%
 Time  (median):     10.293 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   11.807 μs ± 26.271 μs  ┊ GC (mean ± σ):  5.94% ±  2.67%

  █▇▇▇▅▄▃▁                           ▁▂▁ ▂▁▁                  ▂
  █████████▇▇▇███▇▆▇▇▆▅▆▅▃▄▁▃▄▁▁▃▃▇▇▆███████▇▇▆▆▅▄▄▄▅▃▄▁▅▄▄▅▆ █
  9.77 μs      Histogram: log(frequency) by time      24.5 μs <

 Memory estimate: 39.11 KiB, allocs estimate: 2.

julia> @benchmark run_julia(100,100)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  322.725 μs …  1.427 ms  ┊ GC (min … max): 0.00% … 71.91%
 Time  (median):     330.053 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   333.490 μs ± 41.986 μs  ┊ GC (mean ± σ):  0.59% ±  3.40%

     ▁█▅▁▁▃▁
  ▅▇▆███████▆▅▄▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▂▁▂▂▂▂ ▃
  323 μs          Histogram: frequency by time          398 μs <

 Memory estimate: 117.28 KiB, allocs estimate: 4.

julia> @benchmark run5(2000,3000)
BenchmarkTools.Trial: 752 samples with 1 evaluation.
 Range (min … max):  4.192 ms … 10.854 ms  ┊ GC (min … max): 0.00% …  7.84%
 Time  (median):     7.090 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   6.641 ms ±  2.534 ms  ┊ GC (mean ± σ):  8.75% ± 14.03%

  █▁
  ██▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂█▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂█▃▁▁▂▂▁▁▁▆▅ ▂
  4.19 ms        Histogram: frequency by time        10.5 ms <

 Memory estimate: 22.89 MiB, allocs estimate: 2.

julia> @benchmark run_julia(2000,3000)
BenchmarkTools.Trial: 19 samples with 1 evaluation.
 Range (min … max):  273.657 ms … 284.974 ms  ┊ GC (min … max): 1.34% … 0.27%
 Time  (median):     275.478 ms               ┊ GC (median):    1.33%
 Time  (mean ± σ):   276.150 ms ±   3.258 ms  ┊ GC (mean ± σ):  1.29% ± 0.25%

  ▃ ▃          █                                              ▃
  █▇█▇▇▇▁▇▁▇▇▁▇█▁▇▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ ▁
  274 ms           Histogram: frequency by time          285 ms <

 Memory estimate: 68.66 MiB, allocs estimate: 4.

For 2000x3000, it is 40x or more faster than the original version while still single threaded.
Answers are pretty close, but do not match.

I think this can be improved. I didn’t optimize/tweak it, but just wrote the first idea of how I thought the code should be evaluated (and roughly what I’d intend a loop optimizer to do).

13 Likes