This is probably expected behaviour but I don’t understand quite why it’s happening, and suspect it’d be useful to know.
As part of a small library, I’ve been using Symbolics.jl to take user-supplied expressions, compile them to Julia functions, and then use these for further processing and computations. This mostly works well so far and the functions produced from build_function()
are type stable. However, if I then pass that function as an argument to a different function and call it, or even wrap it in another function, then the type stability seems to be lost. The performance within these calculations is apparently fine though: I’ve tested on a number of reasonable examples. MWE at end.
So my question(s) are:
i. Why is the type stability lost? I’m guessing when Julia infers types it cannot ‘see’ what the RuntimeGeneratedFunction is doing but that’s quite woolly and I don’t think I’ve understood properly what’s going on.
ii. Can this be (easily) remedied?
iii. Since the benchmarking looks fine, can I just annotate (assert?) types somewhere sensible so it doesn’t propagate Any throughout the rest of the execution? Due to how the rest of the code is structured, I know what types should be imparted.
Thanks in advance
MWE
julia> using RuntimeGeneratedFunctions, SymbolicUtils, Symbolics
julia> @variables x[1:5]
1-element Vector{Symbolics.Arr{Num, 1}}:
x[1:5]
julia> w = Symbolics.scalarize(x)
5-element Vector{Num}:
x[1]
x[2]
x[3]
x[4]
x[5]
julia> test_num = (x[1] - x[2])^3 + (x[3]*x[4]*x[5] - x[1]*x[2])^5
(x[3]*x[4]*x[5] - x[1]*x[2])^5 + (x[1] - x[2])^3
julia> test_fn = build_function(test_num, w; expression=Val{false})
RuntimeGeneratedFunction(#=in Symbolics=#, #=using Symbolics=#, :((ˍ₋arg1,)->begin
#= C:\Users\peterma\.julia\packages\SymbolicUtils\qulQp\src\code.jl:349 =#
#= C:\Users\peterma\.julia\packages\SymbolicUtils\qulQp\src\code.jl:350 =#
#= C:\Users\peterma\.julia\packages\SymbolicUtils\qulQp\src\code.jl:351 =#
begin
(+)((^)((+)((*)((*)(-1, (getindex)(ˍ₋arg1, 1)), (getindex)(ˍ₋arg1, 2)), (*)((*)((getindex)(ˍ₋arg1, 3), (getindex)(ˍ₋arg1, 4)), (getindex)(ˍ₋arg1, 5))), 5), (^)((+)((*)(-1, (getindex)(ˍ₋arg1, 2)), (getindex)(ˍ₋arg1, 1)), 3))
end
end))
julia> x0 = [ 1.0, 1.1, 1.2, 1.3, 1.4 ]
5-element Vector{Float64}:
1.0
1.1
1.2
1.3
1.4
julia> test_fn(x0)
1.4957401577994214
julia> @code_warntype test_fn(x0)
MethodInstance for (::RuntimeGeneratedFunction{(:ˍ₋arg1,), Symbolics.var"#_RGF_ModTag", Symbolics.var"#_RGF_ModTag", (0xab7cee3c, 0x63779d06, 0x2ab0d9ad, 0x964c6cc8, 0x3eff94b6)})(::Vector{Float64})
from (f::RuntimeGeneratedFunction)(args::Vararg{Any, N}) where N in RuntimeGeneratedFunctions at C:\Users\peterma\.julia\packages\RuntimeGeneratedFunctions\KrkGo\src\RuntimeGeneratedFunctions.jl:117
Static Parameters
N = 1
Arguments
f::RuntimeGeneratedFunction{(:ˍ₋arg1,), Symbolics.var"#_RGF_ModTag", Symbolics.var"#_RGF_ModTag", (0xab7cee3c, 0x63779d06, 0x2ab0d9ad, 0x964c6cc8, 0x3eff94b6)}
args::Tuple{Vector{Float64}}
Body::Float64
1 ─ %1 = Core.tuple(f)::Tuple{RuntimeGeneratedFunction{(:ˍ₋arg1,), Symbolics.var"#_RGF_ModTag", Symbolics.var"#_RGF_ModTag", (0xab7cee3c, 0x63779d06, 0x2ab0d9ad, 0x964c6cc8, 0x3eff94b6)}}
│ %2 = Core._apply_iterate(Base.iterate, RuntimeGeneratedFunctions.generated_callfunc, %1, args)::Float64
└── return %2
julia> function wrap_test_fn(x)
return test_fn(x)
end
wrap_test_fn (generic function with 1 method)
julia> wrap_test_fn(x0)
1.4957401577994214
julia> @code_warntype wrap_test_fn(x0)
MethodInstance for wrap_test_fn(::Vector{Float64})
from wrap_test_fn(x) in Main at REPL[9]:1
Arguments
#self#::Core.Const(wrap_test_fn)
x::Vector{Float64}
Body::Any
1 ─ %1 = Main.test_fn(x)::Any
└── return %1