Suggestions needed: speed up optimization

I need to maximize a likelihood function which has 18 parameters using a large sample with 1 million observations. The objective function is defined as

function loglike(θ, x, y, R, c, d, p0, B)
    prob = [choice_prob_naive(θ, r[i], x[i,:], y[i,:], R[i], c[i], d[i], p0[i])[1] for i = 1:size(B,1)]
    negloglike = - (B'*log.(prob) + (1 .- B)'*log.(1 .- prob))
    return negloglike
end

where each element of prob is a choice probability for an observation given parameter values in \theta and is computed based on numerical integration. The optimization tool I use is

optimize(θ -> loglike(θ, x, y, R, c, d, p0, B), theta_initial, BFGS(), Optim.Options(g_tol = 1e-3, iterations=100_000, show_trace=true, show_every=5))

When I use 1% of my sample to run this program, it took more than 1 hour and I got the following:

Iter     Function value   Gradient norm 
     0     3.056383e+03     3.681269e+05
 * time: 0.0
     5     2.445321e+03     3.724024e+05
 * time: 1194.0
    10     2.232518e+03     2.865856e+05
 * time: 1794.3980000019073
    15     2.060120e+03     1.151448e+05
 * time: 2050.8810000419617
    20     1.968748e+03     3.514386e+04
 * time: 2198.8180000782013
    25     1.962631e+03     5.452261e+03
 * time: 2389.8360002040863
    30     1.961778e+03     2.888854e+03
 * time: 2610.739000082016
    35     1.961533e+03     3.544448e+03
 * time: 2861.231000185013
    40     1.960887e+03     2.669545e+01
 * time: 3069.6550002098083
    45     1.960862e+03     2.172777e+02
 * time: 3338.9970002174377
    50     1.960149e+03     1.045716e+03
 * time: 3533.6230001449585
    55     1.959932e+03     7.975524e+02
 * time: 3687.270000219345
    60     1.959879e+03     1.252292e+02
 * time: 3870.1520001888275
    65     1.959872e+03     1.338831e+01
 * time: 4081.643000125885
    70     1.959872e+03     2.457941e-02
 * time: 4305.37700009346
4323.804038 seconds (153.88 G allocations: 2.822 TiB, 9.57% gc time)
 * Status: failure (objective increased between iterations) (line search failed)
 * Candidate solution
    Minimizer: [-1.15e-03, -1.24e-01, -1.06e+00,  ...]
    Minimum:   1.959872e+03

 * Found with
    Algorithm:     BFGS
    Initial Point: [1.00e-03, 2.00e-02, 1.00e-01,  ...]

 * Convergence measures
    |x - x'|               = 9.18e-04 ≰ 0.0e+00
    |x - x'|/|x'|          = 1.11e-04 ≰ 0.0e+00
    |f(x) - f(x')|         = 8.07e-10 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 4.12e-13 ≰ 0.0e+00
    |g(x)|                 = 2.46e-02 ≰ 1.0e-05

 * Work counters
    Seconds run:   4305  (vs limit Inf)
    Iterations:    70
    f(x) calls:    266
    ∇f(x) calls:   266

This result looks intimidating and I don’t even want to try the full sample with the same code. Could you share your suggestions for how to speed up the optimization to make it feasible? (e.g., the choice of optimization algorithm, parallel computing, use of GPU, buy a better CPU, use the university’s supercomputer, etc.)

How much faster is it if you have

negloglike = -.(B'*log.(prob) .+ (1 .- B)'*log.(1 .- prob))

The extra 2 dots will prevent 2 copies of your data. It also may be worth it to write this as

logp = log.(prob)
negloglike = -.(B'*logp.+ (1 .- B)'*(1./logp))

This saves calculating a lot of logs, at the expense of some memory.

Thanks! Why adding dots to scaler operations can speed it up?

For the last equation, why 1./logp?

Wow, never mind on both of those. My brain is not working. I forgot that B'x is a scalar, and how log rules work.

Can you use automatic differentiation?

Does this require the objective function to be differentiable? In my case, the choice_prob_naive function is not differentiable and that’s why I had to use numerical integration.

Check here:

Basically, just add autodiff=:forward to your call to optimize.

Also, could there be unexploited optimizations to the function choice_prob_naive?
Something that looks suspicious is that you have to extract the first element of its return value.

1 Like

If I’m not mistaken, you are implicitly using finite differentiation, I’m not sure that works with non-differentiable functions. Have you tried a derivative free method?

1 Like

No, I haven’t tried other methods yet. I am quite uneducated about numerical methods.

The reason that I extract the first element of returned value from choice_prob_naive is because the last line in the function:

(expectation, err) = quadgk(ϵ -> integrand(ϵ), -4, 4, rtol = 0.001)

I’d suggest providing an MWE would help.

This should be close to a MWE:

function choice_prob_naive(θ, r, x, y, R, c, d, p0)
    α = θ[1]
    βX = θ[2:5]
    βY = θ[6:10]
    γ0 = θ[11]
    γ1 = θ[12:16]
    sqrtτ = θ[17]
    logκ = θ[18]
    s(ϵ1) = d + 1/sqrtτ * ϵ1                                 
    p(ϵ1) = p0/(p0 + (1 - p0) * exp((0.5 - s(ϵ1))*sqrtτ^2))           
    Eu(ϵ1)  = α*r + x'*βX + y'*βY + 0.05*α*R*r - ((γ0 + y'*γ1)*(1 - c) + 0.05*α*(1 + r - c))*p(ϵ1) - (logκ - 1)        # risk-retention status included 
    integrand(ϵ1) = normcdf(0, 1, Eu(ϵ1))*(1/sqrt(2*π)*exp(-(ϵ1^2)/2)) 
    (expectation, err) = quadgk(ϵ -> integrand(ϵ), -4, 4, rtol = 0.001)
    return expectation
end

function loglike(θ, x, y, R, c, d, p0, B)
    prob = zeros(size(B,1))
    Threads.@threads for i = 1:size(B,1)    # parallel this loop: 2x faster 
        prob[i] = choice_prob_naive(θ, r[i], x[i,:], y[i,:], R[i], c[i], d[i], p0[i])
    end
    negloglike = - (B'*log.(prob) + (1 .- B)'*log.(1 .- prob))
    return negloglike
end

alpha = 0.001;
betaX = [0.02; 0.1;0.00001;0.02];
betaY = [0.00005; -0.01; -0.1; 0.25; 0.11] 
gamma0 = 0.05;
gamma1 = [0.1; -0.05; -0.0; -0.1; -0.1];
tau = 1;
kappa = 2;
theta = [alpha; betaX; betaY; gamma0; gamma1; tau; kappa];

@time theta_naive = optimize(θ -> loglike(θ, x, y, R, c, d, p0, B), theta, BFGS(), Optim.Options(g_tol = 1e-2, iterations=100_000, show_trace=true, show_every=1))

A 100 observation sample can be loaded from below:

string_representation = String(take!(CSV.write(io, convert(DataFrame, data_df))))
"x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,x14,x15,x16\n750.0,4.5,0.0,460.0,6.0,187.56826015,10.0,1.0,0.14,0.06,0.0,0.5,0.0,1.0,0.4113,0.0\n750.0,4.5,0.0,460.0,6.0,183.35580705,10.0,1.0,0.15,0.12,0.0,0.5,0.0,1.0,0.4113,0.0\n750.0,4.5,0.0,460.0,6.0,236.26641334,3.0,1.0,0.2,0.18,0.0,0.5,0.0,1.0,0.4113,0.0\n750.0,4.5,0.0,460.0,6.0,40.86294988,11.0,1.0,0.17,0.26,0.0,0.5,0.0,1.0,0.4113,0.0\n750.0,4.5,0.0,460.0,6.0,511.1,4.0,0.0,0.09,0.0,0.0,0.5,0.0,1.0,0.4113,0.0\n750.0,4.5,0.0,460.0,6.0,555.5,2.0,0.0,0.04,0.01,0.0,0.5,0.0,1.0,0.4113,0.0\n750.0,4.5,0.0,460.0,6.0,201.18404215,10.0,1.0,0.1,0.0,0.0,0.5,0.0,1.0,0.4113,0.0\n750.0,4.5,0.0,460.0,6.0,137.13511872,10.0,1.0,0.08,0.07,0.0,0.5,0.0,1.0,0.4113,0.0\n750.0,4.5,0.0,460.0,6.0,90.0,4.0,0.0,0.06,0.0,0.0,0.5,0.23023789,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,513.4,2.0,0.0,0.04,0.0,0.0,0.5,0.24936197,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,510.5,2.0,0.0,0.03,0.0,0.0,0.5,0.2496394,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,512.8,2.0,0.0,0.03,0.0,0.0,0.5,0.24965397,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,515.7,2.0,0.0,0.03,0.0,0.0,0.5,0.25004764,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,559.9,3.0,0.0,0.05,0.0,0.0,0.5,0.27539542,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,130.5,3.0,0.0,0.06,0.0,0.0,0.5,0.34763562,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,828.59,3.0,0.0,0.06,0.01,0.0,0.5,0.40125603,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,521.0,4.0,0.0,0.06,0.0,0.0,0.5,0.41493814,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,723.15,3.0,0.0,0.06,0.01,0.0,0.5,0.47624594,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,415.0,3.0,0.0,0.05,0.01,0.0,0.5,0.5,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,516.6,4.0,0.0,0.07,0.0,0.0,0.5,0.55668277,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,304.75332932,2.0,0.0,0.01,0.0,0.0,0.5,0.58064516,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,307.0,2.0,0.0,0.01,0.0,0.0,0.5,0.58064516,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,302.5,1.0,0.0,0.0,0.0,0.0,0.5,0.58064516,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,362.5,3.0,0.0,0.03,0.0,0.0,0.5,0.67741935,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,614.15,2.0,0.0,0.04,0.0,0.0,0.5,0.83622274,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,407.9,1.0,0.0,0.03,0.0,0.0,0.5,1.0,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,409.0,2.0,0.0,0.07,0.0,0.0,0.5,1.0,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,333.75,3.0,0.0,0.1,0.01,0.0,0.5,1.0,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,309.9,3.0,0.0,0.09,0.01,0.0,0.5,1.0,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,626.5,3.0,0.0,0.07,0.01,0.0,0.5,1.1440563,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,402.15,3.0,0.0,0.05,0.01,0.0,0.5,1.33333333,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,406.16666669,3.0,1.0,0.1,0.01,0.0,0.5,1.33333333,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,410.18333334,3.0,0.0,0.08,0.01,0.0,0.5,1.33333334,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,340.1,4.0,0.0,0.07,0.0,0.0,0.5,2.0,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,396.6,2.0,0.0,0.0,0.0,0.0,0.5,2.1,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,44.1,2.0,0.0,0.08,0.0,0.0,0.5,3.0,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,586.5,0.0,0.0,0.0,0.0,1.0,0.5,3.1625,1.0,0.4113,1.0\n750.0,4.5,0.0,460.0,6.0,415.25,3.0,0.0,0.05,0.0,0.0,0.5,5.0,1.0,0.4113,1.0\n275.0,8.0,0.0,300.0,5.75,773.0,3.0,0.0,0.08,0.0,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,618.75,3.0,0.0,0.07,0.01,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,133.55,3.0,0.0,0.09,0.02,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,723.15,3.0,0.0,0.06,0.01,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,512.27,3.0,0.0,0.03,0.02,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,733.6,3.0,0.0,0.04,0.0,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,414.25,3.0,0.0,0.11,0.03,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,552.5,3.0,0.0,0.09,0.02,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,512.341,3.0,0.0,0.08,0.01,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,520.78,3.0,0.0,0.08,0.01,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,77.0,3.0,0.0,0.06,0.0,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,414.0,3.0,0.0,0.11,0.01,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,410.18333334,3.0,0.0,0.08,0.01,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,102.0,3.0,0.0,0.06,0.0,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,815.8,3.0,0.0,0.06,0.0,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,513.5,3.0,0.0,0.04,0.0,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,669.75,3.0,0.0,0.09,0.02,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,563.0,3.0,0.0,0.06,0.0,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,566.25,3.0,0.0,0.09,0.0,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,464.75,3.0,0.0,0.07,0.02,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,109.62853288,10.0,1.0,0.41,0.0,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,347.33289865,11.0,1.0,0.1,0.02,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,90.75260187,11.0,1.0,0.02,0.15,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,60.0,12.0,1.0,0.0,1.0,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,271.7199893,11.0,1.0,0.06,0.07,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,201.38320708,11.0,1.0,0.11,0.21,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,53.2,10.0,1.0,0.65,0.2,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,112.06261019,11.0,1.0,0.2,0.0,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,109.31681943,12.0,1.0,0.14,0.67,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,393.2905694,10.0,1.0,0.06,0.0,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,158.7234222,11.0,1.0,0.32,0.08,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,293.79396572,10.0,1.0,0.09,0.0,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,318.01622125,10.0,1.0,0.13,0.08,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,34.5,10.0,1.0,0.0,1.0,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,39.9,11.0,1.0,0.33,0.19,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,134.55481151,10.0,1.0,0.0,1.0,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,31.0,11.0,1.0,0.0,1.0,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,65.0,11.0,1.0,0.0,0.0,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,19.04811772,11.0,1.0,0.05,0.92,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,517.125,3.0,0.0,0.08,0.02,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,167.81082266,10.0,1.0,0.69,0.31,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,112.28100151,12.0,1.0,0.23,0.6,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,40.86294988,11.0,1.0,0.17,0.26,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,185.77341859,10.0,1.0,0.0,1.0,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,48.1611834,11.0,1.0,0.08,0.21,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,262.3,10.0,1.0,0.0,1.0,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,98.38058152,11.0,1.0,0.06,0.06,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,260.18489297,10.0,1.0,0.19,0.13,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,50.0,11.0,1.0,1.0,0.0,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,134.0700113,11.0,1.0,0.43,0.01,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,300.6707341,10.0,1.0,0.09,0.21,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,107.37034461,10.0,1.0,0.15,0.0,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,33.25,11.0,1.0,0.0,1.0,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,60.35539681,10.0,1.0,0.2,0.04,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,35.97804004,10.0,1.0,0.41,0.09,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,48.0,9.0,1.0,0.0,0.65,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,37.10459332,11.0,1.0,0.86,0.14,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,599.37,3.0,0.0,0.09,0.01,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,264.66931206,10.0,1.0,0.07,0.01,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,285.74218908,10.0,1.0,0.08,0.02,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,15.93387704,5.0,1.0,1.0,0.0,0.0,0.55,0.0,0.0,0.0511,0.0\n275.0,8.0,0.0,300.0,5.75,52.07135004,11.0,1.0,0.14,0.0,0.0,0.55,0.0,0.0,0.0511,0.0\n"

using

df = CSV.read(IOBuffer(string_representation));
r = df.x1;
x = [df.x2 df.x3 df.x4 df.x5];
y = [df.x6 df.x7 df.x8 df.x9 df.x10];
R = df.x11;
c = df.x12;
q = df.x13;
d = df.x14;
p0 = df.x15;
B = df.x16;

By the way, is it still possible to get the Hessian matrix in this case?

Replying to myself:
I made the program 10 times faster. What I did so far:

  1. make sure functions are type stable
  2. avoid using global variables. (in the loglike function above, r was not an input so it was global)
  3. multi-thread parallel loops

My current problem is when I try to find the Hessian matrix using

od = OnceDifferentiable(θ -> loglike(θ, r, x, y, R, c, d, p0, B), theta_naive_estimate; autodiff = :forward);
@time theta_naive = optimize(od, theta, BFGS(), Optim.Options(g_tol = 1e-2, iterations=100_000, show_trace=true, show_every=1))

I got

ERROR: TaskFailedException:
MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{var"#61#62",Float64},Float64,9})
Closest candidates are:
  Float64(::Real, ::RoundingMode) where T<:AbstractFloat at rounding.jl:200
  Float64(::T) where T<:Number at boot.jl:716
  Float64(::Float16) at float.jl:256

which I don’t know how to interpret…

I think it is this line, you are allocating for element type Float64 while Optim wants to do AD via ForwardDiff, which uses Dual:

In this particular case I would just use map or a comprehension.

2 Likes

Sorry I still don’t understand after taking a look at How ForwardDiff Works · ForwardDiff. Are you suggesting not pre-allocate prob in this case?

Yes.

If I remove that line, I will get

ERROR: TaskFailedException:
UndefVarError: prob not defined

Declaring argument types should do nothing for performance.

1 Like

Yeah, thanks for pointing this out! So I guess the declaration forced me to be careful with input/output types and this indirectly helped performance.

I learned this fact from reading disscussion ForwardDiff: Using on functions that are defined on floats64 - #3 by IljaK91, although I haven’t figured out how to make autodiff work yet :slight_smile:

I revised the function be based on comprehension, so there is no pre-allocation anymore:

function loglike(θ::Array{T,1}, r::Array{T,1}, x::Array{T,2}, y::Array{T,2}, R::Array{T,1}, c::Array{T,1}, d::Array{T,1}, p0::Array{T,1}, B::Array{T,1}) where {T <: Real}
    prob = [choice_prob_naive(θ, r[i], x[i,:], y[i,:], R[i], c[i], d[i], p0[i]) for i = 1:size(B,1)]
    negloglike = - (B'*log.(prob) + (1.0 .- B)'*log.(1.0 .- prob)) 
    return negloglike
end

Unfortunately I got a similar error as before.

1 Like

I think you are now bitten by the type declarations :slightly_smiling_face:. Specifically your data will be Float64, but the parameters (theta) will be Dual when using ForwardDiff.