Speeding up my logsumexp function

I’ve written two version logsumexp with maximum function using avx and normal maximum.

This code defines a gradient for avx_max function using @grad macro in ReverseDiff

using LoopVectorization, ReverseDiff
using ReverseDiff: @grad, TrackedArray


function avx_max(A; j=1)
    max_ = zeros(size(A, 2))
    @avx for i ∈  1:size(A, 2)
        j = 1
        max_el = A[j, i]
        for j ∈ 1:size(A, 1)
            max_el = max(max_el, A[j,i])
        end
       max_[i] = max_el
    end
    reshape(max_, 1, :) 
end

fast_max(x) = avx_max(x)
fast_max(x::TrackedArray) = ReverseDiff.track(fast_max, x) 
@grad function fast_max(x::AbstractArray)
    xv = ReverseDiff.value(x)
    T = Array{Float64, 2}
    max_ = avx_max(xv)
    max_ret = T(xv .== max_)
    max_, Δ -> (max_ret, )
end

logsumexp with avx and without avx

function logsumexp_avx(mat; dims=1) 
    @assert dims == 1
    max_ = vec(fast_max(mat, dims=1))' # requires dims=1
    exp_mat = @avx exp.(mat .- max_) .- (mat .== max_) 
    sum_exp_ = sum(exp_mat, dims=dims)
    @avx sum_exp_ .= log1p.(sum_exp_) .+ max_
end

function logsumexp_no_avx(mat; dims=1) 
    max_ = maximum(mat, dims=1)
    exp_mat = exp.(mat .- max_) .- (mat .== max_)
    sum_exp_ = sum(exp_mat, dims=dims)
    sum_exp_ .= log1p.(sum_exp_) .+ max_
end
x = rand(3,3);
logsumexp_avx(x) ≈ logsumexp_no_avx(x) #true
ReverseDiff.gradient(sum∘logsumexp_avx, x) ≈ ReverseDiff.gradient(sum∘logsumexp_no_avx, x) # false

I don’t know why the gradients of the both functions don’t match. The gradients of avx_max and julia’s maximum matches.

The gradient of logsumexp_avx has 1 added to all the columnwise maximum positions.

For example

Gradient of logsumexp_no_avx

 0.204813  0.384189  0.500139
 0.420175  0.21128   0.242529
 0.375012  0.404531  0.257332

Gradient of logsumexp_avx

 0.204813  0.384189  1.50014
 1.42017   0.21128   0.242529
 0.375012  1.40453   0.257332

Any help regarding this please?