Speeding up my logsumexp function

Oops, didn’t read for the motivation; my apologies.

remainder(dim::Int, ::NTuple{2}) = 3 - dim
function remainder(dim::Int, ::NTuple{N}) where {N}
    ntuple(n -> n < dim ? n : n + 1, Val(N-1))
end
using SparseArrays
function lsexp_mat(mat::AbstractMatrix; dims=1)
    @assert dims == 1
    remdim = remainder(dims,size(mat))
    maxinds_ = map(argmax, eachslice(mat, dims=remdim))
    max_ = getindex.(eachslice(mat, dims=remdim), maxinds_)
    m, n = size(mat)
    zero1_mat = sparse(maxinds_, axes(mat,2), ones(m), m, n)
    exp_mat = exp.(mat .- max_) - zero1_mat # TODO: generalize me
    log1p.(sum(exp_mat, dims=dims)) .+ max_' # TODO: generalize me
end

This isn’t a great solution, but I wanted to actually offer something addressing your actual question:

Is there way to create a mask matrix with 1 at all argmax indices and 0 else where without mutation?

by using sparse. You could almost certainly make that more efficient.

julia> mat = [1e-20 1e-20; log(1e-20) log(1e-20)];

julia> zero1_mat = zeros(size(mat)); zero1_mat[end, :] = zero1_mat[end, :] .+ 1;

julia> lsexp_mat(mat) # new
1×2 Array{Float64,2}:
 2.0e-20  2.0e-20

julia> lsexp_mat(mat, zero1_mat) # original
1×2 Array{Float64,2}:
 2.0e-20  2.0e-20

Worth pointing out that Zygote doesn’t currently support the SparseMatricSCS constructor, so this answer doesn’t really help you.
It also doesn’t support sort. So you’ll have to do a little work defining your own adjoints. I’d reccomend defining the rule for lsexp_mat directly (which would allow you to mutate internally).