Orders of magnitude runtime difference in row-wise norm

I’m trying to divide every row of a matrix by its norm using the approach suggested by this comment. However, I’m finding that two functionally equivalent implementations are resulting in more than 100x difference in runtime (the first implementation runs in ~1.3s, while the second runs in ~0.002s). Any idea why this is happening?

I’m using Julia 1.8.5 on an Intel Mac.

using LinearAlgebra, Random, Profile

function divide_by_row_norms(x)
    x ./ norm.(eachrow(x))
end

function divide_by_row_norms_2(x)
    norms = norm.(eachrow(x))
    x ./ norms
end

x = randn(MersenneTwister(1234), 100, 3000)
println(sum(divide_by_row_norms(x))) # 6.874556570313833
@time divide_by_row_norms(x)
@time divide_by_row_norms(x)
@profile divide_by_row_norms(x) # 1.280820 seconds

println(sum(divide_by_row_norms_2(x))) # 6.874556570313833
@time divide_by_row_norms_2(x)
@time divide_by_row_norms_2(x)
@profile divide_by_row_norms_2(x) # 0.002097 seconds

Profile.print()
Full Output
6.874556570313833
  1.351277 seconds (3 allocations: 2.293 MiB)
  1.280820 seconds (3 allocations: 2.293 MiB)
6.874556570313833
  0.002182 seconds (4 allocations: 2.294 MiB)
  0.002097 seconds (4 allocations: 2.294 MiB)
Overhead ╎ [+additional indent] Count File:Line; Function
=========================================================
    ╎56   @Base/task.jl:484; (::VSCodeServer.var"#61#62")()
    ╎ 56   @VSCodeServer/src/eval.jl:34; macro expansion
    ╎  56   @Base/essentials.jl:726; invokelatest(::Any)
    ╎   56   @Base/essentials.jl:729; #invokelatest#2
    ╎    56   @VSCodeServer/src/eval.jl:225; (::VSCodeServer.var"#63#67"{VSCodeServer.ReplRunCodeReques...
    ╎     56   @Base/logging.jl:623; with_logger
    ╎    ╎ 56   @Base/logging.jl:511; with_logstate(f::Function, logstate::Any)
    ╎    ╎  56   @VSCodeServer/src/eval.jl:126; (::VSCodeServer.var"#64#68"{Bool, Bool, Bool, Module, St...
    ╎    ╎   56   @VSCodeServer/src/repl.jl:38; hideprompt(f::VSCodeServer.var"#65#69"{Bool, Bool, Bool,...
    ╎    ╎    56   @VSCodeServer/src/eval.jl:155; (::VSCodeServer.var"#65#69"{Bool, Bool, Bool, Module, S...
    ╎    ╎     56   @VSCodeServer/src/repl.jl:249; withpath(f::VSCodeServer.var"#66#70"{Bool, Bool, Bool,...
    ╎    ╎    ╎ 56   @VSCodeServer/src/eval.jl:157; (::VSCodeServer.var"#66#70"{Bool, Bool, Bool, Module, ...
    ╎    ╎    ╎  56   @VSCodeServer/src/eval.jl:230; inlineeval##kw
    ╎    ╎    ╎   56   @VSCodeServer/src/eval.jl:233; inlineeval(m::Module, code::String, code_line::Int64...
    ╎    ╎    ╎    56   @Base/essentials.jl:726; invokelatest(::Any, ::Any, ::Vararg{Any})
    ╎    ╎    ╎     56   @Base/essentials.jl:729; invokelatest(::Any, ::Any, ::Vararg{Any}; kwargs::B...
    ╎    ╎    ╎    ╎ 56   @Base/loading.jl:1428; include_string(mapexpr::typeof(REPL.softscope), mo...
   2╎    ╎    ╎    ╎  56   @Base/boot.jl:368; eval
    ╎    ╎    ╎    ╎   54   ...P23/6.8410 (6.838)/hw3/minimal.jl:4; divide_by_row_norms(x::Matrix{Float64})
    ╎    ╎    ╎    ╎    54   @Base/broadcast.jl:860; materialize
    ╎    ╎    ╎    ╎     54   @Base/broadcast.jl:885; copy
    ╎    ╎    ╎    ╎    ╎ 54   @Base/broadcast.jl:913; copyto!
    ╎    ╎    ╎    ╎    ╎  54   @Base/broadcast.jl:960; copyto!
    ╎    ╎    ╎    ╎    ╎   2    @Base/simdloop.jl:75; macro expansion
   2╎    ╎    ╎    ╎    ╎    2    @Base/int.jl:83; <
    ╎    ╎    ╎    ╎    ╎   52   @Base/simdloop.jl:77; macro expansion
    ╎    ╎    ╎    ╎    ╎    52   @Base/broadcast.jl:961; macro expansion
    ╎    ╎    ╎    ╎    ╎     42   @Base/broadcast.jl:597; getindex
    ╎    ╎    ╎    ╎    ╎    ╎ 42   @Base/broadcast.jl:642; _broadcast_getindex
    ╎    ╎    ╎    ╎    ╎    ╎  42   @Base/broadcast.jl:666; _getindex
    ╎    ╎    ╎    ╎    ╎    ╎   12   @Base/broadcast.jl:636; _broadcast_getindex
    ╎    ╎    ╎    ╎    ╎    ╎    12   @Base/multidimensional.jl:672; getindex
  12╎    ╎    ╎    ╎    ╎    ╎     12   @Base/array.jl:925; getindex
    ╎    ╎    ╎    ╎    ╎    ╎   30   @Base/broadcast.jl:667; _getindex
    ╎    ╎    ╎    ╎    ╎    ╎    21   @Base/broadcast.jl:642; _broadcast_getindex
    ╎    ╎    ╎    ╎    ╎    ╎     21   @Base/broadcast.jl:667; _getindex
    ╎    ╎    ╎    ╎    ╎    ╎    ╎ 21   @Base/broadcast.jl:636; _broadcast_getindex
    ╎    ╎    ╎    ╎    ╎    ╎    ╎  21   @Base/multidimensional.jl:672; getindex
  21╎    ╎    ╎    ╎    ╎    ╎    ╎   21   @Base/array.jl:924; getindex
    ╎    ╎    ╎    ╎    ╎    ╎    9    @Base/broadcast.jl:643; _broadcast_getindex
    ╎    ╎    ╎    ╎    ╎    ╎     9    @Base/broadcast.jl:670; _broadcast_getindex_evalf
   4╎    ╎    ╎    ╎    ╎    ╎    ╎ 9    ...inearAlgebra/src/generic.jl:591; norm
    ╎    ╎    ╎    ╎    ╎    ╎    ╎  5    ...nearAlgebra/src/generic.jl:593; norm(itr::SubArray{Float64, 1, Matrix{Flo...
    ╎    ╎    ╎    ╎    ╎    ╎    ╎   5    ...LinearAlgebra/src/dense.jl:106; norm2
   3╎    ╎    ╎    ╎    ╎    ╎    ╎    3    ...LinearAlgebra/src/blas.jl:427; nrm2(x::SubArray{Float64, 1, Matrix{Floa...
    ╎    ╎    ╎    ╎    ╎    ╎    ╎    2    ...LinearAlgebra/src/blas.jl:429; nrm2(x::SubArray{Float64, 1, Matrix{Floa...
   1╎    ╎    ╎    ╎    ╎    ╎    ╎     2    ...LinearAlgebra/src/blas.jl:420; nrm2
    ╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎ 1    @Base/essentials.jl:412; cconvert
    ╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎  1    @Base/refpointer.jl:104; convert
   1╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎   1    @Base/refvalue.jl:8; RefValue
    ╎    ╎    ╎    ╎    ╎     10   @Base/multidimensional.jl:674; setindex!
  10╎    ╎    ╎    ╎    ╎    ╎ 10   @Base/array.jl:968; setindex!
   1╎1    ...a/stdlib/v1.8/LinearAlgebra/src/blas.jl:429; nrm2(x::SubArray{Float64, 1, Matrix{Float64}, Tuple{Int64, B...
   1╎1    ...tdlib/v1.8/LinearAlgebra/src/generic.jl:0; norm(itr::SubArray{Float64, 1, Matrix{Float64}, Tuple{Int64,...
Total snapshots: 2179. Utilization: 100% across all threads and tasks. Use the `groupby` kwarg to break down by thread and/or task

julia> 
1 Like

I think the big thing here is that the broadcast in the first function is fused, and thus calls norm N^2 times, while the second one instead calls it N times, and makes a temporary array to store the results. A simpler example of this might be:

julia> using BenchmarkTools

julia> let N = 100 
         x = rand(N)
         y = rand(N)
         a = @btime exp.($x) ./ $y'  # fused, calls exp N^2 times
         b = @btime begin z = exp.($x);  z ./ $y' end  # just N calls
         a ≈ b
       end
  min 68.917 μs, mean 81.129 μs (2 allocations, 78.17 KiB)
  min 3.351 μs, mean 11.817 μs (3 allocations, 79.05 KiB)  # extra allocation, length N
true

However, norm.(eachrow(... still isn’t ideal, as it has to access memory in the wrong order – columns contain neighbouring elements, not rows. Functions which take dims (instead of working on slices) work around this by re-organising the calculation. #43459 wants to add such a method as norm(x; dims=2). Without that sqrt.(sum(abs2, x; dims=2)) is one option (but is less careful about underflow etc).

julia> function divide_by_row_norms_3(x)
           z = 1 ./ sqrt.(sum(abs2, x; dims=2))
           x .* z
       end;

julia> using Revise; Revise.track(LinearAlgebra)  # to load PR#43459

julia> function divide_by_row_norms_4(x)
           z = norm(x; dims=2)
           x ./ z
       end;

julia> let x = randn(100, 3000)
         a = @btime divide_by_row_norms($x)
         b = @btime divide_by_row_norms_2($x)  # as above
         println()
         c = @btime divide_by_row_norms_3($x)  # with sqrt.(sum(abs2, x; dims=2))
         d = @btime divide_by_row_norms_4($x) 
         e = @btime normalize($x; dims=2)  # also from PR#43459, seemingly not ideal
         a ≈ b ≈ c ≈ d ≈ e
       end
  min 1.858 s, mean 1.859 s (2 allocations, 2.29 MiB)
  min 698.958 μs, mean 1.031 ms (3 allocations, 2.29 MiB)

  min 156.541 μs, mean 413.161 μs (8 allocations, 2.29 MiB)
  min 159.583 μs, mean 405.405 μs (7 allocations, 2.29 MiB)
  min 341.875 μs, mean 567.787 μs (8 allocations, 2.29 MiB)
true
5 Likes

Preventing broadcast fusing contains more detailed discussion about the broadcasting fusing slowdown (a failure of LICM for a theoretically pure norm function? EDIT: maybe not, because LICM can’t create temporary arrays I suppose?)

1 Like

There is an issue (29285) about LICM for broadcasted pure functions, but that’s for scalar inputs. The loop in broadcast.jl is an unnested @simd loop where _broadcast_getindex_evalf (well, it really starts with a getindex(bc::Broadcasted...) evaluates the fused operations on each index of the output array, so I’m not sure if pure function calls can be lifted out of a loop if those calls do need to be iterated over some of the axes.

Right, I think Preventing broadcast fusing - #18 by Chris_Foster expresses a similar sentiment

1 Like

That particular comment is actually even more damning than just LICM possibly being undoable. It’s that even if you could save pure function calls along some axes by lifting them to higher levels in a nested loop (maybe the Broadcasted struct could hold those intermediate values), there is a function call along an axis stuck in the innermost loop so the number of operations remains on the same order, albeit scaled down. It’s actually a really neat example of a time-memory tradeoff.

1 Like