Speeding up my logsumexp function

I wrote a logsumexp function based on the code here, that takes matrix as input and applies logsumexp along a dimension in a numerically stable way.

function lsexp_mat(mat, zero1_mat; dims=1)

	# To use the function zero1_mat has to be created outside @model
	# with the following code
	
	# zero1_mat = zeros(size(input_matrix))
	# zero1_mat[end, :] = zero1_mat[end, :] .+ 1

    sorted_mat = sort(mat, dims=dims)
    max_ = reshape(sorted_mat[end, :], 1, :)
    exp_mat = exp.(sorted_mat .- max_)
    sum_exp_ = sum(exp_mat .- zero1_mat, dims=dims)
    log1p.(sum_exp_) .+ max_
end

I’m using sort function to make it easy to subtract 1 max elements along the given dimension. But sorting is nlogn. I also wrote this way because I don’t any variables in my code to mutate because I want to use this code as a part of my Turing model and then use Zygote backend.

Is there way to create a mask matrix with 1 at all argmax indices and 0 else where without mutation?

This is the function from the discussion I’m trying to implement for a matrix.

function logsumexp!(p::WeightedParticles)
    N = length(p)
    w = p.logweights
    offset, maxind = findmax(w)
    w .= exp.(w .- offset)
    Σ = sum_all_but(w,maxind) # Σ = ∑wₑ-1
    log1p(ÎŁ) + offset, ÎŁ+1
end

function sum_all_but(w,i)
    w[i] -= 1
    s = sum(w)
    w[i] += 1
    s
end
1 Like

Use maximum to find the maximum:

function lsexp_mat2(mat, zero1_mat; dims=1)

    # To use the function zero1_mat has to be created outside @model
    # with the following code
    # zero1_mat = zeros(size(input_matrix))
    # zero1_mat[end, :] = zero1_mat[end, :] .+ 1

    max_ = maximum(mat, dims=dims)
    exp_mat = exp.(mat .- max_)
    sum_exp_ = sum(exp_mat .- zero1_mat, dims=dims)
    log1p.(sum_exp_) .+ max_
end

I have to subtract 1 from all the columnwise max values and use log1p for the function to be numerically stable. By sorting I know all the columnwise max values are the last row so I can easily subtract -1 from them. The change you made makes the function underflow.

mat = [1e-20 1e-20; log(1e-20) log(1e-20)] check it for this array.

1 Like

Oops, didn’t read for the motivation; my apologies.

remainder(dim::Int, ::NTuple{2}) = 3 - dim
function remainder(dim::Int, ::NTuple{N}) where {N}
    ntuple(n -> n < dim ? n : n + 1, Val(N-1))
end
using SparseArrays
function lsexp_mat(mat::AbstractMatrix; dims=1)
    @assert dims == 1
    remdim = remainder(dims,size(mat))
    maxinds_ = map(argmax, eachslice(mat, dims=remdim))
    max_ = getindex.(eachslice(mat, dims=remdim), maxinds_)
    m, n = size(mat)
    zero1_mat = sparse(maxinds_, axes(mat,2), ones(m), m, n)
    exp_mat = exp.(mat .- max_) - zero1_mat # TODO: generalize me
    log1p.(sum(exp_mat, dims=dims)) .+ max_' # TODO: generalize me
end

This isn’t a great solution, but I wanted to actually offer something addressing your actual question:

Is there way to create a mask matrix with 1 at all argmax indices and 0 else where without mutation?

by using sparse. You could almost certainly make that more efficient.

julia> mat = [1e-20 1e-20; log(1e-20) log(1e-20)];

julia> zero1_mat = zeros(size(mat)); zero1_mat[end, :] = zero1_mat[end, :] .+ 1;

julia> lsexp_mat(mat) # new
1×2 Array{Float64,2}:
 2.0e-20  2.0e-20

julia> lsexp_mat(mat, zero1_mat) # original
1×2 Array{Float64,2}:
 2.0e-20  2.0e-20

Worth pointing out that Zygote doesn’t currently support the SparseMatricSCS constructor, so this answer doesn’t really help you.
It also doesn’t support sort. So you’ll have to do a little work defining your own adjoints. I’d reccomend defining the rule for lsexp_mat directly (which would allow you to mutate internally).

I found that zero1_mat = mat .== maximum(mat, dims=1) creates what I want.

Thank you for your solution, but I want to keep the solution as simple as possible. Zygote may not support SparseArrays.

My new code is

function lsexp_mat(mat; dims=1)
	max_ = maximum(mat, dims=1)
    zero1_mat = mat .== max_
    exp_mat = exp.(mat .- max_)
    sum_exp_ = sum(exp_mat .- zero1_mat, dims=dims)
    log1p.(sum_exp_) .+ max_
end
n = 10000
A = rand(n,n) 

@benchmark lsexp_mat(A, dims=1)

@benchmark mapslices(lsexp_vector, A; dims=1)

lsexp_mat bench mark

BenchmarkTools.Trial: 
  memory estimate:  1.50 GiB
  allocs estimate:  17
  --------------
  minimum time:     1.750 s (0.30% GC)
  median time:      1.784 s (3.12% GC)
  mean time:        1.829 s (5.38% GC)
  maximum time:     1.953 s (11.99% GC)
  --------------
  samples:          3
  evals/sample:     1

lsexp_vector benchmark

BenchmarkTools.Trial: 
  memory estimate:  2.57 MiB
  allocs estimate:  108509
  --------------
  minimum time:     1.127 s (0.00% GC)
  median time:      1.129 s (0.00% GC)
  mean time:        1.133 s (0.00% GC)
  maximum time:     1.143 s (0.00% GC)
  --------------
  samples:          5
  evals/sample:     1

We can see that lsexp_mat is slower than lsexp_vector and also uses orders of magnitude more memory. Why is this happening and how to make this better and faster? And also the number of allocations of lsexp_vector is very high compared to lsexp_mat yet lsexp_vector is faster. I expected lsexp_mat to be faster since everything is vectorised.

The problem with lsexp_mat is it allocates a lot of un-necessary matrices. If you re-write it as

function lsexp_mat1(mat; dims=1)
   max_ = maximum(mat, dims=1)
   sum_exp_ = sum(exp.(mat .- max_) .- mat .== max, dims=dims)
   log1p.(sum_exp_) .+ max_
end

it’s about as fast as the vector version.

2 Likes

Here are a few faster variants, using a package of mine and one of @Elrod’s. If you want this to work with Zygote, then the time spent working out the gradient will tend to dominate.

function lsexp_mat1(mat; dims=1)
    max_ = maximum(mat, dims=1)
    zero1_mat = safeeq(mat, max_) # working around a Zygote bug, today?
    exp_mat = exp.(mat .- max_)
    sum_exp_ = sum(exp_mat .- zero1_mat, dims=dims)
    log1p.(sum_exp_) .+ max_
end
safeeq(mat, max_) = (mat .== max_)

function lsexp_mat2(mat; dims=1) # less memory but not really faster?
    max_ = maximum(mat, dims=1)
    exp_mat = exp.(mat .- max_) .- (mat .== max_) # fuse this broadcast, @Oscar_Smith beat me to it!
    sum_exp_ = sum(exp_mat, dims=dims)
    sum_exp_ .= log1p.(sum_exp_) .+ max_ # re-use this array?
end

using Tullio # ] add Tullio#master -- I just fixed a bug about == & gradients

function lsexp_mat3(mat) # not generic over dims, but differentiable
    max_ = maximum(mat, dims=1)
    @tullio exp_mat[i,j] := exp(mat[i,j] - max_[1,j]) - (mat[i,j] == max_[1,j]) avx=false # grad=Dual # fixed on master
    sum_exp_ = sum(exp_mat, dims=1)
    @tullio out[i,j] := log1p(sum_exp_[i,j]) + max_[i,j] avx=false
end

using LoopVectorization

# function lsexp_mat4(mat; dims=1) # @avx broadcasting, is having a bad day
#     max_ = maximum(mat, dims=1)
#     # zero1_mat = (mat .== max_)
#     exp_mat = @avx exp.(mat .- max_) .- (mat .== max_) # has lots of NaN & Inf in it?
#     sum_exp_ = sum(exp_mat, dims=dims)
#     @avx sum_exp_ .= log1p.(sum_exp_) .+ max_ # mostly NaN
# end

function lsexp_mat5(mat) # also using @avx
    max_ = maximum(mat, dims=1)
    @tullio exp_mat[i,j] := exp(mat[i,j] - max_[1,j]) - (mat[i,j] == max_[1,j])
    sum_exp_ = sum(exp_mat, dims=1)
    @tullio out[i,j] := log1p(sum_exp_[i,j]) + max_[i,j]
end

n = 1_000; A = rand(n,n);
lsexp_mat(A) ≈ lsexp_mat1(A) ≈ lsexp_mat2(A)
# lsexp_mat(A) ≈ lsexp_mat4(A) # false? 
lsexp_mat(A) ≈ lsexp_mat3(A) ≈ lsexp_mat5(A)

@btime lsexp_mat($A)  # 10.164 ms (13 allocations: 15.41 MiB)
@btime lsexp_mat2($A) # 10.631 ms (6 allocations: 7.65 MiB)
@btime lsexp_mat3($A) #  3.188 ms (78 allocations: 7.66 MiB)
# @btime lsexp_mat4($A) #  3.494 ms (14 allocations: 7.65 MiB)
@btime lsexp_mat5($A) #  2.069 ms (76 allocations: 7.66 MiB)

using Tracker, Zygote #, ForwardDiff
Zygote.@nograd safeeq

gA = Tracker.gradient(sum∘lsexp_mat, A)[1];
Zygote.gradient(sum∘lsexp_mat1, A)[1] ≈ gA
Zygote.gradient(sum∘lsexp_mat3, A)[1] ≈ gA
Zygote.gradient(sum∘lsexp_mat5, A)[1] ≈ gA

@btime Zygote.gradient(sum∘lsexp_mat1, $A); # 69.199 ms (3003130 allocations: 137.61 MiB)
@btime Zygote.gradient(sum∘lsexp_mat3, $A); # 12.547 ms (253 allocations: 38.23 MiB)
@btime Zygote.gradient(sum∘lsexp_mat5, $A); # 10.309 ms (248 allocations: 38.23 MiB)
3 Likes

The awkward thing about broadcasting is that type information and syntax combined still do not fully specify the behavior: if isone(size(A, n)), then axis n is broadcasted.
Currently LoopVectorization handles broadcasting by setting strides corresponding to dimensions of size 1 to 0. The linear index A[n...] equals dot(n, strides(A)), so by setting the stride to 0, it’ll ignore indexes on that axis and “broadcast” along it.

However, this doesn’t work well when loading data along a contiguous axis.
For this to be efficient, we need to use vmovup* instructions. These load contiguous elements, so the stride = 0 trick won’t work.
It would work with gather instructions, available in AVX2 and AVX512, but they’re many times slower. While better than nothing, they cripple performance on a comparative basis.
This means to be efficient, we have to use the contiguous loads (and stores, but stores aren’t a problem when broadcasting).

Obviously Julia+LLVM don’t have a problem with this. I’m guessing it uses a few runtime checks to switch between different versions of the loop. I should probably follow that approach.

But for now, a workaround (incompatible with the dims argument) is to make it known at compile time that isone(size(max_,1)):

function lsexp_mat4(mat; dims=1) # @avx broadcasting, is having a bad day
    @assert dims == 1
    max_ = vec(maximum(mat, dims=1))' # requires dims=1
    # zero1_mat = (mat .== max_)
    exp_mat = @avx exp.(mat .- max_) .- (mat .== max_) # should now work
    sum_exp_ = sum(exp_mat, dims=dims)
    @avx sum_exp_ .= log1p.(sum_exp_) .+ max_ # mostly NaN
end

So now I get:

julia> n = 1_000; A = rand(n,n);

julia> lsexp_mat(A) ≈ lsexp_mat1(A) ≈ lsexp_mat2(A)
true

julia> lsexp_mat(A) ≈ lsexp_mat4(A)
true

julia> lsexp_mat(A) ≈ lsexp_mat3(A) ≈ lsexp_mat5(A)
true

julia> @btime lsexp_mat($A);
  10.844 ms (13 allocations: 15.41 MiB)

julia> @btime lsexp_mat2($A);
  10.726 ms (6 allocations: 7.65 MiB)

julia> @btime lsexp_mat3($A);
  8.750 ms (21 allocations: 7.65 MiB)

julia> @btime lsexp_mat4($A);
  2.402 ms (17 allocations: 7.65 MiB)

julia> @btime lsexp_mat5($A);
  2.557 ms (19 allocations: 7.65 MiB)

julia> using Tracker, Zygote #, ForwardDiff

julia> Zygote.@nograd safeeq

julia> gA = Tracker.gradient(sum∘lsexp_mat, A)[1];

julia> Zygote.gradient(sum∘lsexp_mat1, A)[1] ≈ gA
true

julia> Zygote.gradient(sum∘lsexp_mat3, A)[1] ≈ gA
true

julia> Zygote.gradient(sum∘lsexp_mat5, A)[1] ≈ gA
true

julia> @btime Zygote.gradient(sum∘lsexp_mat1, $A);
  67.391 ms (3003130 allocations: 137.61 MiB)

julia> @btime Zygote.gradient(sum∘lsexp_mat3, $A);
  22.924 ms (133 allocations: 38.22 MiB)

julia> @btime Zygote.gradient(sum∘lsexp_mat5, $A);
  10.519 ms (128 allocations: 38.22 MiB)

Overall, our performance numbers are really similar except for lsexp_mat3, where my computer is much slower.

1 Like

That’s interesting, normally your computer smokes mine. Perhaps it’s because lsexp_mat3 is launcing too many threads? It has an extremely crude calculation which here decides that up to 1000^2 / 3636 = 275 threads would be worthwhile. (And the same in the avx case, but I guess compensated.) Keyword threads=200_000 will stop it at about 5 threads (meaning 4 or 8).

Thanks for the explanation about broadcasting. So the problem is particular to trivial dimensions which have stride 1, regardless of the other strides involved?

@avx ones(10,10)' .* rand(10,1)  # no problem
@avx ones(10,10)' .* rand(10,1)' # no problem
@avx ones(10,10)' .* rand(1,10)  # problem
@avx ones(10,10)' .* rand(1,10)' # problem
2 Likes

Thanks, that’s why I thought it was curious/worth pointing out. And that seems to be it.
In particular, Threads.nthreads() was 1. Now using 18 threads:

julia> @btime lsexp_mat($A);
  11.372 ms (13 allocations: 15.41 MiB)

julia> @btime lsexp_mat2($A);
  11.212 ms (6 allocations: 7.65 MiB)

julia> @btime lsexp_mat3($A);
  2.105 ms (332 allocations: 7.68 MiB)

julia> @btime lsexp_mat4($A);
  2.385 ms (17 allocations: 7.65 MiB)

julia> @btime lsexp_mat5($A);
  1.613 ms (328 allocations: 7.68 MiB)

julia> using Tracker, Zygote #, ForwardDiff

julia> Zygote.@nograd safeeq

julia> gA = Tracker.gradient(sum∘lsexp_mat, A)[1];

julia> Zygote.gradient(sum∘lsexp_mat1, A)[1] ≈ gA
true

julia> Zygote.gradient(sum∘lsexp_mat3, A)[1] ≈ gA
true

julia> Zygote.gradient(sum∘lsexp_mat5, A)[1] ≈ gA
true

julia> @btime Zygote.gradient(sum∘lsexp_mat1, $A);
  65.850 ms (3003130 allocations: 137.61 MiB)

julia> @btime Zygote.gradient(sum∘lsexp_mat3, $A);
  8.496 ms (788 allocations: 38.28 MiB)

julia> @btime Zygote.gradient(sum∘lsexp_mat5, $A);
  7.639 ms (778 allocations: 38.28 MiB)

Trying different values in the threads arguments:

julia> @btime lsexp_mat_25_000_threads($A);
  1.669 ms (328 allocations: 7.68 MiB)

julia> @btime lsexp_mat_50_000_threads($A);
  1.681 ms (328 allocations: 7.68 MiB)

julia> @btime lsexp_mat_100_000_threads($A);
  1.679 ms (328 allocations: 7.68 MiB)

julia> @btime lsexp_mat_200_000_threads($A);
  1.792 ms (160 allocations: 7.67 MiB)

julia> @btime Zygote.gradient(sum∘lsexp_mat_25_000_threads, $A);
  7.700 ms (779 allocations: 38.28 MiB)

julia> @btime Zygote.gradient(sum∘lsexp_mat_50_000_threads, $A);
  7.689 ms (779 allocations: 38.28 MiB)

julia> @btime Zygote.gradient(sum∘lsexp_mat_100_000_threads, $A);
  7.699 ms (780 allocations: 38.28 MiB)

julia> @btime Zygote.gradient(sum∘lsexp_mat_200_000_threads, $A);
  7.902 ms (424 allocations: 38.25 MiB)

It’s definitely an improvement, but lags well behind the 18 threads. With avx=false:

julia> @btime lsexp_mat_25_000_threads_noavx($A);
  2.124 ms (330 allocations: 7.68 MiB)

julia> @btime lsexp_mat_50_000_threads_noavx($A);
  2.114 ms (330 allocations: 7.68 MiB)

julia> @btime lsexp_mat_100_000_threads_noavx($A);
  2.123 ms (330 allocations: 7.68 MiB)

julia> @btime lsexp_mat_200_000_threads_noavx($A);
  2.602 ms (161 allocations: 7.67 MiB)

julia> @btime Zygote.gradient(sum∘lsexp_mat_25_000_threads_noavx, $A);
  8.451 ms (784 allocations: 38.28 MiB)

julia> @btime Zygote.gradient(sum∘lsexp_mat_50_000_threads_noavx, $A);
  8.450 ms (782 allocations: 38.28 MiB)

julia> @btime Zygote.gradient(sum∘lsexp_mat_100_000_threads_noavx, $A);
  8.502 ms (785 allocations: 38.28 MiB)

julia> @btime Zygote.gradient(sum∘lsexp_mat_200_000_threads_noavx, $A);
  9.605 ms (429 allocations: 38.25 MiB)

Scaling is a lot better than it looks, because the serial portions take a long time, e.g. around half of all the time is spent in maximum:

julia> @btime maximum($A, dims=1);
  903.528 Îźs (3 allocations: 8.00 KiB)

This could of course also be optimized by making it threaded and/or SIMD. We can try a SIMD version via LoopVectorization:

julia> @btime(vreduce(max, $A, dims=1)) == maximum(A, dims=1)
  333.883 Îźs (1 allocation: 7.94 KiB)
true

But I didn’t try this within lsexp_mat due to the lack of Zygote support.

So the problem is particular to trivial dimensions which have stride 1, regardless of the other strides involved?

Yes. The first two examples should work, while the latter two will be broken.

1 Like

I think the core problem here is https://github.com/JuliaDiff/ReverseDiff.jl/issues/135 , I should really change the readme not to promise things which don’t work, sorry. (It worked in an earlier incarnation.) That said I’m a little surprised by those error messages, there may be other details to fix beyond that.

Is the gensym problem related to this error too? I thought ReverseDiff doesn’t support vreduce. Is there anyway to speed up maximum

It would probably be easy to make vreduce work, by defining a gradient as in this PR using perhaps Tracker’s as the model.

To make @tullio work you’d have to solve that issue, the easy part is Base.gensym(ex::Expr) = gensym(string(ex)) but I didn’t see how to make @grad deal with callable structs. (The macro wraps the forward & backward functions in a struct Eval so that there’s somewhere to attach gradient definitions.)

Otherwise you may need to go back to write things as broadcasting instead. Or tell Turing to use Zygote not ReverseDiff.

1 Like

I’m trying to use Tullio with sparse arrays.

using Tullio, LoopVectorization, SparseArrays

function logsumexp(mat; dims=1)
           max_ = maximum(mat, dims=1)
           @tullio exp_mat[i,j] := exp(mat[i,j] - max_[1,j]) - (mat[i,j] == max_[1,j])
           sum_exp_ = sum(exp_mat, dims=dims)
           @tullio out[i,j] := log1p(sum_exp_[i,j]) + max_[i,j]
       end

A = sprand(1500, 10000, 0.02);
logsumexp(A)

When I run the above code I get signal (6): Aborted

This is the full stack trace.

double free or corruption (!prev)

signal (6): Aborted
in expression starting at REPL[4]:1
gsignal at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
abort at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
unknown function (ip: 0x7fe8d09d8906)
unknown function (ip: 0x7fe8d09df979)
cfree at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
jl_realloc_aligned at /buildworker/worker/package_linux64/build/src/gc.c:249 [inlined]
gc_managed_realloc_ at /buildworker/worker/package_linux64/build/src/gc.c:3369 [inlined]
jl_gc_managed_realloc at /buildworker/worker/package_linux64/build/src/gc.c:3386
array_resize_buffer at /buildworker/worker/package_linux64/build/src/array.c:660
jl_array_grow_at_beg at /buildworker/worker/package_linux64/build/src/array.c:785 [inlined]
jl_array_grow_at at /buildworker/worker/package_linux64/build/src/array.c:929
_growat! at ./array.jl:873 [inlined]
insert! at ./array.jl:1215 [inlined]
_insert! at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.4/SparseArrays/src/sparsematrix.jl:2485
_setindex_scalar! at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.4/SparseArrays/src/sparsematrix.jl:2473
setindex! at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.4/SparseArrays/src/sparsematrix.jl:2447 [inlined]
𝒜𝒸𝓉! at /home/s/.julia/packages/Tullio/HGzih/src/macro.jl:797
unknown function (ip: 0x7fe87f26475e)
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2145 [inlined]
jl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2323
block_halves at /home/s/.julia/packages/Tullio/HGzih/src/threads.jl:104
block_halves at /home/s/.julia/packages/Tullio/HGzih/src/threads.jl:108
block_halves at /home/s/.julia/packages/Tullio/HGzih/src/threads.jl:107
block_halves at /home/s/.julia/packages/Tullio/HGzih/src/threads.jl:107
block_halves at /home/s/.julia/packages/Tullio/HGzih/src/threads.jl:101 [inlined]
thread_halves at /home/s/.julia/packages/Tullio/HGzih/src/threads.jl:91
#184 at ./threadingconstructs.jl:126
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2145 [inlined]
jl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2323
jl_apply at /buildworker/worker/package_linux64/build/src/julia.h:1700 [inlined]
start_task at /buildworker/worker/package_linux64/build/src/task.c:687
unknown function (ip: (nil))
Allocations: 55469269 (Pool: 55453222; Big: 16047); GC: 56
Aborted (core dumped)

The code runs fine for 100x100 sparse arrays and 1500x10000 normal arrays. My machine has 45 GB free RAM. What is the problem here?

That’s a pretty inelegant failure! I get something similar.

For smaller sizes it does run. But I wouldn’t expect it to be efficient, it’s completely unaware of sparsity & just works through every element. It shouldn’t be hard to write a fast logsumexp(:: SparseMatrixCSC ) though.

julia> Tullio.storage_type(sprand(10,10,0.1)) # not <: Array{<:BlasFloat}, hence will not use LoopVectorization
SparseMatrixCSC{Float64,Int64}

julia> A[3,4] # but you can index, so fallback seems OK?
0.0

julia> logsumexp(A)
julia(43313,0x700003cdf000) malloc: *** error for object 0x18a8f9000: pointer being freed was not allocated
julia(43313,0x700003cdf000) malloc: *** set a breakpoint in malloc_error_break to debug

signal (6): Abort trap: 6
in expression starting at REPL[73]:1
__pthread_kill at /usr/lib/system/libsystem_kernel.dylib (unknown line)
Allocations: 1296329819 (Pool: 1296254828; Big: 74991); GC: 596
Abort trap: 6
1 Like

I wondered whether ReverseDiff had been quietly fixed, but it seems not, Use with ReverseDiff · Issue #21 · mcabbott/Tullio.jl · GitHub has a simple test. So I think you ought to get errors if Turing is trying to do this, at all. Is there a chance Turing isn’t using AD at all?

So lsexp1 which is purely broadcasting is giving you this corruption error? That might be worth isolating. Is A still sparse here?

I mixed up the function names, now I fixed them.
While importing Tullio with ReverseDiff I get the following warning but generally the sampling finishes and estimations are accurate.

┌ Warning: Error requiring ReverseDiff from Tullio:
│ LoadError: LoadError: MethodError: no method matching gensym(::Expr)
│ Closest candidates are:
│   gensym(!Matched::Symbol) at expr.jl:15
│   gensym(!Matched::String) at expr.jl:12
│   gensym() at expr.jl:10
│   ...
│ Stacktrace:
│  [1] @grad(::LineNumberNode, ::Module, ::Any) at /home/swamy/.julia/packages/ReverseDiff/Thhqg/src/macros.jl:182
│  [2] include(::Module, ::String) at ./Base.jl:377
│  [3] include(::String) at /home/s/.julia/packages/Tullio/HGzih/src/Tullio.jl:1
│  [4] top-level scope at REPL[11]:1
│  [5] eval at ./boot.jl:331 [inlined]
│  [6] eval at /home/s/.julia/packages/Tullio/HGzih/src/Tullio.jl:1 [inlined]
│  [7] (::Tullio.var"#150#154")() at /home/s/.julia/packages/Requires/qy6zC/src/require.jl:85
│  [8] err(::Any, ::Module, ::String) at /home/s/.julia/packages/Requires/qy6zC/src/require.jl:42
│  [9] (::Tullio.var"#149#153")() at /home/s/.julia/packages/Requires/qy6zC/src/require.jl:84
│  [10] withpath(::Any, ::String) at /home/s/.julia/packages/Requires/qy6zC/src/require.jl:32
│  [11] (::Tullio.var"#148#152")() at /home/s/.julia/packages/Requires/qy6zC/src/require.jl:83
│  [12] listenpkg(::Any, ::Base.PkgId) at /home/s/.julia/packages/Requires/qy6zC/src/require.jl:15
│  [13] macro expansion at /home/s/.julia/packages/Requires/qy6zC/src/require.jl:81 [inlined]
│  [14] (::Tullio.var"#147#151")() at /home/s/.julia/packages/Requires/qy6zC/src/init.jl:11
│  [15] __init__() at /home/s/.julia/packages/Requires/qy6zC/src/init.jl:18
│  [16] _include_from_serialized(::String, ::Array{Any,1}) at ./loading.jl:697
│  [17] _require_search_from_serialized(::Base.PkgId, ::String) at ./loading.jl:781
│  [18] _require(::Base.PkgId) at ./loading.jl:1006
│  [19] require(::Base.PkgId) at ./loading.jl:927
│  [20] require(::Module, ::Symbol) at ./loading.jl:922
│  [21] eval(::Module, ::Any) at ./boot.jl:331
│  [22] eval_user_input(::Any, ::REPL.REPLBackend) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.4/REPL/src/REPL.jl:86
│  [23] macro expansion at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.4/REPL/src/REPL.jl:118 [inlined]
│  [24] (::REPL.var"#26#27"{REPL.REPLBackend})() at ./task.jl:358
│ in expression starting at /home/s/.julia/packages/Tullio/HGzih/src/grad/reverse.jl:8
│ in expression starting at /home/s/.julia/packages/Tullio/HGzih/src/grad/reverse.jl:8
└ @ Requires ~/.julia/packages/Requires/qy6zC/src/require.jl:44
  1. NUTS sampler needs to get gradients from AD, so its somehow getting gradients or else it will output gradient undefined error or something. I’m sure ReverseDiff is being used because if I use ForwardDiff instead of ReverseDiff it takes much longer to finish sampling.

  2. Function using Tullio gives the corruption error.

  3. A is not sparse, its a dense array.

I’ve written two version logsumexp with maximum function using avx and normal maximum.

This code defines a gradient for avx_max function using @grad macro in ReverseDiff

using LoopVectorization, ReverseDiff
using ReverseDiff: @grad, TrackedArray


function avx_max(A; j=1)
    max_ = zeros(size(A, 2))
    @avx for i ∈  1:size(A, 2)
        j = 1
        max_el = A[j, i]
        for j ∈ 1:size(A, 1)
            max_el = max(max_el, A[j,i])
        end
       max_[i] = max_el
    end
    reshape(max_, 1, :) 
end

fast_max(x) = avx_max(x)
fast_max(x::TrackedArray) = ReverseDiff.track(fast_max, x) 
@grad function fast_max(x::AbstractArray)
    xv = ReverseDiff.value(x)
    T = Array{Float64, 2}
    max_ = avx_max(xv)
    max_ret = T(xv .== max_)
    max_, Δ -> (max_ret, )
end

logsumexp with avx and without avx

function logsumexp_avx(mat; dims=1) 
    @assert dims == 1
    max_ = vec(fast_max(mat, dims=1))' # requires dims=1
    exp_mat = @avx exp.(mat .- max_) .- (mat .== max_) 
    sum_exp_ = sum(exp_mat, dims=dims)
    @avx sum_exp_ .= log1p.(sum_exp_) .+ max_
end

function logsumexp_no_avx(mat; dims=1) 
    max_ = maximum(mat, dims=1)
    exp_mat = exp.(mat .- max_) .- (mat .== max_)
    sum_exp_ = sum(exp_mat, dims=dims)
    sum_exp_ .= log1p.(sum_exp_) .+ max_
end
x = rand(3,3);
logsumexp_avx(x) ≈ logsumexp_no_avx(x) #true
ReverseDiff.gradient(sum∘logsumexp_avx, x) ≈ ReverseDiff.gradient(sum∘logsumexp_no_avx, x) # false

I don’t know why the gradients of the both functions don’t match. The gradients of avx_max and julia’s maximum matches.

The gradient of logsumexp_avx has 1 added to all the columnwise maximum positions.

For example

Gradient of logsumexp_no_avx

 0.204813  0.384189  0.500139
 0.420175  0.21128   0.242529
 0.375012  0.404531  0.257332

Gradient of logsumexp_avx

 0.204813  0.384189  1.50014
 1.42017   0.21128   0.242529
 0.375012  1.40453   0.257332

Any help regarding this please?

I had to define

function logsumexp_avx(mat; dims=1) 
    @assert dims == 1
    max_ = vec(fast_max(mat))' # requires dims=1
    exp_mat = @avx exp.(mat .- max_) .- (mat .== max_) 
    sum_exp_ = sum(exp_mat, dims=dims)
    @avx sum_exp_ .= log1p.(sum_exp_) .+ max_
end

(i.e., remove the dims = 1 argument from the call to fast_max)

Also, more definitions are needed for the gradient to work. Did you define something like

LoopVectorization.vmaterialize(bc::Base.Broadcast.Broadcasted{<:ReverseDiff.TrackedStyle}, ::Val{_}) where {_} = Base.materialize(bc)

?

1 Like

My bad I didn’t verify what I pasted.

No. I didn’t get any errors when I called gradient of either fast_max or logsumexp_avx.


Surprisingly, if I remove my gradient definition for fast_max the gradients of logsumexp_avx and logsumexp_no_avx matches. I don’t know what’s going on.