Hard to beat Numba / JAX loop for generating Mandelbrot

Not sure I agree. What is the point of the benchmark? Figure out how to make the fastest Mandelbrot set? Probably not - then we could just copy-paste the solution posted by Kristoffer.

To me the interesting thing is to figure out why Numba can create better code than Julia, while the compiler presumably have even less information.

What is the Julia compiler missing? How can it be improved? This benchmark shows there are some gains to be had in a usecase that is probably hit quite often.

20 Likes

100% agree

1 Like

What about the GPU versions? Are possible CUDA versions comparable to the ones there?

don’t have a CUDA device right now

Sort of unrelated, I think abomination and inexcusable are a bit strong of words to use in this thread. The point was to get to the bottom of why the Julia version is slower, and I think your rhetoric is unnecessary, distracting, and unhelpful.

5 Likes

No, β€˜abomination’ isn’t nearly strong enough for unnecessary allocations. (Just read it as somewhat tongue-in-cheek.)

Here is some verbose and ugly code agnostic with respect to `N`
using VectorizationBase, LayoutPointers

function run5_unroll!(fractal, y, x, ::Val{U}=Val(4), ::Val{N} = Val(20)) where {U,N}
  # Base.require_one_based_indexing(fractal, y, x)
  VW = pick_vector_width(Base.promote_eltype(fractal,y,x))
  _fp, fractalpres = LayoutPointers.stridedpointer_preserve(fractal)
  _yp, ypres = LayoutPointers.stridedpointer_preserve(y)
  _xp, xpres = LayoutPointers.stridedpointer_preserve(x)
  fp = VectorizationBase.zero_offsets(_fp)
  yp = VectorizationBase.zero_offsets(_yp)
  xp = VectorizationBase.zero_offsets(_xp)
  W = length(x)
  H = length(y)
  I = eltype(fractal)
  # LW = LX & (U-1)
  # LH = LY & (VW-1)
  GC.@preserve fractalpres ypres xpres begin
    w = 0
    while w < (W+1-U)
      _c_re = vload(xp, VectorizationBase.Unroll{1,1,U,0,1,zero(UInt64)}((w,)))
      z_re4 = c_re4 = vbroadcast(VW, _c_re)
      h = 0
      while h < H
        _c_im = vload(yp, (MM(VW, h),))
        z_im4 = c_im4 = VecUnroll(ntuple(Returns(_c_im), Val(U)))
        _m = VectorizationBase.Mask(VectorizationBase.mask(VW, h, H))
        m4 = VecUnroll(ntuple(Returns(_m), Val(U)))
        i = one(I)
        fhw4 = VecUnroll(ntuple(Returns(vbroadcast(VW,N%I)), Val(U)))
        # fhw4 = vload(xp, VectorizationBase.Unroll{2,1,U,1,Int(VW),0xf%UInt64}((h,w)), _m)
        while true
          z_re4,z_im4 = c_re4 + z_re4*z_re4 - z_im4*z_im4, c_im4 + 2f0*z_re4*z_im4
          az4 = (z_re4*z_re4 + z_im4*z_im4) > 4f0

          fhw4 = VectorizationBase.ifelse(m4 & az4, i, fhw4)
          m4 &= (!az4)
          (any(VectorizationBase.data(VectorizationBase.vany(m4))) && ((i+=one(i)) <= N)) || break
        end
        vstore!(fp, fhw4, VectorizationBase.Unroll{2,1,U,1,Int(VW),0xf%UInt64}((h,w)), _m)
        h += VW
      end
      w += U
    end
    while w < W
      z_re = c_re = vbroadcast(VW, vload(xp, (w,)))
      h = 0
      while h < H
        z_im = c_im = vload(yp, (MM(VW, h),))
        m = mi = VectorizationBase.Mask(VectorizationBase.mask(VW, h, H))
        i = zero(Int32)
        fhw = vbroadcast(VW,N%I)
        while true
          z_re,z_im = c_re + z_re*z_re - z_im*z_im, c_im + 2f0*z_re*z_im
          az4 = (z_re*z_re + z_im*z_im) > 4f0

          fhw = VectorizationBase.ifelse(mi & az4, i, fhw)
          mi &= (!az4)
          (VectorizationBase.vany(mi) && ((i+=one(i)) < N)) || break
        end
        vstore!(fp, fhw, (MM(VW,h), w), m)
        h += VW
      end      
      w += 1
    end
  end
  fractal
end
function run5(height, width)
  y = range(-1.0f0, 0.0f0; length = height)
  x = range(-1.5f0, 0.0f0; length = width)
  fractal = Matrix{Int32}(undef, height, width)
  return run5_unroll!(fractal, y, x, Val(4), Val(20))
end

Here are prettier benchmarks:

julia> @benchmark run5(10,10)
BenchmarkTools.Trial: 10000 samples with 196 evaluations.
 Range (min … max):  465.469 ns …   5.591 ΞΌs  β”Š GC (min … max): 0.00% … 84.27%
 Time  (median):     472.441 ns               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   488.324 ns Β± 166.302 ns  β”Š GC (mean Β± Οƒ):  1.46% Β±  3.88%

  β–ƒβ–ƒβ–ˆβ–†β–„β–‚β–‚β–ƒβ–ƒβ–ƒβ–‚β–β–                                   ▁ ▁           ▁
  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–…β–…β–…β–β–β–ƒβ–β–β–β–β–β–β–β–β–β–β–β–β–β–ƒβ–β–β–β–β–β–β–β–β–β–β–ƒβ–…β–‡β–‡β–ˆβ–ˆβ–ˆβ–ˆβ–‡β–†β–†β–†β–‡β–†β–‡β–‡β–‡ β–ˆ
  465 ns        Histogram: log(frequency) by time        626 ns <

 Memory estimate: 496 bytes, allocs estimate: 1.

julia> @benchmark run_julia(10,10)
BenchmarkTools.Trial: 10000 samples with 8 evaluations.
 Range (min … max):  3.318 ΞΌs … 172.239 ΞΌs  β”Š GC (min … max): 0.00% … 95.47%
 Time  (median):     3.368 ΞΌs               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   3.539 ΞΌs Β±   2.293 ΞΌs  β”Š GC (mean Β± Οƒ):  0.89% Β±  1.36%

  β–…β–ˆβ–‡β–„β–ƒβ–‚      ▃▄▂▁  ▃▆▅▃▂▁▁▃▃▂▁▁                              β–‚
  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–…β–…β–‚β–…β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–†β–†β–…β–†β–†β–…β–…β–…β–…β–…β–„β–„β–ƒβ–„β–„β–ƒβ–„β–…β–ƒβ–„β–„β–„β–…β–„β–„β–† β–ˆ
  3.32 ΞΌs      Histogram: log(frequency) by time       4.4 ΞΌs <

 Memory estimate: 1.36 KiB, allocs estimate: 2.

julia> @benchmark run5(100,100)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):   9.768 ΞΌs …  1.219 ms  β”Š GC (min … max): 0.00% … 95.06%
 Time  (median):     10.293 ΞΌs              β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   11.807 ΞΌs Β± 26.271 ΞΌs  β”Š GC (mean Β± Οƒ):  5.94% Β±  2.67%

  β–ˆβ–‡β–‡β–‡β–…β–„β–ƒβ–                           ▁▂▁ ▂▁▁                  β–‚
  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–‡β–‡β–ˆβ–ˆβ–ˆβ–‡β–†β–‡β–‡β–†β–…β–†β–…β–ƒβ–„β–β–ƒβ–„β–β–β–ƒβ–ƒβ–‡β–‡β–†β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–‡β–†β–†β–…β–„β–„β–„β–…β–ƒβ–„β–β–…β–„β–„β–…β–† β–ˆ
  9.77 ΞΌs      Histogram: log(frequency) by time      24.5 ΞΌs <

 Memory estimate: 39.11 KiB, allocs estimate: 2.

julia> @benchmark run_julia(100,100)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  322.725 ΞΌs …  1.427 ms  β”Š GC (min … max): 0.00% … 71.91%
 Time  (median):     330.053 ΞΌs              β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   333.490 ΞΌs Β± 41.986 ΞΌs  β”Š GC (mean Β± Οƒ):  0.59% Β±  3.40%

     β–β–ˆβ–…β–β–β–ƒβ–
  β–…β–‡β–†β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–†β–…β–„β–ƒβ–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–β–‚β–β–β–β–‚β–β–β–β–β–β–β–β–β–β–β–‚β–β–‚β–‚β–‚β–‚ β–ƒ
  323 ΞΌs          Histogram: frequency by time          398 ΞΌs <

 Memory estimate: 117.28 KiB, allocs estimate: 4.

julia> @benchmark run5(2000,3000)
BenchmarkTools.Trial: 752 samples with 1 evaluation.
 Range (min … max):  4.192 ms … 10.854 ms  β”Š GC (min … max): 0.00% …  7.84%
 Time  (median):     7.090 ms              β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   6.641 ms Β±  2.534 ms  β”Š GC (mean Β± Οƒ):  8.75% Β± 14.03%

  β–ˆβ–
  β–ˆβ–ˆβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–‚β–ˆβ–‚β–‚β–β–‚β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–‚β–ˆβ–ƒβ–β–β–‚β–‚β–β–β–β–†β–… β–‚
  4.19 ms        Histogram: frequency by time        10.5 ms <

 Memory estimate: 22.89 MiB, allocs estimate: 2.

julia> @benchmark run_julia(2000,3000)
BenchmarkTools.Trial: 19 samples with 1 evaluation.
 Range (min … max):  273.657 ms … 284.974 ms  β”Š GC (min … max): 1.34% … 0.27%
 Time  (median):     275.478 ms               β”Š GC (median):    1.33%
 Time  (mean Β± Οƒ):   276.150 ms Β±   3.258 ms  β”Š GC (mean Β± Οƒ):  1.29% Β± 0.25%

  β–ƒ β–ƒ          β–ˆ                                              β–ƒ
  β–ˆβ–‡β–ˆβ–‡β–‡β–‡β–β–‡β–β–‡β–‡β–β–‡β–ˆβ–β–‡β–‡β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–ˆ ▁
  274 ms           Histogram: frequency by time          285 ms <

 Memory estimate: 68.66 MiB, allocs estimate: 4.

For 2000x3000, it is 40x or more faster than the original version while still single threaded.
Answers are pretty close, but do not match.

I think this can be improved. I didn’t optimize/tweak it, but just wrote the first idea of how I thought the code should be evaluated (and roughly what I’d intend a loop optimizer to do).

13 Likes

I believe JAX operations are threaded by default? That and jax.jit can sometimes elide intermediate allocations. If there’s a way to figure those two bits out, then a like-for-like comparison may be easier.

the notebook writes:

It’s hard to keep JAX from using all your CPU cores, so I ran this notebook with
taskset -c 0 jupyter lab
to bind it to exactly one core (the one numbered 0).

2 Likes

thanks but I would β€œdisqualify” any complete unrolling like this, again, unless we know JAX aggressively unroll like this, it’s not useful for an apple-to-apple comparison because we don’t know how deep we need to go, 20, 40, 200, 2000?

This is only unrolling by 4, regardless of if depth is 4, 20, 40, 200, or 2000.

ah, oops I’m not familiar with this flavor of Julia sorry.

Unlike my earlier version, the newer one also checks and tries to break out of the loop when it can (so that you hopefully don’t need to do 2000 iterations).

1 Like

I believe it does, at least for the static range(20) in the notebook (Edit: this is noted in πŸ”ͺ JAX - The Sharp Bits πŸ”ͺ β€” JAX documentation, among other places). Using the incantation in print optimized code Β· Discussion #7068 Β· google/jax Β· GitHub to look at the optimized HLO:

In [24]: print(module.to_string(xla_ext.HloPrintOptions.short_parsable()))
HloModule xla_computation_run_jax_kernel.7

fused_computation.clone {
  param_1.9 = c64[2000,2000]{1,0} parameter(1)
  multiply.20 = c64[2000,2000]{1,0} multiply(param_1.9, param_1.9)
  add.20 = c64[2000,2000]{1,0} add(multiply.20, param_1.9)
  multiply.21 = c64[2000,2000]{1,0} multiply(add.20, add.20)
  add.21 = c64[2000,2000]{1,0} add(multiply.21, param_1.9)
  multiply.22 = c64[2000,2000]{1,0} multiply(add.21, add.21)
  add.22 = c64[2000,2000]{1,0} add(multiply.22, param_1.9)
  multiply.23 = c64[2000,2000]{1,0} multiply(add.22, add.22)
  add.23 = c64[2000,2000]{1,0} add(multiply.23, param_1.9)
  multiply.24 = c64[2000,2000]{1,0} multiply(add.23, add.23)
  add.24 = c64[2000,2000]{1,0} add(multiply.24, param_1.9)
  multiply.25 = c64[2000,2000]{1,0} multiply(add.24, add.24)
  add.25 = c64[2000,2000]{1,0} add(multiply.25, param_1.9)
  multiply.26 = c64[2000,2000]{1,0} multiply(add.25, add.25)
  add.26 = c64[2000,2000]{1,0} add(multiply.26, param_1.9)
  multiply.28 = c64[2000,2000]{1,0} multiply(add.26, add.26)
  add.27 = c64[2000,2000]{1,0} add(multiply.28, param_1.9)
  multiply.29 = c64[2000,2000]{1,0} multiply(add.27, add.27)
  add.29 = c64[2000,2000]{1,0} add(multiply.29, param_1.9)
  multiply.30 = c64[2000,2000]{1,0} multiply(add.29, add.29)
  add.30 = c64[2000,2000]{1,0} add(multiply.30, param_1.9)
  multiply.31 = c64[2000,2000]{1,0} multiply(add.30, add.30)
  add.31 = c64[2000,2000]{1,0} add(multiply.31, param_1.9)
  multiply.32 = c64[2000,2000]{1,0} multiply(add.31, add.31)
  add.32 = c64[2000,2000]{1,0} add(multiply.32, param_1.9)
  multiply.33 = c64[2000,2000]{1,0} multiply(add.32, add.32)
  add.33 = c64[2000,2000]{1,0} add(multiply.33, param_1.9)
  multiply.34 = c64[2000,2000]{1,0} multiply(add.33, add.33)
  add.34 = c64[2000,2000]{1,0} add(multiply.34, param_1.9)
  multiply.35 = c64[2000,2000]{1,0} multiply(add.34, add.34)
  add.35 = c64[2000,2000]{1,0} add(multiply.35, param_1.9)
  multiply.36 = c64[2000,2000]{1,0} multiply(add.35, add.35)
  add.36 = c64[2000,2000]{1,0} add(multiply.36, param_1.9)
  multiply.37 = c64[2000,2000]{1,0} multiply(add.36, add.36)
  add.37 = c64[2000,2000]{1,0} add(multiply.37, param_1.9)
  multiply.38 = c64[2000,2000]{1,0} multiply(add.37, add.37)
  add.38 = c64[2000,2000]{1,0} add(multiply.38, param_1.9)
  multiply.39 = c64[2000,2000]{1,0} multiply(add.38, add.38)
  add.39 = c64[2000,2000]{1,0} add(multiply.39, param_1.9)
  multiply.41 = c64[2000,2000]{1,0} multiply(add.39, add.39)
  add.40 = c64[2000,2000]{1,0} add(multiply.41, param_1.9)
  abs.40 = f32[2000,2000]{1,0} abs(add.40)
  constant.46 = f32[] constant(2)
  broadcast.47 = f32[2000,2000]{1,0} broadcast(constant.46), dimensions={}
  compare.89 = pred[2000,2000]{1,0} compare(abs.40, broadcast.47), direction=GT
  abs.39 = f32[2000,2000]{1,0} abs(add.39)
  compare.87 = pred[2000,2000]{1,0} compare(abs.39, broadcast.47), direction=GT
  abs.38 = f32[2000,2000]{1,0} abs(add.38)
  compare.85 = pred[2000,2000]{1,0} compare(abs.38, broadcast.47), direction=GT
  abs.37 = f32[2000,2000]{1,0} abs(add.37)
  compare.81 = pred[2000,2000]{1,0} compare(abs.37, broadcast.47), direction=GT
  abs.36 = f32[2000,2000]{1,0} abs(add.36)
  compare.79 = pred[2000,2000]{1,0} compare(abs.36, broadcast.47), direction=GT
  abs.35 = f32[2000,2000]{1,0} abs(add.35)
  compare.77 = pred[2000,2000]{1,0} compare(abs.35, broadcast.47), direction=GT
  abs.34 = f32[2000,2000]{1,0} abs(add.34)
  compare.75 = pred[2000,2000]{1,0} compare(abs.34, broadcast.47), direction=GT
  abs.33 = f32[2000,2000]{1,0} abs(add.33)
  compare.73 = pred[2000,2000]{1,0} compare(abs.33, broadcast.47), direction=GT
  abs.32 = f32[2000,2000]{1,0} abs(add.32)
  compare.71 = pred[2000,2000]{1,0} compare(abs.32, broadcast.47), direction=GT
  abs.31 = f32[2000,2000]{1,0} abs(add.31)
  compare.67 = pred[2000,2000]{1,0} compare(abs.31, broadcast.47), direction=GT
  abs.30 = f32[2000,2000]{1,0} abs(add.30)
  compare.65 = pred[2000,2000]{1,0} compare(abs.30, broadcast.47), direction=GT
  abs.28 = f32[2000,2000]{1,0} abs(add.29)
  compare.63 = pred[2000,2000]{1,0} compare(abs.28, broadcast.47), direction=GT
  abs.27 = f32[2000,2000]{1,0} abs(add.27)
  compare.61 = pred[2000,2000]{1,0} compare(abs.27, broadcast.47), direction=GT
  abs.26 = f32[2000,2000]{1,0} abs(add.26)
  compare.59 = pred[2000,2000]{1,0} compare(abs.26, broadcast.47), direction=GT
  abs.25 = f32[2000,2000]{1,0} abs(add.25)
  compare.55 = pred[2000,2000]{1,0} compare(abs.25, broadcast.47), direction=GT
  abs.24 = f32[2000,2000]{1,0} abs(add.24)
  compare.53 = pred[2000,2000]{1,0} compare(abs.24, broadcast.47), direction=GT
  abs.23 = f32[2000,2000]{1,0} abs(add.23)
  compare.51 = pred[2000,2000]{1,0} compare(abs.23, broadcast.47), direction=GT
  abs.22 = f32[2000,2000]{1,0} abs(add.22)
  compare.49 = pred[2000,2000]{1,0} compare(abs.22, broadcast.47), direction=GT
  abs.21 = f32[2000,2000]{1,0} abs(add.21)
  compare.47 = pred[2000,2000]{1,0} compare(abs.21, broadcast.47), direction=GT
  abs.20 = f32[2000,2000]{1,0} abs(add.20)
  compare.45 = pred[2000,2000]{1,0} compare(abs.20, broadcast.47), direction=GT
  param_0.4 = s32[2000,2000]{1,0} parameter(0)
  constant.45 = s32[] constant(20)
  broadcast.46 = s32[2000,2000]{1,0} broadcast(constant.45), dimensions={}
  compare.42 = pred[2000,2000]{1,0} compare(param_0.4, broadcast.46), direction=EQ
  and.20 = pred[2000,2000]{1,0} and(compare.45, compare.42)
  constant.44 = s32[] constant(0)
  broadcast.45 = s32[2000,2000]{1,0} broadcast(constant.44), dimensions={}
  select.41 = s32[2000,2000]{1,0} select(and.20, broadcast.45, param_0.4)
  compare.46 = pred[2000,2000]{1,0} compare(select.41, broadcast.46), direction=EQ
  and.21 = pred[2000,2000]{1,0} and(compare.47, compare.46)
  constant.47 = s32[] constant(1)
  broadcast.48 = s32[2000,2000]{1,0} broadcast(constant.47), dimensions={}
  select.42 = s32[2000,2000]{1,0} select(and.21, broadcast.48, select.41)
  compare.48 = pred[2000,2000]{1,0} compare(select.42, broadcast.46), direction=EQ
  and.22 = pred[2000,2000]{1,0} and(compare.49, compare.48)
  constant.48 = s32[] constant(2)
  broadcast.49 = s32[2000,2000]{1,0} broadcast(constant.48), dimensions={}
  select.43 = s32[2000,2000]{1,0} select(and.22, broadcast.49, select.42)
  compare.50 = pred[2000,2000]{1,0} compare(select.43, broadcast.46), direction=EQ
  and.23 = pred[2000,2000]{1,0} and(compare.51, compare.50)
  constant.49 = s32[] constant(3)
  broadcast.51 = s32[2000,2000]{1,0} broadcast(constant.49), dimensions={}
  select.44 = s32[2000,2000]{1,0} select(and.23, broadcast.51, select.43)
  compare.52 = pred[2000,2000]{1,0} compare(select.44, broadcast.46), direction=EQ
  and.24 = pred[2000,2000]{1,0} and(compare.53, compare.52)
  constant.50 = s32[] constant(4)
  broadcast.52 = s32[2000,2000]{1,0} broadcast(constant.50), dimensions={}
  select.45 = s32[2000,2000]{1,0} select(and.24, broadcast.52, select.44)
  compare.54 = pred[2000,2000]{1,0} compare(select.45, broadcast.46), direction=EQ
  and.25 = pred[2000,2000]{1,0} and(compare.55, compare.54)
  constant.51 = s32[] constant(5)
  broadcast.53 = s32[2000,2000]{1,0} broadcast(constant.51), dimensions={}
  select.46 = s32[2000,2000]{1,0} select(and.25, broadcast.53, select.45)
  compare.58 = pred[2000,2000]{1,0} compare(select.46, broadcast.46), direction=EQ
  and.26 = pred[2000,2000]{1,0} and(compare.59, compare.58)
  constant.52 = s32[] constant(6)
  broadcast.54 = s32[2000,2000]{1,0} broadcast(constant.52), dimensions={}
  select.47 = s32[2000,2000]{1,0} select(and.26, broadcast.54, select.46)
  compare.60 = pred[2000,2000]{1,0} compare(select.47, broadcast.46), direction=EQ
  and.27 = pred[2000,2000]{1,0} and(compare.61, compare.60)
  constant.53 = s32[] constant(7)
  broadcast.55 = s32[2000,2000]{1,0} broadcast(constant.53), dimensions={}
  select.48 = s32[2000,2000]{1,0} select(and.27, broadcast.55, select.47)
  compare.62 = pred[2000,2000]{1,0} compare(select.48, broadcast.46), direction=EQ
  and.28 = pred[2000,2000]{1,0} and(compare.63, compare.62)
  constant.54 = s32[] constant(8)
  broadcast.56 = s32[2000,2000]{1,0} broadcast(constant.54), dimensions={}
  select.49 = s32[2000,2000]{1,0} select(and.28, broadcast.56, select.48)
  compare.64 = pred[2000,2000]{1,0} compare(select.49, broadcast.46), direction=EQ
  and.29 = pred[2000,2000]{1,0} and(compare.65, compare.64)
  constant.55 = s32[] constant(9)
  broadcast.57 = s32[2000,2000]{1,0} broadcast(constant.55), dimensions={}
  select.50 = s32[2000,2000]{1,0} select(and.29, broadcast.57, select.49)
  compare.66 = pred[2000,2000]{1,0} compare(select.50, broadcast.46), direction=EQ
  and.30 = pred[2000,2000]{1,0} and(compare.67, compare.66)
  constant.56 = s32[] constant(10)
  broadcast.58 = s32[2000,2000]{1,0} broadcast(constant.56), dimensions={}
  select.52 = s32[2000,2000]{1,0} select(and.30, broadcast.58, select.50)
  compare.68 = pred[2000,2000]{1,0} compare(select.52, broadcast.46), direction=EQ
  and.31 = pred[2000,2000]{1,0} and(compare.71, compare.68)
  constant.57 = s32[] constant(11)
  broadcast.59 = s32[2000,2000]{1,0} broadcast(constant.57), dimensions={}
  select.53 = s32[2000,2000]{1,0} select(and.31, broadcast.59, select.52)
  compare.72 = pred[2000,2000]{1,0} compare(select.53, broadcast.46), direction=EQ
  and.33 = pred[2000,2000]{1,0} and(compare.73, compare.72)
  constant.58 = s32[] constant(12)
  broadcast.60 = s32[2000,2000]{1,0} broadcast(constant.58), dimensions={}
  select.54 = s32[2000,2000]{1,0} select(and.33, broadcast.60, select.53)
  compare.74 = pred[2000,2000]{1,0} compare(select.54, broadcast.46), direction=EQ
  and.34 = pred[2000,2000]{1,0} and(compare.75, compare.74)
  constant.59 = s32[] constant(13)
  broadcast.61 = s32[2000,2000]{1,0} broadcast(constant.59), dimensions={}
  select.55 = s32[2000,2000]{1,0} select(and.34, broadcast.61, select.54)
  compare.76 = pred[2000,2000]{1,0} compare(select.55, broadcast.46), direction=EQ
  and.35 = pred[2000,2000]{1,0} and(compare.77, compare.76)
  constant.60 = s32[] constant(14)
  broadcast.62 = s32[2000,2000]{1,0} broadcast(constant.60), dimensions={}
  select.56 = s32[2000,2000]{1,0} select(and.35, broadcast.62, select.55)
  compare.78 = pred[2000,2000]{1,0} compare(select.56, broadcast.46), direction=EQ
  and.36 = pred[2000,2000]{1,0} and(compare.79, compare.78)
  constant.61 = s32[] constant(15)
  broadcast.64 = s32[2000,2000]{1,0} broadcast(constant.61), dimensions={}
  select.57 = s32[2000,2000]{1,0} select(and.36, broadcast.64, select.56)
  compare.80 = pred[2000,2000]{1,0} compare(select.57, broadcast.46), direction=EQ
  and.37 = pred[2000,2000]{1,0} and(compare.81, compare.80)
  constant.62 = s32[] constant(16)
  broadcast.65 = s32[2000,2000]{1,0} broadcast(constant.62), dimensions={}
  select.58 = s32[2000,2000]{1,0} select(and.37, broadcast.65, select.57)
  compare.84 = pred[2000,2000]{1,0} compare(select.58, broadcast.46), direction=EQ
  and.38 = pred[2000,2000]{1,0} and(compare.85, compare.84)
  constant.63 = s32[] constant(17)
  broadcast.66 = s32[2000,2000]{1,0} broadcast(constant.63), dimensions={}
  select.59 = s32[2000,2000]{1,0} select(and.38, broadcast.66, select.58)
  compare.86 = pred[2000,2000]{1,0} compare(select.59, broadcast.46), direction=EQ
  and.39 = pred[2000,2000]{1,0} and(compare.87, compare.86)
  constant.64 = s32[] constant(18)
  broadcast.67 = s32[2000,2000]{1,0} broadcast(constant.64), dimensions={}
  select.60 = s32[2000,2000]{1,0} select(and.39, broadcast.67, select.59)
  compare.88 = pred[2000,2000]{1,0} compare(select.60, broadcast.46), direction=EQ
  and.40 = pred[2000,2000]{1,0} and(compare.89, compare.88)
  constant.65 = s32[] constant(19)
  broadcast.68 = s32[2000,2000]{1,0} broadcast(constant.65), dimensions={}
  ROOT select.61 = s32[2000,2000]{1,0} select(and.40, broadcast.68, select.60)
}

parallel_fusion {
  p = s32[2000,2000]{1,0} parameter(0)
  p.1 = c64[2000,2000]{1,0} parameter(1)
  ROOT fusion.clone = s32[2000,2000]{1,0} fusion(p, p.1), kind=kLoop, calls=fused_computation.clone, outer_dimension_partitions={8}
}

ENTRY main.288 {
  Arg_1.2 = s32[2000,2000]{1,0} parameter(1)
  Arg_0.1 = c64[2000,2000]{1,0} parameter(0)
  call = s32[2000,2000]{1,0} call(Arg_1.2, Arg_0.1), to_apply=parallel_fusion
  ROOT tuple.287 = (s32[2000,2000]{1,0}) tuple(call)
}

The canonical form my have more low-level details, but it exceeds the Discourse character limit. I’ve dumped it in JAX Mandelbrot canonical HLO Β· GitHub.

Some more diagnostics using Add public API for computation cost analysis (flops, memory use, etc.) Β· Issue #10542 Β· google/jax Β· GitHub :

In [28]: module = comp.as_hlo_module()

In [29]: client = jax.lib.xla_bridge.get_backend()

In [30]: analysis = jax.lib.xla_client._xla.hlo_module_cost_analysis(client, module)

In [31]: analysis
Out[31]: 
{'bytes accessed': 7871990784.0,
 'bytes accessed operand 0 {}': 2720000000.0,
 'bytes accessed operand 1 {}': 2320000000.0,
 'bytes accessed operand 2 {}': 320000000.0,
 'bytes accessed output {}': 2512000000.0,
 'flops': 560000000.0,
 'optimal_seconds': 0.0}

I think it turns out in Jax, if you have a loop like for i in range(20):, it unrolls everything always

2 Likes

for posterity:

7 Likes

Small simplification: in Julia v1.9 I get identical performance for your function and this one:

# NOT using fastmath
function run_julia(height, width)
    y = range(-1.0f0, 0.0f0; length = height) # need Float32 because Jax defaults to it
    x = range(-1.5f0, 0.0f0; length = width)
    c = x' .+ y*im
    fractal = fill(Int32(20), height, width)
    for idx in eachindex(c, fractal)
        _c = c[idx]
        z = _c
        m = true
        Base.Cartesian.@nexprs 20 i -> begin
            z = z^2 + _c
            az4 = abs2(z) > 4f0
            fractal[idx] = ifelse(m&az4, Int32(i), fractal[idx]) # 32-bit Int, same reason as above
            m &= (!az4)
        end
    end
    return fractal
end

Single for over eachindex(c, fractal) and no need for @inbounds.

1 Like

in 1.8 it seems I need

    @inbounds for idx in eachindex(c, fractal)

because of regression on inferring inbounds for `eachindex()` Β· Issue #45507 Β· JuliaLang/julia Β· GitHub

2 Likes

Looping over eachindex(c, fractal) instead of hardcoding 1-based indices across all dimensions should still be nicer :slightly_smiling_face:

3 Likes

yes but I’m trying to keep it minimal change from Numba version, if we still need @inbounds, I think it would only confuse readers who don’t speak Julia; if we can drop @inbounds, I can at least say β€œwell this give compiler confidence that the c and fractal shape are the same”