Orders of magnitude runtime difference in row-wise norm

I think the big thing here is that the broadcast in the first function is fused, and thus calls norm N^2 times, while the second one instead calls it N times, and makes a temporary array to store the results. A simpler example of this might be:

julia> using BenchmarkTools

julia> let N = 100 
         x = rand(N)
         y = rand(N)
         a = @btime exp.($x) ./ $y'  # fused, calls exp N^2 times
         b = @btime begin z = exp.($x);  z ./ $y' end  # just N calls
         a ≈ b
       end
  min 68.917 μs, mean 81.129 μs (2 allocations, 78.17 KiB)
  min 3.351 μs, mean 11.817 μs (3 allocations, 79.05 KiB)  # extra allocation, length N
true

However, norm.(eachrow(... still isn’t ideal, as it has to access memory in the wrong order – columns contain neighbouring elements, not rows. Functions which take dims (instead of working on slices) work around this by re-organising the calculation. #43459 wants to add such a method as norm(x; dims=2). Without that sqrt.(sum(abs2, x; dims=2)) is one option (but is less careful about underflow etc).

julia> function divide_by_row_norms_3(x)
           z = 1 ./ sqrt.(sum(abs2, x; dims=2))
           x .* z
       end;

julia> using Revise; Revise.track(LinearAlgebra)  # to load PR#43459

julia> function divide_by_row_norms_4(x)
           z = norm(x; dims=2)
           x ./ z
       end;

julia> let x = randn(100, 3000)
         a = @btime divide_by_row_norms($x)
         b = @btime divide_by_row_norms_2($x)  # as above
         println()
         c = @btime divide_by_row_norms_3($x)  # with sqrt.(sum(abs2, x; dims=2))
         d = @btime divide_by_row_norms_4($x) 
         e = @btime normalize($x; dims=2)  # also from PR#43459, seemingly not ideal
         a ≈ b ≈ c ≈ d ≈ e
       end
  min 1.858 s, mean 1.859 s (2 allocations, 2.29 MiB)
  min 698.958 μs, mean 1.031 ms (3 allocations, 2.29 MiB)

  min 156.541 μs, mean 413.161 μs (8 allocations, 2.29 MiB)
  min 159.583 μs, mean 405.405 μs (7 allocations, 2.29 MiB)
  min 341.875 μs, mean 567.787 μs (8 allocations, 2.29 MiB)
true
5 Likes