Hi, I am using Turing inference which draws uncertain parameters and then passes them on to the simulate kernel function. The kernel takes a pre-crated ODEProblem and applies the new parameters. The code works but is really slow. The profiler shows that MTK reinitialises a lot of “things”. And the performance is important for this problem.
I used setp and had to also deploy a Tunable parameter container in order to allow for Duals. Problem:
- My p_work container always contains u0, which overwrites the later solve(u0=u0, …). So I have to set the u0 in p_work, but I don’t know if there is a better constructor than what I used.
- I used setp instead of remake for performance reasons.
Questions:
- Is there anything obviously heavy I do here?
- Why would MTK reinitialise “things”?
- Is there an obvious path here to make this faster?
- why would recompilation happen here often?
- Anyway I can use setp container without u0?
My goals:
- consider caching
- investigate type instability
I attach my full kernel below. Any advice would be appreciated!
kernel(sampled_uncertain_params, prob, model, opts_prod)
# modify container to allow for AD types
T = eltype(sampled_uncertain_params)
p_work = replace(Tunable(), prob.p, T.(model.tunable_pflat.tunable_parameters))
# use previous setter to update uncertain params
model.uncertain_param_setter!(p_work, sampled_uncertain_params)
if model.warmup
sol = solve(prob, solver, p=p_work; solver_opts..., save_end=true, save_everystep=false, dense=false)
u0 = sol.u[end]
# reconstruct p_work to allow u0 to contain a Dual
P = typeof(p_work).name.wrapper
pvec = getfield(p_work, 1)
u0_old = getfield(p_work, 2)
f3 = getfield(p_work, 3)
f4 = getfield(p_work, 4)
f5 = getfield(p_work, 5)
f6 = getfield(p_work, 6)
states = unknowns(model.sys)
u0_setter! = setu(model.sys, states)
T = eltype(u0)
u0_work = similar(u0_old, T)
copyto!(u0_work, u0_old)
# restitching
p_work = P(pvec, u0_work, f3, f4, f5, f6)
u0_setter!(p_work[2], u0)
end
for i in 1:multiparam_length
# for each experiment set the input values
for (j, symbol) in enumerate(model.multiparam_symbols)
multiparam_values[j] = model.multiparams[symbol][i]
end
model.multiparam_setter!(p_work, multiparam_values)
sol = solve(prob, solver; p=p_work, opts_prod...)
prealloc_results_vector[i] = sol
end
return prealloc_results_vector