I was using NLsolve and couldn’t get type stability in my model no matter how. I’m new to Julia, so I thought I wasn’t fully understanding how type stability works. But I think this is part of a bug that dates from at least 2016 (see #15276).
Just wanted to confirm it, since it basically determines that the order and place where you put functions is crucial for type stability.
If it’s not a bug, I’d appreciate whether you could explain to me how I should think about it (in particular how to think about CASE 1 and CASE 2 below, which completely puzzles me!!!).
I read the documentation, but in my examples I’m only changing the order of the functions. Many thanks!!!
I start from the issue, and then show the consequences for NLsolve.
All cases provide the same results (tested in Julia 1.8 and NLsolve v4.5.1)“”"
### NESTED FUNCTIONS - all cases provide the same result
# CASE 1 - type unstable and hence slow
function result1(zz::Int64)
function intermediate_result1(xx::Int64)
yy = xx + 1
ff1(yy)
end
ff1(yy) = yy + 2
intermediate_result1(zz)
end
@code_warntype result1(2)
@btime result1(2)
# CASE 2 - type stable and hence faster
#just changing the order of the functions (weird!!!)
function result2(zz::Int64)
ff2(yy) = yy + 2
function intermediate_result2(xx::Int64)
yy = xx + 1
ff2(yy)
end
intermediate_result2(zz)
end
@code_warntype result2(2)
@btime result2(2)
# CASE 3 - type stable and hence similar to CASE 2
#putting function ff outside
ff3(yy) = yy + 2
function result3(zz::Int64)
function intermediate_result3(xx::Int64)
yy = xx + 1
ff3(yy)
end
intermediate_result3(zz)
end
@code_warntype result3(2)
@btime result3(2)
# CASE 3b - type stable and hence similar to CASE 2
#defining function after behaves like CASE 3
function result3b(zz::Int64)
function intermediate_result3b(xx::Int64)
yy = xx + 1
ff3b(yy)
end
intermediate_result3b(zz)
end
ff3b(yy) = yy + 2
@code_warntype result3b(2)
@btime result3b(2)
# CASE 4 - type stable and hence similar to CASE 2
# adding ff as an argument of the inner function
function result4(zz::Int64)
function intermediate_result4(ff4,xx::Int64)
yy = xx + 1
ff4(yy)
end
ff4(yy) = yy + 2
intermediate_result4(ff4,zz)
end
@code_warntype result4(2)
@btime result4(2)
What confuses me more is that if you wrap CASE 3 into function execute ()
and then run it, it becomes even slower than CASE 1 (while CASES 2 and 4 are not affected).
As for NLsolve
: for my model, I have to solve several systems of equations that are nested. The model is quite long and hard to reproduce, so I provide an extremely simple example.
I even added specific numbers, so you can see the problem more clearly. If it’s not a bug, I think it’d be really useful if this behavior is documented in NLsolve.
using NLsolve, BenchmarkTools
# NLsolve - all cases provide the same result
##CASE 1 - type unstable and hence slow
function ff1(xx2)
function solving1!(res,sol; xx2=xx2)
yy1 = sol[1:2]
yy2 = sol[3:4]
res[1:2] = yy1 .- xx2
res[3:4] = yy2 .- xx2
return res
end
sol0 = [ [0.99,0.99]; [0.99,0.99]]
solution = nlsolve(solving1!, sol0)
yy1 = solution.zero[1:2]
yy2 = solution.zero[3:4]
(yy1,yy2)
end
@code_warntype ff1([2.0, 3.0])
@btime ff1([2.0, 3.0])
## CASE 2 - type stable and hence faster
# we only add `let` and `end` (or wrap the second part into a function)
function ff2(xx2)
function solving2!(res,sol; xx2=xx2)
yy1 = sol[1:2]
yy2 = sol[3:4]
res[1:2] = yy1 .- xx2
res[3:4] = yy2 .- xx2
return res
end
let # or `function execute()`
sol0 = [ [0.99,0.99]; [0.99,0.99]]
solution = nlsolve(solving2!, sol0)
yy1 = solution.zero[1:2]
yy2 = solution.zero[3:4]
(yy1,yy2)
end
end
@code_warntype ff2([2.0, 3.0])
@btime ff2([2.0, 3.0])
#CASE 3 - type stable and hence similar to CASE 2
# we define `solving` out of the function
function input(xx2)
function solving3!(res,sol;parameter=xx2)
yy1 = sol[1:2]
yy2 = sol[3:4]
res[1:2] = yy1 .- xx2
res[3:4] = yy2 .- xx2
return res
end
end
function ff3(xx2)
sol0 = [ [0.99,0.99]; [0.99,0.99]]
solution = nlsolve(input(xx2), sol0)
yy1 = solution.zero[1:2]
yy2 = solution.zero[1+2 : 2*2]
(yy1,yy2)
end
@code_warntype ff3([2.0, 3.0])
@btime ff3([2.0, 3.0])