`rev_cumsum_exp!()` optimization

You are right. Here is a full example, where I highlighted as a function the sensitive operation (eating 90% of my runtime). Its a mix between the code we were looking at here and the one from the previous thread, where i ended up using your solution.

using RDatasets
using LinearAlgebra
using BenchmarkTools
using StatsBase

# The dataset: 
const colon = let
    df = dataset("survival","colon")
    df = dropmissing(df, [:Nodes, :Differ])
    
    T = Float64.(df.Time)
    Δ = Int64.(df.Status)
    X = Float64.(hcat(df.Sex, df.Age, df.Obstruct,
    df.Perfor, df.Adhere, df.Nodes,
    df.Differ, df.Extent, df.Surg,
    df.Node4, df.EType))
    (T,Δ,X)
end

struct Cox
    X::Matrix{Float64}
    T::Vector{Float64}
    Δ::Vector{Bool}
    sX::Vector{Float64}
    G::Vector{Float64}
    η::Vector{Float64}
    B::Vector{Float64}
    I::Vector{Int64}
    J::Vector{Int64}
    _A::Vector{Float64}
    _B::Vector{Float64}
    function Cox(T,Δ,X)

        # Allocate: 
        n,m = size(X)
        G, B  = zeros(m), zeros(m)
        η, _A, _B  = zeros(n), zeros(n), zeros(n)

        # Precompute a few things: 
        # This should also be optimized, its taking a lot of time right now..
        o = sortperm(T)
        To = T[o]
        Δo = Bool.(Δ[o])
        Xo = X[o,:]
        sX = X'Δ
        Ro = StatsBase.competerank(To)
        I = Ro[Δo]

        # J is a form of ranks, filtered on Δ. Did not really found a function to do it. 
        J = zeros(Int64,n)
        for j in 1:n
            if Δo[j]
                for i in 1:n
                    J[i] += (To[i] >= To[j])
                end
            end
        end
        
        # Compute hessian bounds:
        for l in 1:m # for each dimension. 
            lastj = n+1
            Mₓ = Xo[end,l]
            mₓ = Xo[end,l]
            for i in n:-1:1
                for j in lastj-1:-1:Ro[i]
                    Mₓ = max(Xo[j,l], Mₓ)
                    mₓ = min(Xo[j,l], mₓ)
                end
                lastj = Ro[i]
                B[l] += (1/4) * Δ[i] * (Mₓ-mₓ)^2
            end
        end

        # Instantiate: 
        new(Xo, To, Δo, sX, G, η, B, I, J, _A, _B)
    end
end

function most_of_my_runtime!(A,B,η,I,J)
    # This is equivalent to : 
    #   A .= exp.(η)
    #   B .= 1 ./ reverse(cumsum(reverse(A)))
    #   A .*= cumsum(B[I])[J]
    # But 40% faster for the moment.

    aₖ = bₖ = cₖ = zero(eltype(A))
    n, lastj = length(η), 0
    local j

    @inbounds for k in n:-1:1
        aₖ = exp(η[k])
        A[k] = aₖ
        bₖ += aₖ
        B[k] = bₖ
    end
    @inbounds for k in 1:n
        for outer j in lastj+1:J[k]
            cₖ += inv(B[I[j]])
        end
        lastj = j
        A[k] *= cₖ
    end
end

function update!(β, M::Cox)
    mul!(M.η, M.X, β)
    most_of_my_runtime!(M._A, M._B, M.η, M.I, M.J)
    mul!(M.G, M.X', M._A)
    M.G .= M.sX .- M.G
    β .+= M.G ./ M.B 
    return nothing
end

function getβ(M::Cox; max_iter = 10000, tol = 1e-6)

    β = zeros(size(M.X, 2))
    βᵢ = similar(β)
    for i in 1:max_iter
        βᵢ .= β
        update!(β, M) 
        gap = norm(βᵢ - β)
        # println(gap)
        if gap < tol
            break
        end
    end
    return β
end

M = Cox(colon...)
getβ(M)
@profview [getβ(M) for i in 1:100]

The profview I have looks like:

I thought about trying to intertwine the two loops, but i was not able to find a better version yet.