Turing Inference with MTK ODEProblem

Hi, I am using Turing inference which draws uncertain parameters and then passes them on to the simulate kernel function. The kernel takes a pre-crated ODEProblem and applies the new parameters. The code works but is really slow. The profiler shows that MTK reinitialises a lot of “things”. And the performance is important for this problem.

I used setp and had to also deploy a Tunable parameter container in order to allow for Duals. Problem:

  • My p_work container always contains u0, which overwrites the later solve(u0=u0, …). So I have to set the u0 in p_work, but I don’t know if there is a better constructor than what I used.
  • I used setp instead of remake for performance reasons.

Questions:

  • Is there anything obviously heavy I do here?
  • Why would MTK reinitialise “things”?
  • Is there an obvious path here to make this faster?
  • why would recompilation happen here often?
  • Anyway I can use setp container without u0?

My goals:

  • consider caching
  • investigate type instability

I attach my full kernel below. Any advice would be appreciated!

kernel(sampled_uncertain_params, prob, model, opts_prod)
    # modify container to allow for AD types
    T = eltype(sampled_uncertain_params)
    p_work = replace(Tunable(), prob.p, T.(model.tunable_pflat.tunable_parameters))

    # use previous setter to update uncertain params
    model.uncertain_param_setter!(p_work, sampled_uncertain_params)
    
    if model.warmup
        sol = solve(prob, solver, p=p_work; solver_opts..., save_end=true, save_everystep=false, dense=false)
        u0 = sol.u[end]
    
        # reconstruct p_work to allow u0 to contain a Dual
        P = typeof(p_work).name.wrapper

        pvec = getfield(p_work, 1)
        u0_old   = getfield(p_work, 2)
        f3       = getfield(p_work, 3)
        f4       = getfield(p_work, 4)
        f5       = getfield(p_work, 5)
        f6       = getfield(p_work, 6)

        states = unknowns(model.sys)
        u0_setter! = setu(model.sys, states)

        T = eltype(u0)
        u0_work = similar(u0_old, T)
        copyto!(u0_work, u0_old)
        # restitching
        p_work = P(pvec, u0_work, f3, f4, f5, f6)
        u0_setter!(p_work[2], u0)
    end

    for i in 1:multiparam_length
        # for each experiment set the input values
        for (j, symbol) in enumerate(model.multiparam_symbols)
            multiparam_values[j] = model.multiparams[symbol][i]
        end

        model.multiparam_setter!(p_work, multiparam_values)

        sol = solve(prob, solver; p=p_work, opts_prod...)
        prealloc_results_vector[i] = sol
    end

    return prealloc_results_vector