Speeding up my logsumexp function

What versions are you on? I get the following error wheter or not I define the fast_max gradient:

julia> ReverseDiff.gradient(sum∘logsumexp_avx, x)
ERROR: MethodError: no method matching similar(::Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(-),Tuple{Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof
(exp),Tuple{Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(-),Tuple{TrackedArray{Float64,Float64,2,Matrix{Float64},Matrix{Float64}},LinearAlgebra.Adjoint{ReverseDiff.Track
edReal{Float64,Float64,TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}}},TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}}}}}}},Base.Broadcast.Broadcasted{Reverse
Diff.TrackedStyle,Nothing,typeof(==),Tuple{TrackedArray{Float64,Float64,2,Matrix{Float64},Matrix{Float64}},LinearAlgebra.Adjoint{ReverseDiff.TrackedReal{Float64,Float64,TrackedArray{Float64,F
loat64,1,Vector{Float64},Vector{Float64}}},TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}}}}}}}, ::Type{ReverseDiff.TrackedReal{Float64,Float64,Nothing}}, ::Tuple{Base.OneTo{I
nt64},Base.OneTo{Int64}})
Closest candidates are:
  similar(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{N},Axes,F,Args} where Args<:Tuple where F where Axes, ::Type{ElType}, ::Any) where {N, ElType} at broadcast.jl:197
  similar(::Base.Broadcast.Broadcasted{Base.Broadcast.ArrayConflict,Axes,F,Args} where Args<:Tuple where F where Axes, ::Type{ElType}, ::Any) where ElType at broadcast.jl:202
  similar(::Base.Broadcast.Broadcasted, ::Type{T}) where T at broadcast.jl:196
  ...
Stacktrace:
 [1] similar(bc::Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(-),Tuple{Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(exp),Tuple{Base.Broadcast.Broad
casted{ReverseDiff.TrackedStyle,Nothing,typeof(-),Tuple{TrackedArray{Float64,Float64,2,Matrix{Float64},Matrix{Float64}},LinearAlgebra.Adjoint{ReverseDiff.TrackedReal{Float64,Float64,TrackedAr
ray{Float64,Float64,1,Vector{Float64},Vector{Float64}}},TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}}}}}}},Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof
(==),Tuple{TrackedArray{Float64,Float64,2,Matrix{Float64},Matrix{Float64}},LinearAlgebra.Adjoint{ReverseDiff.TrackedReal{Float64,Float64,TrackedArray{Float64,Float64,1,Vector{Float64},Vector{
Float64}}},TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}}}}}}}, #unused#::Type{ReverseDiff.TrackedReal{Float64,Float64,Nothing}})
   @ Base.Broadcast ./broadcast.jl:196
 [2] vmaterialize(bc::Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(-),Tuple{Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(exp),Tuple{Base.Broadcast.
Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(-),Tuple{TrackedArray{Float64,Float64,2,Matrix{Float64},Matrix{Float64}},LinearAlgebra.Adjoint{ReverseDiff.TrackedReal{Float64,Float64,Trac
kedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}}},TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}}}}}}},Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,t
ypeof(==),Tuple{TrackedArray{Float64,Float64,2,Matrix{Float64},Matrix{Float64}},LinearAlgebra.Adjoint{ReverseDiff.TrackedReal{Float64,Float64,TrackedArray{Float64,Float64,1,Vector{Float64},Ve
ctor{Float64}}},TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}}}}}}}, #unused#::Val{:Main})
   @ LoopVectorization ~/.julia/dev/LoopVectorization/src/broadcast.jl:327
 [3] logsumexp_avx(mat::TrackedArray{Float64,Float64,2,Matrix{Float64},Matrix{Float64}}; dims::Int64)
   @ Main ./REPL[34]:4
 [4] logsumexp_avx(mat::TrackedArray{Float64,Float64,2,Matrix{Float64},Matrix{Float64}})
   @ Main ./REPL[34]:2 

Output of Pkg.installed()

  "Tullio"            => v"0.2.0"
  "ForwardDiff"       => v"0.10.12"
  "Memoize"           => v"0.4.3"
  "Distributions"     => v"0.23.4"
  "Atom"              => v"0.12.16"
  "BenchmarkTools"    => v"0.5.0"
  "Optim"             => v"0.22.0"
  "StatsPlots"        => v"0.14.6"
  "TimerOutputs"      => v"0.5.6"
  "Juno"              => v"0.8.2"
  "ReverseDiff"       => v"1.2.0"
  "StatsBase"         => v"0.33.0"
  "Memoization"       => v"0.1.4"
  "Tracker"           => v"0.2.8"
  "IJulia"            => v"1.21.2"
  "LoopVectorization" => v"0.8.15"
  "Flux"              => v"0.8.3"
  "Plots"             => v"1.5.4"
  "SCS"               => v"0.6.6"
  "Noise"             => v"0.2.0"
  "StatsFuns"         => v"0.9.5"
  "MacroTools"        => v"0.5.5"
  "StaticArrays"      => v"0.12.4"
  "XLSX"              => v"0.7.2"
  "Convex"            => v"0.13.4"
  "Zygote"            => v"0.4.22"
  "DataFrames"        => v"0.21.4"
  "NNlib"             => v"0.6.6"
  "Turing"            => v"0.13.0"
  "Setfield"          => v"0.7.0"
  "SliceMap"          => v"0.2.3"

I was on ReverseDiff 1.4. I don’t get an error with ReverseDiff 1.2.
How broadcasting was handled changed between versions.

I’ll have to look at this more, but my preliminary guess is that LoopVectorization.check_args fails, so that a fallback loop is used instead of LoopVectorization, making it work.

If thats the case then the gradient produced by both with avx and without should be same right?

Yes.
Skimming through elementwise.jl in ReverseDiff@1.2, it looks like they use ForwardDiff for broadcasting. LoopVectorization.check_args will return false on any AbstractArray{<:ForwardDiff.Dual}.

@btime of gradients both functions is almost same. This must be the case then.

Alas, I was trying to speed up logsumexp for a 600x1200 matrix as mentioned above, it seems the gradients are going to take the same time either way.


ReverseDiff.gradient(sum∘avx_max, x) # without @grad outputs

 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0

But using avx_max without @grad in logsumexp_avx outputs the correct gradient. Wrong gradient but right answer.

It’ll be a while until I can work on it, and there are probably lots of bugs and definitely a lot of unimplemented derivative rules, but see here for an example of getting a reverse pass out of LoopVectorization directly.

But you could also just use closed form solutions for the @grads using @avx.

As for the problem of incorrect gradients when using fast_max:

julia> ReverseDiff.jacobian(y -> vec(maximum(y, dims=1)), x)
 3×9 Matrix{Float64}:
 1.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  1.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  1.0  0.0

julia> ReverseDiff.jacobian(fast_max, x)
 3×9 Matrix{Float64}:
 1.0  0.0  0.0  1.0  0.0  0.0  0.0  1.0  0.0
 1.0  0.0  0.0  1.0  0.0  0.0  0.0  1.0  0.0
 1.0  0.0  0.0  1.0  0.0  0.0  0.0  1.0  0.0

EDIT:
@vkv
You forgot to propagate the sensitivities…

@grad function fast_max(x::AbstractArray)
    xv = ReverseDiff.value(x)
    max_ = avx_max(xv)
    max_, Δ -> (Δ .* (xv .== max_), )
end
1 Like

I think I have this problem with gradients of maximum too.

However, if I’m thinking right, can’t you just skip that gradient here? Propagate only through sum_exp_ and exp_mat. For example if I define lsexp3 which is identical except for having nograd=max_ after each @tullio, then:

julia> mat = rand(3,4);

julia> Zygote.gradient(sum∘lsexp2, mat)[1]
3×4 Array{Float64,2}:
 0.228741  0.440927  0.422353  0.299816
 0.371044  0.30398   0.257484  0.27449
 0.400215  0.255093  0.320162  0.425694

julia> Zygote.gradient(sum∘lsexp3, mat)[1]
3×4 Array{Float64,2}:
 0.228741  0.440927  0.422353  0.299816
 0.371044  0.30398   0.257484  0.27449
 0.400215  0.255093  0.320162  0.425694
2 Likes

You’re right. This works with respect to logsumexp:

@grad function fast_max(x::AbstractArray)
    avx_max(ReverseDiff.value(x)), Δ -> (0.0,)
end

Should probably either bake this into the grad definition for logsumexp, or name the method max_for_logsumexp instead of a more generic name someone may mistakenly use elsewhere.

I haven’t thought super-hard about its stability etc – this fancy logsumexp with subtractions is (IIRC) better behaved in some weird cases (compared to the one with exp.(mat .- max_), which in turn is better than the naiive one) and ideally its gradient would be equally well-mannered.

Would be nice to have a good logsumexp as a library function, part of NNlib, perhaps.

Agreed! I tried to write one with KernelAbstractions once but got myself confused.
Also, the gradient of logsumexp is softmax, right? There’s an issue for that one with some code here btw Native Softmax · Issue #175 · JuliaGPU/CUDA.jl · GitHub

Oh right! So perhaps that’s an easy solution for this thread, define one @grad.

The @tullio versions of logsumexp will I believe work via KernelAbstractions. How fast I don’t know.

1 Like

The faster logsumexp function has a bug, whenever there are more than one maxima elements in the column it returns erroneous values.

This is due to using mat .== max_. Using julia’s findmax and using the indices returned by it resolves the bug, but degrades the run time performance. I will try to write a faster findmax function if possible.

I have been working on CTC loss which has similar issues with underflow, stability, etc.

LogSumExp is very well understood function and is used a lot in ML & NN. So I would not worry about differentiability and just write the fastest and most stable version. And for the derivative define the chain-rule myself.

The derivative is already defined here: https://github.com/FluxML/NNlib.jl/blob/master/src/softmax.jl

I have a tutorial that helps write ChainRules for code that is mutating arrays, where Mutating arrays is not supported.
https://github.com/rakeshvar/Zygote-Mutating-Arrays-WorkAround.jl

You can also see this:
https://github.com/FluxML/Zygote.jl/blob/6dadc6c1839ca4811ceeb5446d9df434e9c74362/src/lib/nnlib.jl#L35

1 Like

Adding this comment in case it’s helpful to anyone. I was able to get significant speedups by using LoopVectorization. The code is specialized to the array shape & dimensions, but might be helpful to someone. See https://github.com/magerton/FastLogSumExp.jl/ and also https://github.com/JuliaSIMD/LoopVectorization.jl/issues/437