Oops, didn’t read for the motivation; my apologies.
remainder(dim::Int, ::NTuple{2}) = 3 - dim
function remainder(dim::Int, ::NTuple{N}) where {N}
ntuple(n -> n < dim ? n : n + 1, Val(N-1))
end
using SparseArrays
function lsexp_mat(mat::AbstractMatrix; dims=1)
@assert dims == 1
remdim = remainder(dims,size(mat))
maxinds_ = map(argmax, eachslice(mat, dims=remdim))
max_ = getindex.(eachslice(mat, dims=remdim), maxinds_)
m, n = size(mat)
zero1_mat = sparse(maxinds_, axes(mat,2), ones(m), m, n)
exp_mat = exp.(mat .- max_) - zero1_mat # TODO: generalize me
log1p.(sum(exp_mat, dims=dims)) .+ max_' # TODO: generalize me
end
This isn’t a great solution, but I wanted to actually offer something addressing your actual question:
Is there way to create a mask matrix with
1
at all argmax indices and0
else where without mutation?
by using sparse
. You could almost certainly make that more efficient.
julia> mat = [1e-20 1e-20; log(1e-20) log(1e-20)];
julia> zero1_mat = zeros(size(mat)); zero1_mat[end, :] = zero1_mat[end, :] .+ 1;
julia> lsexp_mat(mat) # new
1×2 Array{Float64,2}:
2.0e-20 2.0e-20
julia> lsexp_mat(mat, zero1_mat) # original
1×2 Array{Float64,2}:
2.0e-20 2.0e-20
Worth pointing out that Zygote doesn’t currently support the SparseMatricSCS
constructor, so this answer doesn’t really help you.
It also doesn’t support sort
. So you’ll have to do a little work defining your own adjoints. I’d reccomend defining the rule for lsexp_mat
directly (which would allow you to mutate internally).