So @Dan you exploited the fact that reverse(cumsum(reverse(A))) == sum(A) .- cumsum([0,A[1:end-1]...])
… That is smart indeed. I thought about it, but then since the next cumsum
has to happen in the reverse order to the first, I do not thing there is the possibility to do it in one pass other the dataset.
I ended up moving the crazy indexing logic to the constructor, to have it out of the hot loop, and realizing that each B[i] inside a k-adding-group are in fact all equal. Take a look at this version :
Complete code
using RDatasets
using LinearAlgebra
using BenchmarkTools
using StatsBase
# The dataset:
const colon = let
df = dataset("survival","colon")
df = dropmissing(df, [:Nodes, :Differ])
T = Float64.(df.Time)
Δ = Int64.(df.Status)
X = Float64.(hcat(df.Sex, df.Age, df.Obstruct,
df.Perfor, df.Adhere, df.Nodes,
df.Differ, df.Extent, df.Surg,
df.Node4, df.EType))
(T,Δ,X)
end
struct Cox
X::Matrix{Float64}
T::Vector{Float64}
Δ::Vector{Bool}
sX::Vector{Float64}
G::Vector{Float64}
η::Vector{Float64}
B::Vector{Float64}
K::Vector{Int64}
_A::Vector{Float64}
_B::Vector{Float64}
function Cox(T,Δ,X)
# Allocate:
n,m = size(X)
G, B = zeros(m), zeros(m)
η, _A, _B = zeros(n), zeros(n), zeros(n)
# Precompute a few things:
# This should also be optimized, its taking a lot of time right now..
o = sortperm(T)
To = T[o]
Δo = Bool.(Δ[o])
Xo = X[o,:]
sX = X'Δ
# Indexing hell for I,J and now K
Ro = StatsBase.competerank(To)
I = Ro[Δo]
J = zeros(Int64,n)
for rⱼ in I
J[rⱼ] += 1
end
for j in 2:n
J[j] += J[j-1]
end
K = Int64[]
Jₖ₋₁ = 0
for Jₖ in J
push!(K,length((Jₖ₋₁+1):Jₖ))
Jₖ₋₁ = Jₖ
end
# Compute hessian bounds:
for l in 1:m # for each dimension.
lastj = n+1
Mₓ = Xo[end,l]
mₓ = Xo[end,l]
Bₗ = 0.0
for i in n:-1:1
for j in Ro[i]:(lastj-1)
Mₓ = max(Xo[j,l], Mₓ)
mₓ = min(Xo[j,l], mₓ)
end
lastj = Ro[i]
Bₗ += (1/4) * Δ[i] * (Mₓ-mₓ)^2
end
B[l] = Bₗ
end
# Instantiate:
new(Xo, To, Δo, sX, G, η, B, K, _A, _B)
end
end
function most_of_my_runtime!(A,B,η,K)
aₖ = bₖ = cₖ = zero(eltype(A))
n = length(η)
@inbounds for k in n:-1:1
aₖ = exp(η[k])
A[k] = aₖ
bₖ += aₖ
B[k] = bₖ
end
@inbounds for (k, (nₖ, bₖ)) in enumerate(zip(K,B))
cₖ += nₖ / bₖ
A[k] *= cₖ
end
end
function most_of_my_runtime_V2!(A,B,η,K)
######### Equivalent to :
#### A .= exp.(η)
#### B .= reverse(cumsum(reverse(A)))
#### A .*= cumsum(K ./ B)
####
#### Remark that K is half full of zeros
################################
aₖ = bₖ = cₖ = zero(eltype(A))
n = length(η)
@inbounds for k in n:-1:1
aₖ = exp(η[k])
A[k] = aₖ
bₖ += aₖ
B[k] = bₖ
end
@inbounds for (k, (nₖ, bₖ)) in enumerate(zip(K,B))
(nₖ > 0) && (cₖ += nₖ / bₖ)
A[k] *= cₖ
end
end
function update!(β, M::Cox)
mul!(M.η, M.X, β)
most_of_my_runtime_V2!(M._A, M._B, M.η, M.K)
mul!(M.G, M.X', M._A)
M.G .= M.sX .- M.G
β .+= M.G ./ M.B
return nothing
end
function getβ(M::Cox; max_iter = 10000, tol = 1e-6)
β = zeros(size(M.X, 2))
βᵢ = similar(β)
for i in 1:max_iter
βᵢ .= β
update!(β, M)
gap = norm(βᵢ - β)
# println(gap)
if gap < tol
break
end
end
return β
end
M = Cox(colon...)
getβ(M)
@profview [getβ(M) for i in 1:100]
Maybe some better structure for K
could be found since 60% of the n\_k
are actually zeros..