[Review wanted] Adam gradient descent with complicated gradient to fit multivariate gamma convolutions

Hi,

I have written a code that fits a multivariate gamma convolution to a dataset using a stochastic gradient descent on a home-made loss. Most of the code is directed into the computation of the gradient, which is of combinatorial nature (and takes up 95% of the runtime). The last parts of the code performs the descent itself. I am currently trying to optimize the whole thing, and I come to you for advices on how to increase runtime performance of the main loop.

The procedure is quite complex, and I am using StaticArrays and LoopVectorization. This gives me troubles extracting a smaller MWE, i’m sorry about the lenght of the code. However, I have already profiled it a little and I have extracted the costly componants into their own functions :

project!(D,e,data,N,d)
exp_and_mean!(slack,D,N)
prod_and_mean!(slack,D,N)
turbo_aBC!(R, a, B, C)

which are regroupped into the second block of code. They account for 75-80% of the runtime. Comments in them resume what is their pruposes. We already discussed the first three functions here and there, and i have little hope that they could be optimised further, but if you see something please tell me ! :slight_smile:

The thrid block represents the core of the computational loop (computation of a gradient), and consists of three functions:

build_model!(s,data)
jac!(s)
loss_and_∇!(par,s)

These functions calls the 4 previous ones, and their own code account for the remainding runtime.

The rest of the code is only interfacing these core computations and performing the gradient descent, which does not take a significant amount of time. The fit! function describes the main loop.

I am not sure this is the right place to ask, but i did not see a review section in the forum. Feel free to try out the code (you’ll need the unregistered version of ThorinDistributions, see the code for the github link), and tell me what you think. Do not hesitate to ask if you want more details about why i did things this way or not.

Hope this will peak your interest ! :slight_smile:

The Full Code.
#########################################################################################################
####### 0) Packages.
#########################################################################################################
using Random, 
      Statistics,  
      ProgressMeter, 
      UnicodePlots,
      StaticArrays,
      HybridArrays, 
      LoopVectorization,
      LinearAlgebra
import ThorinDistributions as TD # add https://github.com/lrnv/ThorinDistributions.jl
#########################################################################################################
####### 1) StaticProjectionStorage structure. 
#########################################################################################################
struct StaticProjectionStorage{T,M,d,MM,dd}
    N::Int
    m::Int
    d::Int
    n::Int
    D::Vector{T}
    η::MVector{M,T}
    slack::Vector{T}
    e::MVector{d,T}
    τₑ::MVector{M,T}
    J::MMatrix{M,M,T,MM}
    P::TD.PreComp{T,M,MM}
    τ::MVector{M,T}
    ∇τ::Matrix{T}
    θ_pow::MVector{M,T}
    tmp::MVector{M,T}
    loss::MVector{1,T}
    ∇::Vector{T}
    inv_facs::SVector{M,T}
    std_data::SMatrix{d,d,T,dd}
    # Are static : η,e,τₑ,J,P,τ,θ_pow,tmp,loss,one_over. 
end
function StaticProjectionStorage(T,N,m,d,n,std_data)
    M        = m+1
    # large standards arrays
    D        = zeros(T,N)
    slack    = zeros(T,N)
    ∇        = zeros(T,(d+1)*n)
    ∇τ       = zeros(T,(M,(d+1)*n))

    # Small static arrays
    P        = TD.get_precomp(T,M) # Some factorials and binomials, to avoid recomputing them later. 
    J        = MMatrix{M,M}(zeros(T,(M,M)))
    η        = MVector{M}(zeros(T,M))
    τₑ       = MVector{M}(zeros(T,M))
    τ        = MVector{M}(zeros(T,M))
    θ_pow    = MVector{M}(zeros(T,M))
    tmp      = MVector{M}(zeros(T,M))
    e        = MVector{d}(zeros(T,d))
    loss     = MVector{1}(zeros(T,1))
    inv_facs = SVector{M}(T.(1 ./ factorial.(big.(0:(M-1)))))
    std_data = SMatrix{d,d,T,d*d}(std_data)
    return StaticProjectionStorage{T,M,d,M*M,d*d}(N,m,d,n,D,η,slack,e,τₑ,J,P,τ,∇τ,θ_pow,tmp,loss,∇,inv_facs,std_data)
end
#########################################################################################################
####### 2) The four functions that make up 80% of the runtime 
#########################################################################################################
@inline function project!(D,e,data,N,d) # <e,D>
    # this function should make :
        # D .= data'e
    # as efficiently as possible.

    zz = zero(eltype(D))
    @tturbo for i in 1:N # This loop eat 1/8th of my runtime. 
        Dᵢ = zz
        for j in 1:d
            Dᵢ += e[j]*data[j,i]
        end
        D[i] = Dᵢ
    end
    return nothing
end
@inline function exp_and_mean!(slack,D,N) #𝔼(eˣ)
    # this function should make : 
        # slack .= exp.(-D)
        # return mean(slack)
    # as efficiently as possible... 

    zz = zero(eltype(D))
    @tturbo for i in 1:N # This loop eats 1/8th of my runtime. 
        slack[i] = exp(-D[i])
        zz += slack[i]
    end
    return zz/N
end
@inline function prod_and_mean!(slack,D,N) #𝔼(Xᵏeˣ)
    # this function should make : 
        # slack .*= D
        # return mean(slack)
    # as efficiently as possible... 
    
    zz = zero(eltype(D))
    @tturbo for i in 1:N # This loop eat 3/8th of my runtime. 
        slack[i] *= D[i]
        zz += slack[i]
    end
    return zz/N
end
@inline function turbo_aBC!(R, a, B, C)
    # this function should make : 
        # R .= vec(2(a'B)*C)
    # as efficiently as possible.
    zz = zero(eltype(R))
    @tturbo for k in eachindex(R)
        Rₖ = zz
        for j in 1:size(B,2)
            tⱼ = zz
            for i in 1:length(a)
                tⱼ += a[i]*B[i,j]
            end
            Rₖ += tⱼ*C[j,k]
        end
        R[k] = 2Rₖ
    end
end
#########################################################################################################
####### 3) The three core functions that make up the rest of the runtime
#########################################################################################################
function build_model!(s,data)
    # Make all computations that are related to e
    # and that does not require the theta paraeters yet. 
    s.e ./= sqrt(s.e's.std_data*s.e)
    project!(s.D,s.e,data,s.N,s.d)
    s.η[1] = exp_and_mean!(s.slack,s.D,s.N) # this fills up the slack variable. 
    Cᵧ = 1/s.η[1]
    s.τₑ[1] = log(s.η[1])
    for k in 1:s.m # order does matter.
        s.η[k+1] = prod_and_mean!(s.slack,s.D,s.N) * Cᵧ * s.inv_facs[k+1] # this fills up the slack variable. 
        s.τₑ[k+1] = k*s.η[k+1]
        for j in 1:(k-1)  # order does not matter. 
            s.τₑ[k+1] -= s.τₑ[j+1] * s.η[k+1-j]
        end
    end
    jac!(s)
    return nothing
end
function jac!(s) # much faster than forwarddiff.
    # this function should make :
        # ForwardDiff.jacobian!(s.J,x -> TD.a_from_μ(TD.μ_from_κ(TD.κ_from_τ(x,s.P),s.P),s.P),s.τₑ)
    # as efficiently as possible.
    s.J[1,1] = s.tmp[1] = exp(s.τₑ[1])
    zz = zero(eltype(s.D))
    for k in 2:(s.m+1)  # order does matter
        #empty the storage:
        s.tmp[k] = zz
        @tturbo for p in 1:k # order does not matter...
            s.J[k,p] = zz
        end
        #compute the gradient: 
        for j in 1:k-1 # order does matter 
            Cₖⱼ = s.P.FACTS[k-j] * s.P.BINS[j, k-1]
            s.tmp[k] += Cₖⱼ * s.tmp[j] * s.τₑ[k-j+1]
            s.J[k,k-j+1] += Cₖⱼ * s.tmp[j]
            Cₖⱼ *= s.τₑ[k-j+1]
            @tturbo for p in 1:k # order does not matter
                s.J[k,p] += Cₖⱼ * s.J[j,p]
            end
        end
    end
    @tturbo s.J .= sqrt(2) .* (s.P.LAGUERRE's.J)
    return nothing
end
function loss_and_∇!(par,s)
    zz = zero(eltype(par))
    s.τ .= zz
    @turbo s.∇τ[1,:] .= zz
    
    # First part : Construct τ and ∇τ 
    for i in 1:s.n # order does not matter. 
        if par[i] > 0 # otherwise this is pointless to do. 
            α = par[i].^2
            b2e = zz
            @tturbo for j in 1:s.d # Order of instruction does not matter. 
                b2e += par[i+j*s.n]^2*s.e[j]
            end
            b2ep1 = b2e+1 # b^2, projected on e, plus one <-> b2ep1.
            nlb2ep1 = -log(b2ep1)  # negative log of <b^2,e>+1 
            θ = b2e/b2ep1
            Cᵢ = 2α / (b2ep1^2)
            s.θ_pow[1] = 1
            s.τ[1] += α*nlb2ep1
            s.∇τ[1,i] = 2par[i]*nlb2ep1
            for k in 1:s.m # order does matter
                # These 4 instuctions are supposed to be executed in that order,
                # although the two in the middle could be extracted in theire own loop
                # (so 3 loops at all, the two first could be @tturbo'ed. )
                s.θ_pow[k+1] = s.θ_pow[k]*θ
                s.∇τ[k+1,i] = 2par[i]*s.θ_pow[k+1]
                s.τ[k+1] += α * s.θ_pow[k+1]
                s.θ_pow[k] *= k # order of instructions matter, be carrefull. 
            end
            b2ep1 = -b2ep1 # additions are easier than substractions for the loop. 
            @tturbo for j in 1:s.d # order does not matter. 
                idx = i+j*s.n
                Cᵢⱼ = Cᵢ * par[idx] * s.e[j] 
                s.∇τ[1,idx] += b2ep1 * Cᵢⱼ
                for k in 1:s.m # order does not matter. 
                    s.∇τ[k+1,idx] = s.θ_pow[k] * Cᵢⱼ
                end
            end
        end
    end

    # Second part: construct the final gradient from it. 
    @tturbo s.tmp .= s.J*(s.τ-s.τₑ)
    s.loss[1] = sum(s.tmp.^2)
    # @tturbo s.∇ .= vec(2(s.tmp's.J)*s.∇τ)
    turbo_aBC!(s.∇,s.tmp,s.J,s.∇τ) 
    return nothing
end
#########################################################################################################
####### 4) Adam structure to handle stochatic gradient descent. 
#########################################################################################################
struct Adam{T}
    theta::Vector{T} # Parameter array 2
    m::Vector{T}     # First moment
    v::Vector{T}     # Second moment
    b1::T            # Exp. decay first moment
    b2::T            # Exp. decay second moment
    a::T             # Step size
    epsilon::T       # Epsilon for stability
    t::MVector{1,Int}           # Time step (iteration)

end
function Adam(theta::AbstractArray{T}) where T
    m       = zeros(T,size(theta))
    v       = zeros(T,size(theta))
    b1      = T(0.9)
    b2      = T(0.999)
    a       = T(0.001)
    epsilon = eps(T)
    t       = MVector{1}([0])
    Adam(theta, m, v, b1, b2, a, epsilon, t)
end
function step!(opt::Adam{T},grad) where T
    opt.t[1] += 1
    mfac = 1 / (1 - opt.b1^opt.t[1])
    vfac = 1 / (1 - opt.b2^opt.t[1])
    @turbo for i in eachindex(opt.theta)
        opt.m[i] = opt.b1 * opt.m[i] + (1-opt.b1) * grad[i]
        opt.v[i] = opt.b2 * opt.v[i] + (1-opt.b2) * grad[i]^2
        opt.theta[i] -= opt.a * (opt.m[i] * mfac) / sqrt(opt.v[i]*vfac + opt.epsilon)
    end
    return nothing
end
#########################################################################################################
####### 5) Problem structure and final fitting functions
#########################################################################################################
function build_p0(n,d)
    return randn(n*(d+1))
end
struct Problem{T,M,d,MM,dd}
    s::StaticProjectionStorage{T,M,d,MM,dd}
    data::HybridMatrix{d, StaticArrays.Dynamic(), T, 2, Matrix{T}}
    opt::Adam{T}
end
function Problem(par,data,m)
    T = promote_type(eltype(par),eltype(data))
    par = T.(par)
    data = T.(data)
    d,N = size(data)
    l = length(par)
    @assert l % (d+1) == 0
    n = Int(l//(d+1))
    data = HybridArray{Tuple{d,StaticArrays.Dynamic()}}(data)
    opt   = Adam(T.(par))
    Storage = StaticProjectionStorage(T,N,m,d,n,cov(data'))
    return Problem{T,m+1,d,(m+1)^2,d^2}(Storage,data,opt)
end
function fit!(P::Problem{T,M,d,MM,dd},epochs) where {T,M,d,MM,dd}
    for i_e = 1:epochs
        Random.rand!(P.s.e)          # Choose a random direction of projection. 
        build_model!(P.s,P.data)     # Construct the projected model (precomputations for the gradient)
        loss_and_∇!(P.opt.theta,P.s) # Evaluate the gradient and store it in P.s.∇
        step!(P.opt,P.s.∇)           # Perform the gradient decent step.
    end
    return nothing
end
function extract_rez(P::Problem{T,M,d,MM,dd}) where {T,M,d,MM,dd}
    n = Int(length(P.opt.theta)//(d+1))
    α = P.opt.theta[1:n].^2
    θ = reshape(P.opt.theta[(n+1):end].^2,(n,d))
    return TD.MultivariateGammaConvolution(α,θ)
end
#########################################################################################################
####### 6) Direct application. 
#########################################################################################################
# Construct a mockup dataset and launch the algorithm : 
d,N = 15, 10000;
Random.seed!(13);
data= reshape(exp.(randn(d*N)),(d,N)); # Ofc, this is mockup data. 
p0 = randn(100*(d+1));
P = Problem(p0,data,20);
fit!(P,10000);
@info "Loss: $(P.s.loss[1])"
rez = extract_rez(P);
#########################################################################################################
####### 7) Simple Profiling. 
######################################################################################################### 
using BenchmarkTools
@btime fit!(P,1) # 60μs and 2 allocations per step. 

using ProfileView
ProfileView.@profview fit!(P,100_000) # launch it twice and discard the first run. 

That sounds suspicious. With reverse-mode / adjoint differentiation (either automatic or manual), you should always be able to compute the gradient in with a computational cost comparable to that of the loss function. (Essentially, make sure you are evaluating the chain rule from left-to-right and not from right-to-left.)

Not sure why you are using stochastic gradient descent, as opposed to a deterministic minimization algorithm, unless you are sampling a huge dataset in random batches during fitting?

In fact I compute both of them at the same time, this time includes both. I think I’m indeed doing it in the right order, but i’ll read up the ref, thanks.

Well, I have a stochastic componant in my loss (the loss is written as an expectation on the random variable s.e, that the function fit!() randomises at each iteration). This is indeed a stochastic gradient descent.