You are right. Here is a full example, where I highlighted as a function the sensitive operation (eating 90% of my runtime). Its a mix between the code we were looking at here and the one from the previous thread, where i ended up using your solution.
using RDatasets
using LinearAlgebra
using BenchmarkTools
using StatsBase
# The dataset:
const colon = let
df = dataset("survival","colon")
df = dropmissing(df, [:Nodes, :Differ])
T = Float64.(df.Time)
Δ = Int64.(df.Status)
X = Float64.(hcat(df.Sex, df.Age, df.Obstruct,
df.Perfor, df.Adhere, df.Nodes,
df.Differ, df.Extent, df.Surg,
df.Node4, df.EType))
(T,Δ,X)
end
struct Cox
X::Matrix{Float64}
T::Vector{Float64}
Δ::Vector{Bool}
sX::Vector{Float64}
G::Vector{Float64}
η::Vector{Float64}
B::Vector{Float64}
I::Vector{Int64}
J::Vector{Int64}
_A::Vector{Float64}
_B::Vector{Float64}
function Cox(T,Δ,X)
# Allocate:
n,m = size(X)
G, B = zeros(m), zeros(m)
η, _A, _B = zeros(n), zeros(n), zeros(n)
# Precompute a few things:
# This should also be optimized, its taking a lot of time right now..
o = sortperm(T)
To = T[o]
Δo = Bool.(Δ[o])
Xo = X[o,:]
sX = X'Δ
Ro = StatsBase.competerank(To)
I = Ro[Δo]
# J is a form of ranks, filtered on Δ. Did not really found a function to do it.
J = zeros(Int64,n)
for j in 1:n
if Δo[j]
for i in 1:n
J[i] += (To[i] >= To[j])
end
end
end
# Compute hessian bounds:
for l in 1:m # for each dimension.
lastj = n+1
Mₓ = Xo[end,l]
mₓ = Xo[end,l]
for i in n:-1:1
for j in lastj-1:-1:Ro[i]
Mₓ = max(Xo[j,l], Mₓ)
mₓ = min(Xo[j,l], mₓ)
end
lastj = Ro[i]
B[l] += (1/4) * Δ[i] * (Mₓ-mₓ)^2
end
end
# Instantiate:
new(Xo, To, Δo, sX, G, η, B, I, J, _A, _B)
end
end
function most_of_my_runtime!(A,B,η,I,J)
# This is equivalent to :
# A .= exp.(η)
# B .= 1 ./ reverse(cumsum(reverse(A)))
# A .*= cumsum(B[I])[J]
# But 40% faster for the moment.
aₖ = bₖ = cₖ = zero(eltype(A))
n, lastj = length(η), 0
local j
@inbounds for k in n:-1:1
aₖ = exp(η[k])
A[k] = aₖ
bₖ += aₖ
B[k] = bₖ
end
@inbounds for k in 1:n
for outer j in lastj+1:J[k]
cₖ += inv(B[I[j]])
end
lastj = j
A[k] *= cₖ
end
end
function update!(β, M::Cox)
mul!(M.η, M.X, β)
most_of_my_runtime!(M._A, M._B, M.η, M.I, M.J)
mul!(M.G, M.X', M._A)
M.G .= M.sX .- M.G
β .+= M.G ./ M.B
return nothing
end
function getβ(M::Cox; max_iter = 10000, tol = 1e-6)
β = zeros(size(M.X, 2))
βᵢ = similar(β)
for i in 1:max_iter
βᵢ .= β
update!(β, M)
gap = norm(βᵢ - β)
# println(gap)
if gap < tol
break
end
end
return β
end
M = Cox(colon...)
getβ(M)
@profview [getβ(M) for i in 1:100]
The profview I have looks like:
I thought about trying to intertwine the two loops, but i was not able to find a better version yet.