`rev_cumsum_exp!()` optimization

So @Dan you exploited the fact that reverse(cumsum(reverse(A))) == sum(A) .- cumsum([0,A[1:end-1]...])… That is smart indeed. I thought about it, but then since the next cumsum has to happen in the reverse order to the first, I do not thing there is the possibility to do it in one pass other the dataset.

I ended up moving the crazy indexing logic to the constructor, to have it out of the hot loop, and realizing that each B[i] inside a k-adding-group are in fact all equal. Take a look at this version :

Complete code
using RDatasets
using LinearAlgebra
using BenchmarkTools
using StatsBase

# The dataset: 
const colon = let
    df = dataset("survival","colon")
    df = dropmissing(df, [:Nodes, :Differ])
    
    T = Float64.(df.Time)
    Δ = Int64.(df.Status)
    X = Float64.(hcat(df.Sex, df.Age, df.Obstruct,
    df.Perfor, df.Adhere, df.Nodes,
    df.Differ, df.Extent, df.Surg,
    df.Node4, df.EType))
    (T,Δ,X)
end

struct Cox
    X::Matrix{Float64}
    T::Vector{Float64}
    Δ::Vector{Bool}
    sX::Vector{Float64}
    G::Vector{Float64}
    η::Vector{Float64}
    B::Vector{Float64}
    K::Vector{Int64}
    _A::Vector{Float64}
    _B::Vector{Float64}
    function Cox(T,Δ,X)

        # Allocate: 
        n,m = size(X)
        G, B  = zeros(m), zeros(m)
        η, _A, _B  = zeros(n), zeros(n), zeros(n)

        # Precompute a few things: 
        # This should also be optimized, its taking a lot of time right now..
        o = sortperm(T)
        To = T[o]
        Δo = Bool.(Δ[o])
        Xo = X[o,:]
        sX = X'Δ
        
        # Indexing hell for I,J and now K
        Ro = StatsBase.competerank(To)
        I = Ro[Δo]
        J = zeros(Int64,n)
        for rⱼ in I
            J[rⱼ] += 1
        end
        for j in 2:n
            J[j] += J[j-1]
        end

        K = Int64[]
        Jₖ₋₁ = 0
        for Jₖ in J
            push!(K,length((Jₖ₋₁+1):Jₖ))
            Jₖ₋₁ = Jₖ
        end
        
        # Compute hessian bounds:
        for l in 1:m # for each dimension. 
            lastj = n+1
            Mₓ = Xo[end,l]
            mₓ = Xo[end,l]
            Bₗ = 0.0
            for i in n:-1:1
                for j in Ro[i]:(lastj-1)
                    Mₓ = max(Xo[j,l], Mₓ)
                    mₓ = min(Xo[j,l], mₓ)
                end
                lastj = Ro[i]
                Bₗ += (1/4) * Δ[i] * (Mₓ-mₓ)^2
            end
            B[l] = Bₗ
        end

        # Instantiate: 
        new(Xo, To, Δo, sX, G, η, B, K, _A, _B)
    end
end

function most_of_my_runtime!(A,B,η,K)
    aₖ = bₖ = cₖ = zero(eltype(A))
    n = length(η)
    
    @inbounds for k in n:-1:1
        aₖ = exp(η[k])
        A[k] = aₖ
        bₖ += aₖ
        B[k] = bₖ
    end
    @inbounds for (k, (nₖ, bₖ)) in enumerate(zip(K,B))
        cₖ += nₖ / bₖ
        A[k] *= cₖ
    end
end

function most_of_my_runtime_V2!(A,B,η,K)
    ######### Equivalent to : 
    ####    A .= exp.(η)
    ####    B .= reverse(cumsum(reverse(A)))
    ####    A .*= cumsum(K ./ B)
    ####
    #### Remark that K is half full of zeros
    ################################

    aₖ = bₖ = cₖ = zero(eltype(A))
    n = length(η)
    
    @inbounds for k in n:-1:1
        aₖ = exp(η[k])
        A[k] = aₖ
        bₖ += aₖ
        B[k] = bₖ
    end
    @inbounds for (k, (nₖ, bₖ)) in enumerate(zip(K,B))
        (nₖ > 0) && (cₖ += nₖ / bₖ)
        A[k] *= cₖ
    end
end

function update!(β, M::Cox)
    mul!(M.η, M.X, β)
    most_of_my_runtime_V2!(M._A, M._B, M.η, M.K)
    mul!(M.G, M.X', M._A)
    M.G .= M.sX .- M.G
    β .+= M.G ./ M.B 
    return nothing
end

function getβ(M::Cox; max_iter = 10000, tol = 1e-6)

    β = zeros(size(M.X, 2))
    βᵢ = similar(β)
    for i in 1:max_iter
        βᵢ .= β
        update!(β, M) 
        gap = norm(βᵢ - β)
        # println(gap)
        if gap < tol
            break
        end
    end
    return β
end

M = Cox(colon...)
getβ(M)
@profview [getβ(M) for i in 1:100]

Maybe some better structure for K could be found since 60% of the n\_k are actually zeros..

1 Like