Speeding up my logsumexp function

I found that zero1_mat = mat .== maximum(mat, dims=1) creates what I want.

Thank you for your solution, but I want to keep the solution as simple as possible. Zygote may not support SparseArrays.

My new code is

function lsexp_mat(mat; dims=1)
	max_ = maximum(mat, dims=1)
    zero1_mat = mat .== max_
    exp_mat = exp.(mat .- max_)
    sum_exp_ = sum(exp_mat .- zero1_mat, dims=dims)
    log1p.(sum_exp_) .+ max_
end
n = 10000
A = rand(n,n) 

@benchmark lsexp_mat(A, dims=1)

@benchmark mapslices(lsexp_vector, A; dims=1)

lsexp_mat bench mark

BenchmarkTools.Trial: 
  memory estimate:  1.50 GiB
  allocs estimate:  17
  --------------
  minimum time:     1.750 s (0.30% GC)
  median time:      1.784 s (3.12% GC)
  mean time:        1.829 s (5.38% GC)
  maximum time:     1.953 s (11.99% GC)
  --------------
  samples:          3
  evals/sample:     1

lsexp_vector benchmark

BenchmarkTools.Trial: 
  memory estimate:  2.57 MiB
  allocs estimate:  108509
  --------------
  minimum time:     1.127 s (0.00% GC)
  median time:      1.129 s (0.00% GC)
  mean time:        1.133 s (0.00% GC)
  maximum time:     1.143 s (0.00% GC)
  --------------
  samples:          5
  evals/sample:     1

We can see that lsexp_mat is slower than lsexp_vector and also uses orders of magnitude more memory. Why is this happening and how to make this better and faster? And also the number of allocations of lsexp_vector is very high compared to lsexp_mat yet lsexp_vector is faster. I expected lsexp_mat to be faster since everything is vectorised.