Understanding compilation with Optim.jl and OrdinaryDiffEq.jl

I’m currently working on numerical optimisation problems that involve solving ODEs, similar to this.

I’m trying to understand recurring TTFX when I repeatedly run variations of numerical optimisation problems. Here’s the setup:

using OrdinaryDiffEq
using Optim

function f(du, u, p, t)
    du[1] = dx = p[1] * u[1] - u[1] * u[2]
    du[2] = dy = -3 * u[2] + u[1] * u[2]
end

u0 = [1.0; 1.0]
tspan = (0.0, 10.0)
p = [1.5]
prob = ODEProblem(f, u0, tspan, p)

function loss_func(p)
	remade_prob = remake(prob, p=p); 
	sol = try 
		solve(remade_prob);
	catch e
		return Inf # ODE simulation diverges
	end
	return (sol.u[end][1] - 2) ^ 2 # arbitrary loss function
end

Here are some timings:

@time res = optimize(loss_func, [0.1], [10.], [1.], Fminbox(BFGS()); autodiff = :forward)
# 14.201909 seconds (56.20 M allocations: 2.848 GiB, 7.87% gc time, 99.98% compilation time)

This is classic TTFX, no surprises here.

alias(p) = loss_func(p)
@time res = optimize(alias, [0.1], [10.], [1.], Fminbox(BFGS()); autodiff = :forward)
# 11.889799 seconds (48.73 M allocations: 2.497 GiB, 6.41% gc time, 99.98% compilation time)

Question 1: What is being compiled here?

Now let’s say I want to perform the optimisation on the log space instead. If I use an existing function name…

alias(p) = loss_func(exp.(p))
@time res = optimize(alias, [log(0.1)], [log(10.)], [log(1.)], Fminbox(BFGS()); autodiff = :forward)
# 0.155743 seconds (235.77 k allocations: 12.319 MiB, 45.72% gc time, 99.34% compilation time: 34% of which was recompilation)

Most time is still spent on compilation, but it is much shorter.

If I did the same thing with a new function name…

alias2(p) = loss_func(exp.(p))
@time res = optimize(alias2, [log(0.1)], [log(10.)], [log(1.)], Fminbox(BFGS()); autodiff = :forward)
# 12.119719 seconds (48.92 M allocations: 2.506 GiB, 6.72% gc time, 99.99% compilation time)

Question 2: Why does using an existing function name make a difference?

I don’t have experience with profiling, so if the answer is to investigate using profiling, I’d appreciate some guidance on doing that.

Every function in Julia is its own type, so this re-specializes. SciML packages mostly have high level handling to avoid this recompilation (though Optimization.jl notably does not have it yet), but Optim directly wouldn’t.

Same answer.

I’m afraid I’m not familiar with the details about specialization. As far as I understand, specialization refers to different methods that a function has depending on its argument types. So in my example, is optimize being re-specialized whenever a different function is passed in its first argument?

I have a followup question that is more on the practical side then: To run the two optimisation problems in the example, is the following an advisable way to avoid ‘unnecessary’ compilation?

# assume same setup code, with definition of `loss_func`

optim_func(p) = loss_func(p)
res = optimize(optim_func, [0.1], [10.], [1.], Fminbox(BFGS()); autodiff = :forward)

optim_func(p) = loss_func(exp.(p))
res = optimize(optim_func, [log(0.1)], [log(10.)], [log(1.)], Fminbox(BFGS()); autodiff = :forward)

Command-line options --trace-compile and, on nightly, --trace-compile-timing exist to find that out.

First thing I did was get rid of the global variables in your example, to make the example relevant (file /tmp/j.jl):

using OrdinaryDiffEq, Optim

function f(du, u, p, t)
    du[1] = dx = p[1] * u[1] - u[1] * u[2]
    du[2] = dy = -3 * u[2] + u[1] * u[2]
end

const prob = let
    u0 = [1.0; 1.0]
    tspan = (0.0, 10.0)
    p = [1.5]
    ODEProblem(f, u0, tspan, p)
end

function loss_func(p)
    remade_prob = remake(prob, p=p); 
    sol = try 
        solve(remade_prob);
    catch e
        return Inf # ODE simulation diverges
    end
    return (sol.u[end][1] - 2) ^ 2 # arbitrary loss function
end

optimize(loss_func, [0.1], [10.], [1.], Fminbox(BFGS()); autodiff = :forward)

alias(p) = loss_func(p)
optimize(alias, [0.1], [10.], [1.], Fminbox(BFGS()); autodiff = :forward)

Then I ran Julia like so:

julia -t9 --trace-compile=/tmp/compile.jl --trace-compile-timing /tmp/j.jl

The relevant lines, mentioning alias, are:

#=   20.5 ms =# precompile(Tuple{Type{NLSolversBase.OnceDifferentiable{TF, TDF, TX} where TX where TDF where TF}, typeof(Main.alias), Array{Float64, 1}, Float64, Array{Float64, 1}, Symbol, ForwardDiff.Chunk{1}})
#=    4.4 ms =# precompile(Tuple{Type{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, V, N} where N where V}, Float64, ForwardDiff.Partials{1, Float64}})
#=    4.3 ms =# precompile(Tuple{Type{ForwardDiff.Partials{1, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}}, Tuple{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}})
#=    4.7 ms =# precompile(Tuple{Type{Base.Generator{I, F} where F where I}, FunctionWrappers.var"#14#15"{Type{Tuple{Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Float64}}}, Base.UnitRange{Int64}})
#=   11.1 ms =# precompile(Tuple{typeof(Base.iterate), Base.Generator{Base.UnitRange{Int64}, FunctionWrappers.var"#14#15"{Type{Tuple{Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Float64}}}}})
#=    6.1 ms =# precompile(Tuple{typeof(Base.iterate), Base.Generator{Base.UnitRange{Int64}, FunctionWrappers.var"#14#15"{Type{Tuple{Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Float64}}}}, Int64})
#=    4.6 ms =# precompile(Tuple{Type{Base.Generator{I, F} where F where I}, FunctionWrappers.var"#14#15"{Type{Tuple{Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Float64}}}, Base.UnitRange{Int64}})
#=   11.0 ms =# precompile(Tuple{typeof(Base.iterate), Base.Generator{Base.UnitRange{Int64}, FunctionWrappers.var"#14#15"{Type{Tuple{Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Float64}}}}})
#=    6.1 ms =# precompile(Tuple{typeof(Base.iterate), Base.Generator{Base.UnitRange{Int64}, FunctionWrappers.var"#14#15"{Type{Tuple{Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Float64}}}}, Int64})
#=    4.6 ms =# precompile(Tuple{Type{Base.Generator{I, F} where F where I}, FunctionWrappers.var"#14#15"{Type{Tuple{Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}}}}, Base.UnitRange{Int64}})
#=   11.0 ms =# precompile(Tuple{typeof(Base.iterate), Base.Generator{Base.UnitRange{Int64}, FunctionWrappers.var"#14#15"{Type{Tuple{Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}}}}}})
#=    6.0 ms =# precompile(Tuple{typeof(Base.iterate), Base.Generator{Base.UnitRange{Int64}, FunctionWrappers.var"#14#15"{Type{Tuple{Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}}}}}, Int64})
#=    4.6 ms =# precompile(Tuple{Type{Base.Generator{I, F} where F where I}, FunctionWrappers.var"#14#15"{Type{Tuple{Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}}}}, Base.UnitRange{Int64}})
#=   11.0 ms =# precompile(Tuple{typeof(Base.iterate), Base.Generator{Base.UnitRange{Int64}, FunctionWrappers.var"#14#15"{Type{Tuple{Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}}}}}})
#=    6.0 ms =# precompile(Tuple{typeof(Base.iterate), Base.Generator{Base.UnitRange{Int64}, FunctionWrappers.var"#14#15"{Type{Tuple{Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}}}}}, Int64})
#=    3.8 ms =# precompile(Tuple{Type{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}})
#=    4.8 ms =# precompile(Tuple{typeof(Base.:(*)), ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, Base.Rational{Int64}})
#=    3.6 ms =# precompile(Tuple{typeof(Base.convert), Type{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}})
#=    3.6 ms =# precompile(Tuple{typeof(Base.real), ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}})
#=    4.1 ms =# precompile(Tuple{typeof(ForwardDiff.value), ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}})
#=    4.9 ms =# precompile(Tuple{typeof(Base.convert), Type{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, Bool})
#=    4.6 ms =# precompile(Tuple{typeof(Base.:(+)), ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}})
#=    4.8 ms =# precompile(Tuple{typeof(Base.:(/)), ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}})
#=    4.6 ms =# precompile(Tuple{typeof(Base.:(*)), ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}})
#=    4.4 ms =# precompile(Tuple{typeof(Base.:(-)), ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}})
#=    5.2 ms =# precompile(Tuple{typeof(Base.sqrt), ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}})
#=    4.4 ms =# precompile(Tuple{Type{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, Float64})
#=    4.8 ms =# precompile(Tuple{typeof(Base.:(*)), Float64, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}})
#=    4.8 ms =# precompile(Tuple{typeof(Base.inv), ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}})
#=    5.2 ms =# precompile(Tuple{typeof(LinearAlgebra.norm), ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}})
#=    3.8 ms =# precompile(Tuple{typeof(Base.float), ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}})
#=    5.3 ms =# precompile(Tuple{typeof(Base.abs), ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}})
#=    5.0 ms =# precompile(Tuple{typeof(Base.abs2), ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}})
#=    4.6 ms =# precompile(Tuple{typeof(Base.convert), Type{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, Int64})
#=    5.1 ms =# precompile(Tuple{typeof(LinearAlgebra.dot), ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}})
#=    4.6 ms =# precompile(Tuple{typeof(Base.Broadcast.broadcasted), typeof(Base.identity), ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}})
#=    7.9 ms =# precompile(Tuple{typeof(Base.getproperty), Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(Base.identity), Tuple{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}}, Symbol})
#=    3.7 ms =# precompile(Tuple{typeof(Base.getindex), ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}})
#=    4.9 ms =# precompile(Tuple{typeof(Base.:(-)), ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}})
#=   14.3 ms =# precompile(Tuple{typeof(Polyester.add_var!), Expr, Expr, Expr, Type{Base.SubArray{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 2, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 2}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.UnitRange{Int64}}, true}}, Symbol, Symbol, Int64})
#=    4.7 ms =# precompile(Tuple{Type{Polyester.BatchClosure{RecursiveFactorization.var"#apply_permutation!##0#apply_permutation!##1", ManualMemory.Reference{Tuple{Static.StaticInt{1}, Static.StaticInt{1}, Polyester.NoLoop, Polyester.CombineIndices, StrideArraysCore.AbstractPtrArray{Int64, 1, (1,), Tuple{Int64}, Tuple{Nothing}, Tuple{Static.StaticInt{1}}, Int64}, StrideArraysCore.AbstractPtrArray{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 2, (1, 2), Tuple{Int64, Int64}, Tuple{Nothing, Nothing}, Tuple{Static.StaticInt{1}, Static.StaticInt{1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}}}, false, Tuple{}}}, Function})
#=    4.2 ms =# precompile(Tuple{typeof(Base.getproperty), Polyester.BatchClosure{RecursiveFactorization.var"#apply_permutation!##0#apply_permutation!##1", ManualMemory.Reference{Tuple{Static.StaticInt{1}, Static.StaticInt{1}, Polyester.NoLoop, Polyester.CombineIndices, StrideArraysCore.AbstractPtrArray{Int64, 1, (1,), Tuple{Int64}, Tuple{Nothing}, Tuple{Static.StaticInt{1}}, Int64}, StrideArraysCore.AbstractPtrArray{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 2, (1, 2), Tuple{Int64, Int64}, Tuple{Nothing, Nothing}, Tuple{Static.StaticInt{1}, Static.StaticInt{1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}}}, false, Tuple{}}, Symbol})
#=   14.2 ms =# precompile(Tuple{typeof(Polyester.add_var!), Expr, Expr, Expr, Type{Base.SubArray{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 2, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 2}, Tuple{Base.UnitRange{Int64}, Base.UnitRange{Int64}}, false}}, Symbol, Symbol, Int64})
#=    4.7 ms =# precompile(Tuple{Type{Polyester.BatchClosure{RecursiveFactorization.var"#apply_permutation!##0#apply_permutation!##1", ManualMemory.Reference{Tuple{Static.StaticInt{1}, Static.StaticInt{1}, Polyester.NoLoop, Polyester.CombineIndices, StrideArraysCore.AbstractPtrArray{Int64, 1, (1,), Tuple{Int64}, Tuple{Nothing}, Tuple{Static.StaticInt{1}}, Int64}, StrideArraysCore.AbstractPtrArray{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 2, (1, 2), Tuple{Int64, Int64}, Tuple{Nothing, StrideArraysCore.StrideReset{Int64}}, Tuple{Static.StaticInt{1}, Static.StaticInt{1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}}}, false, Tuple{}}}, Function})
#=    4.1 ms =# precompile(Tuple{typeof(Base.getproperty), Polyester.BatchClosure{RecursiveFactorization.var"#apply_permutation!##0#apply_permutation!##1", ManualMemory.Reference{Tuple{Static.StaticInt{1}, Static.StaticInt{1}, Polyester.NoLoop, Polyester.CombineIndices, StrideArraysCore.AbstractPtrArray{Int64, 1, (1,), Tuple{Int64}, Tuple{Nothing}, Tuple{Static.StaticInt{1}}, Int64}, StrideArraysCore.AbstractPtrArray{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 2, (1, 2), Tuple{Int64, Int64}, Tuple{Nothing, StrideArraysCore.StrideReset{Int64}}, Tuple{Static.StaticInt{1}, Static.StaticInt{1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}}}, false, Tuple{}}, Symbol})
#= 7685.3 ms =# precompile(Tuple{NLSolversBase.var"#OnceDifferentiable##4#OnceDifferentiable##5"{Float64, typeof(Main.alias), ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}}}, Array{Float64, 1}, Array{Float64, 1}})
#=   36.2 ms =# precompile(Tuple{typeof(OrdinaryDiffEqCore.alg_cache), OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Type{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, Type{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, Type{Float64}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, SciMLBase.ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}}}}, false}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Float64, Float64, Float64, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Bool, Base.Val{true}})
#=    6.6 ms =# precompile(Tuple{FunctionWrappers.CallWrapper{Nothing}, SciMLBase.Void{FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}}}}, false}}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Float64})
#=   13.2 ms =# precompile(Tuple{FunctionWrappers.CallWrapper{Nothing}, SciMLBase.Void{typeof(Main.f)}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Float64})
#=  142.7 ms =# precompile(Tuple{typeof(OrdinaryDiffEqCore.alg_cache), OrdinaryDiffEqRosenbrock.Rosenbrock23{0, ADTypes.AutoFiniteDiff{Base.Val{:forward}, Base.Val{:forward}, Base.Val{:hcentral}}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Base.Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Type{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, Type{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, Type{Float64}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, SciMLBase.ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, 1}, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}}}}, false}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Float64, Float64, Float64, Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.alias), Float64}, Float64, 1}, 1}, Bool, Base.Val{true}})
#=    4.5 ms =# precompile(Tuple{typeof(Main.alias), Array{Float64, 1}})

The compilation of one method, belonging to an anonymous function from NLSolversBase.jl, took most of the time.

I think that’s as far as I can investigate without getting into the internals of specific packages.

struct FunctionWrapper
  f
end
FunctionWrapper(optim_func)

just pass that to the optimize. Or just use FunctionWrappers.jl (with FunctionWrappersWrappers.jl for the tag handling).

You don’t need more than that, the answer is pretty obvious from that printout and it’s exactly what I’m saying. It’s the function specialization, which is then the function specialization on the ForwardDiff type. If you just wrap the function though then those all go away.

Thanks for showing me the --trace-compile option.

FunctionWrappers.jl seems like a good way forward, I’ll do that. It looks like I’ve stumbled upon a well-known ‘problem’ that others have thought a lot more about.

If I’m not using auto-diff, then I imagine that I would pass FunctionWrapper{Float64, Tuple{AbstractVector{Float64}}}(loss_func) to optimize. But with forward auto-diff, I need this to work for Dual as well, so does this mean I should define a struct? If so, what would that be?

Actually, I’m not sure if tag handling is something that I need to worry about. What would using FunctionWrappers.jl or FunctionWrappersWrappers.jl look like for my example?

It’s a bit too much code to write down… you’d need to do the following: