Softmax using Tullio

I want to port my DIY softmax function

softmax(x) = exp.(x .- maximum(x)) / sum(exp.(x .- maximum(x)))

to Tullio einsum so that it runs faster and gives fast symbolic gradients

I split it into multiple lines, and it works:

function softmax_einsum(x)
    maxx = maximum(x)
    @tullio sumx := exp(x[i] - maxx) verbose=false
    @tullio ret[i] := exp(x[i] - maxx) / sumx verbose=false
end

It works

julia> t1 = [1.0,1,1,1]
4-element Vector{Float64}:
 1.0
 1.0
 1.0
 1.0

julia> softmax(t1)
4-element Vector{Float64}:
 0.25
 0.25
 0.25
 0.25

julia> NNlib.softmax(t1)
4-element Vector{Float64}:
 0.25
 0.25
 0.25
 0.25

julia> softmax_einsum(t1)
4-element Vector{Float64}:
 0.25
 0.25
 0.25
 0.25

But it gives different jacobian

julia> t1 = [1.0,1,1,1]
4-element Vector{Float64}:
 1.0
 1.0
 1.0
 1.0

julia> Zygote.jacobian(NNlib.softmax, t1)
([0.1875 -0.0625 -0.0625 -0.0625; -0.0625 0.1875 -0.0625 -0.0625; -0.0625 -0.0625 0.1875 -0.0625; -0.0625 -0.0625 -0.0625 0.1875],)

julia> Zygote.jacobian(softmax, t1)
([0.1875 -0.0625 -0.0625 -0.0625; -0.0625 0.1875 -0.0625 -0.0625; -0.0625 -0.0625 0.1875 -0.0625; -0.0625 -0.0625 -0.0625 0.1875],)

julia> Zygote.jacobian(softmax_einsum, t1)
([0.25 0.0 0.0 0.0; 0.0 0.25 0.0 0.0; 0.0 0.0 0.25 0.0; 0.0 0.0 0.0 0.25],)

What happened?

you are computing exp(...) twice.

Yeah, im not sure why you get the wrong thing with tullio, but you can improve your implementation before that.

First I dont think the maximum does much if you are looking at it symbolically, it might have some numerical purpose but the answer should be the same without it.
And the as was mentioned above you calculate exp of all elements twice which is a large waste.