Hello, I am trying to conduct a large-scale Maximum Likelihood Estimation. Luckily, I have analytical gradients and Hessians and therefore I can speed up my optimization routine by using derivative based methods.
However, my gradient and my Hessian are very slow to compute. I would appreciate very much if you could provide tips on how to improve things on this front. See MWE below
The llike
function has been improved here Improving Performance of a Loop . However, the gradient and the Hessian are still annoyingly slow.
EDIT: I was able to improve the performance of the gradient and updated the MWE. The hess
function is still very slow. My guess is that this is because of the outer product.
using Random, Optim
using LinearAlgebra
using BenchmarkTools
###############################################################
# Loglikelihood
###############################################################
const J = 75
const Ι = 200
const N = 10000
const Mat1 = ones(Int8,Ι,N)
const Mat2 = ones(Int8,J,N)
const Mat3 = rand(Ι,J,N) .>0.9
NumberParams = Ι
const τ = rand(Ι,J)
#NumberParams = 2
# Independent Variables Matrix
XMat = zeros(Ι+1,Ι,J)
for j = 1 : J
XMat[:,:,j] = vcat(τ[:,j]',Matrix(I,Ι,Ι))
end
XMat = XMat[1:end-1,:,:]
@views function Log_Likelihood(N,Ι,J,τ,Mat1,Mat2,Mat3,params)
llike = 0.0
θ = params[1]
γ = params[2:end]
γ = vcat(γ,1.0)
log_vals = -θ.*τ .+ γ
vals = exp.(log_vals)
cond = sum(Mat1,dims=1)'
@fastmath @inbounds for n = 1 : N
if cond[n]>1
for j = 1 : J
denom_index = 0.0
if Mat2[j,n] == 1
denom_index = Mat1[:,n]'*vals[:,j]
for i = 1 : Ι
llike += Mat3[i,j,n]*Mat2[j,n]*Mat1[i,n]*(log_vals[i,j] - log(denom_index))
#llike_test[i,j,n] = Mat3[i,j,n]*Mat2[j,n]*Mat1[i,n]*(log_vals[i,j] - log(denom_index))
end
end
end
end
end
llike = -llike
return llike
end
@views function grad(N,Ι,J,τ,Mat1,Mat2,Mat3,xvec,storage,params)
θ = params[1]
γ = params[2:end]
γ = vcat(γ,1.0)
vec = 0 .*ones(size(params))
log_vals = θ.*τ .+ γ
vals = exp.(log_vals)
cond = sum(Mat1,dims=1)'
@fastmath @inbounds for n = 1 : N
if cond[n]>1
for j = 1 : J
denom_index = 0.0
if Mat2[j,n] == 1
denom_index = Mat1[:,n]'*vals[:,j]
for i = 1 : Ι
vec += (Mat2[j,n]*Mat1[i,n]*(Mat3[i,j,n].-vals[i,j]./denom_index).*xvec[:,i,j])
end
end
end
end
end
storage .= -vec
return storage
end
@views function hess(N,Ι,J,τ,Mat1,Mat2,Mat3,xvec,storage,params)
θ = params[1]
γ = params[2:end]
γ = vcat(γ,1.0)
mat = 0 .*ones(size(params)[1],size(params)[1])
log_vals = θ.*τ .+ γ
vals = exp.(log_vals)
cond = sum(Mat1,dims=1)'
pvals = zeros(Ι,J,N)
@fastmath @inbounds for n = 1 : N
if cond[n]>1
for j = 1 : J
denom_index = 0.0
if Mat2[j,n] == 1
denom_index = Mat1[:,n]'*vals[:,j]
for i = 1 : Ι
pvals[i,j,n] = Mat2[j,n]*Mat1[i,n]*vals[i,j]/denom_index
end
end
end
end
end
@fastmath @inbounds for n = 1 : N
if cond[n]>1
for j = 1 : J
if Mat2[j,n] == 1
xbar = xvec[:,:,j]*pvals[:,j,n]
for i = 1 : Ι
mat += (Mat2[j,n]*Mat1[i,n]*pvals[i,j,n].*(xvec[:,i,j]-xbar)*(xvec[:,i,j]-xbar)')
end
end
end
end
end
storage .= mat
return storage
end
f(x) = Log_Likelihood(N,Ι,J,τ,Mat1,Mat2,Mat3,x)
g!(storage,x) = grad(N,Ι,J,τ,Mat1,Mat2,Mat3,XMat,storage,x)
h!(storage,x) = hess(N,Ι,J,τ,Mat1,Mat2,Mat3,XMat,storage,x)
XXX = @time Log_Likelihood(N,Ι,J,τ,Mat1,Mat2,Mat3,vcat(7.0,ones(NumberParams-1)))
XXX = @time grad(N,Ι,J,τ,Mat1,Mat2,Mat3,zeros(NumberParams),vcat(7.0,ones(NumberParams-1)))
XXX = @time hess(N,Ι,J,τ,Mat1,Mat2,Mat3,zeros(NumberParams,NumberParams),vcat(7.0,ones(NumberParams-1)))
func = TwiceDifferentiable(vars -> Log_Likelihood(N,Ι,J,τ,Mat1,Mat2,Mat3,vars[1:NumberParams]),
vcat(7.0,ones(NumberParams-1)); autodiff=:forward);
#opt1 = @time Optim.optimize(func, vcat(0.0,0.0),show_trace = true)
opt1 = @time Optim.optimize(func, vcat(7.0,ones(NumberParams-1)),show_trace = true)
#opt1 = @time Optim.optimize(func, vcat(0.0,0.0),show_trace = true)
opt1 = @time Optim.optimize(func, ones(NumberParams),show_trace = true)
#opt_explicit = @time Optim.optimize(func_explicit, ones(NumberParams), Optim.Newton())
opt_explicit = Optim.optimize(f, g!, h!, ones(NumberParams), Optim.Newton(),
Optim.Options(show_trace=true, iterations = 10))