I am doing very simple calculations in loops. It looks like the array extraction is the most expensive part of this. Is there an obvious way to speed it up? Here is a MWE of what I am doing. Any help would be greatly appreciated!
Is1
and Is2
are matrices of indices, and Pis1
and Pis2
are matrices of probabilities. I compute the following
Threads.@threads for iA = 1:Nstates
@simd for iy = 1:nY
@simd for iP = 1:nP
@simd for ip = 1:np
# relevant continuation value
I1, PI1, I2, PI2 = Is1[iy, iP, ip, iA], Pis1[iy, iP, ip, iA], Is2[iy, iP, ip, iA], Pis2[iy, iP, ip, iA]
Xcont = compute_cont(X, I1, PI1, I2, PI2 );
# then I use Xcont in other computations but the above line is what is slow
end
end
end
end
function compute_cont(X, I1, I2, PI1, PI2)
I1_next, I2_next = min(nY, I1+1), min(nP, I2+1);
output = @views X[:, :, I1, I2, :] * PI1 * PI2 + X[:, :, I1_next, I2, :] * (1 - PI1) * PI2
+ X[:, :, I1, I2_next, :] * PI1 * (1 - PI2) + X[:, :, I1_next, I2_next, :] * (1 - PI1) * (1 - PI2);
return output
end