Speeding up my logsumexp function

Thanks, that’s why I thought it was curious/worth pointing out. And that seems to be it.
In particular, Threads.nthreads() was 1. Now using 18 threads:

julia> @btime lsexp_mat($A);
  11.372 ms (13 allocations: 15.41 MiB)

julia> @btime lsexp_mat2($A);
  11.212 ms (6 allocations: 7.65 MiB)

julia> @btime lsexp_mat3($A);
  2.105 ms (332 allocations: 7.68 MiB)

julia> @btime lsexp_mat4($A);
  2.385 ms (17 allocations: 7.65 MiB)

julia> @btime lsexp_mat5($A);
  1.613 ms (328 allocations: 7.68 MiB)

julia> using Tracker, Zygote #, ForwardDiff

julia> Zygote.@nograd safeeq

julia> gA = Tracker.gradient(sum∘lsexp_mat, A)[1];

julia> Zygote.gradient(sum∘lsexp_mat1, A)[1] ≈ gA
true

julia> Zygote.gradient(sum∘lsexp_mat3, A)[1] ≈ gA
true

julia> Zygote.gradient(sum∘lsexp_mat5, A)[1] ≈ gA
true

julia> @btime Zygote.gradient(sum∘lsexp_mat1, $A);
  65.850 ms (3003130 allocations: 137.61 MiB)

julia> @btime Zygote.gradient(sum∘lsexp_mat3, $A);
  8.496 ms (788 allocations: 38.28 MiB)

julia> @btime Zygote.gradient(sum∘lsexp_mat5, $A);
  7.639 ms (778 allocations: 38.28 MiB)

Trying different values in the threads arguments:

julia> @btime lsexp_mat_25_000_threads($A);
  1.669 ms (328 allocations: 7.68 MiB)

julia> @btime lsexp_mat_50_000_threads($A);
  1.681 ms (328 allocations: 7.68 MiB)

julia> @btime lsexp_mat_100_000_threads($A);
  1.679 ms (328 allocations: 7.68 MiB)

julia> @btime lsexp_mat_200_000_threads($A);
  1.792 ms (160 allocations: 7.67 MiB)

julia> @btime Zygote.gradient(sum∘lsexp_mat_25_000_threads, $A);
  7.700 ms (779 allocations: 38.28 MiB)

julia> @btime Zygote.gradient(sum∘lsexp_mat_50_000_threads, $A);
  7.689 ms (779 allocations: 38.28 MiB)

julia> @btime Zygote.gradient(sum∘lsexp_mat_100_000_threads, $A);
  7.699 ms (780 allocations: 38.28 MiB)

julia> @btime Zygote.gradient(sum∘lsexp_mat_200_000_threads, $A);
  7.902 ms (424 allocations: 38.25 MiB)

It’s definitely an improvement, but lags well behind the 18 threads. With avx=false:

julia> @btime lsexp_mat_25_000_threads_noavx($A);
  2.124 ms (330 allocations: 7.68 MiB)

julia> @btime lsexp_mat_50_000_threads_noavx($A);
  2.114 ms (330 allocations: 7.68 MiB)

julia> @btime lsexp_mat_100_000_threads_noavx($A);
  2.123 ms (330 allocations: 7.68 MiB)

julia> @btime lsexp_mat_200_000_threads_noavx($A);
  2.602 ms (161 allocations: 7.67 MiB)

julia> @btime Zygote.gradient(sum∘lsexp_mat_25_000_threads_noavx, $A);
  8.451 ms (784 allocations: 38.28 MiB)

julia> @btime Zygote.gradient(sum∘lsexp_mat_50_000_threads_noavx, $A);
  8.450 ms (782 allocations: 38.28 MiB)

julia> @btime Zygote.gradient(sum∘lsexp_mat_100_000_threads_noavx, $A);
  8.502 ms (785 allocations: 38.28 MiB)

julia> @btime Zygote.gradient(sum∘lsexp_mat_200_000_threads_noavx, $A);
  9.605 ms (429 allocations: 38.25 MiB)

Scaling is a lot better than it looks, because the serial portions take a long time, e.g. around half of all the time is spent in maximum:

julia> @btime maximum($A, dims=1);
  903.528 μs (3 allocations: 8.00 KiB)

This could of course also be optimized by making it threaded and/or SIMD. We can try a SIMD version via LoopVectorization:

julia> @btime(vreduce(max, $A, dims=1)) == maximum(A, dims=1)
  333.883 μs (1 allocation: 7.94 KiB)
true

But I didn’t try this within lsexp_mat due to the lack of Zygote support.

So the problem is particular to trivial dimensions which have stride 1, regardless of the other strides involved?

Yes. The first two examples should work, while the latter two will be broken.

1 Like