Adding and updating a scalar to the Integerator Cache in DifferentialEquations.jl

Background. I have a nice stochastic jump-piecewise-ODE behaving as it should. Thanks @ChrisRackauckas and team! Basically this would have been completely impossible in any other language/package! Now I want to calculate a likelihood for L(θ) for the hidden markov process in the “dumb” (forward) way. That is, I have observations, y[1]...y[n] at time times[1]...times[n] I know the likelihoods for the observations given the underlying state p(y[2]|x2), where x2 = u(times[2]) I wish to get,

L(θ) = p(y,t|θ) = ∫dx1 … ∫dxn π(x1) p(y[1]|x1) ∗ p(x2|x1) p(y[2]|x2) ∗ … ∗ p(y[n]|xn)
which we can write as
L(θ) = 〈Π_i p(y[i]|xi) |θ〉where xi are the states of trajectories at time i, and〈 · |θ〉is the average over trajectories obeying the parameters θ.

Let us call the single-trajectory likelihood
Π_i p(y[i]|xi) = (traj_likelihood)

At the end of integration, I only need traj_likelihood, over which I am planning to take an ensemble average. (yes it’s slow, but I have no choice.) So I would like to allocate a scalar Float64 reference in the Integrator cache so that every time I hit a times[i] I calculate the conditional (log)-likelihood log(p(y[2]|x2)) and add it to this mutable cache. Then, at the end, all I want to return is this log-likelihood value (the cache at the end of integration).

I’m thinking all I need is a list of PresetTimeCallback(times[i], affect!), and affect! to change this cache, but I can’t figure out how to allocate this cache, and access it in affect! because there’s just not quite enough here for me to work out the interface. Thanks for the help!

Edit: It occurs to me that this post doesn’t have enough code. I will post a concrete example later today.

Here’s an example of the idea using an ornstein uhlenbeck process with noisy observations

θ = 3.0
u₀ = 1 / 2
f(u, p, t) = -(u - θ)
g(u, p, t) = 2.0
tspan = (0.0, 1.0)
dt = 1 // 2^(4)
prob = SDEProblem(f, g, u₀, (0.0, 4.0))


# here's a helper function that makes the CallbackList for the observations

function callback_maker(ys,ts; σ = 2.0)
    observation_loglikelihood(x,y) = (x - y)^2 / (2*σ^2) - log(2π) / 2
    cbs = []
    for (y,t) in zip(ys,ts)
        function affect!(integrator)
            ######## QUESTION POINT
            # do something like this?
            # integrator.cache += observation_loglikelihood(integrator.u,y)
            # commented out so everything else runs
            integrator.u = 1.0 + integrator.u # observations also modify state in my context.
            nothing
        end
        push!(cbs,PresetTimeCallback(t, affect!))
    end
    CallbackSet(cbs...)
end
# here's a list of observations
y = [3.1,7.2,3.3]
times = [1.0,2.0,3.0]
# we use them to construct a callback set
cb = callback_maker(y,times; σ = 2.0)

sol = solve(prob, EM(); dt = dt, callback = cb, save_everystep = false, save_start = false, save_end = false, initialize_save = false) 

At the end, this all gets wrapped up in a function and what we want is a bare bones set of likelihood values, which I might construct like

function likelihoods(θ,y,t; nsamples)
#
# construct prob as above with θ as the new O-U process mean
#
return  [begin
    sol = solve(prob, EM(); dt = dt, callback = cb, save_everystep = false, save_start = false, save_end = false, initialize_save = false) # would like as little overhead as possible saving things.
    # all we want to return is the likelihood valuye
    sol.cache end for _ in 1:nsamples]
end

Have you look at the SavingCallback? Output and Saving Controls · DiffEqCallbacks.jl

That sounds like it might be what you want.

1 Like

Thanks for reply. I have seen it, but I have two questions about SavingCallback

1.) I only wish to allocate one scalar for the whole of the integration, not allocate for each observation.

2.) I don’t understand how to combine SavingCallback with PresetTimeCallback.

These two concerns have me hoping that I can figure out how to control the caching myself, but I guess the mere existence of SavingCallback shows that it is not so simple?

You can always make a custom callback function that encodes the rate and affect functions, and has internal state. Here is an example for storing and using the previous value of a simulation at every step (but you could do this for a continuous callback too).

FAQ · JumpProcesses.jl?

(Note the “?” Is part of the link but Discourse drops it.)

So you could make a PresetTimeCallback that takes such a functor, and can save whatever you want each time the associated affect is called. Since you create the functor you can then access the value(s) stored in it after a simulation.

1 Like

Okay lemme try.


mutable struct ObservationLikelihood{T}
    llk::T
end

function (obsl::ObservationLikelihood)(x,y ; σ = 2.0)
    obsl.llk += (x - y)^2 / (2*σ^2) - log(2π) / 2
end

# here's a list of observations

function callback_maker(ys,ts; σ = 2.0)
    obsl = ObservationLikelihood(0.0)
    cbs = []
    for (y,t) in zip(ys,ts)
        function affect!(integrator)
            ######## QUESTION POINT
            # do something like this?
            obsl(integrator.u, y; σ)
            # commented out so everything else runs
            integrator.u = 1.0 + integrator.u # observations also modify state in my context.
            nothing
        end
        push!(cbs,PresetTimeCallback(t, affect!))
    end
    (CallbackSet(cbs...), obsl)
end


# here's a list of observations
y = [3.1,7.2,3.3]
times = [1.0,2.0,3.0]
# we use them to construct a callback set, and then assign the aggregator to an external variable.
(cb, obsl) = callback_maker(y,times; σ = 2.0)

sol = solve(prob, EM(); dt = dt, callback = cb, save_everystep = false, save_start = false, save_end = false, initialize_save = false) 
obsl.llk # the value we want?

Ok this works, no need to learn deeply about integrator! But are you sure this idiomatic and fast? Isn’t it basically just a typed global variable?

Just make that your full affect! function, i.e.

function (obsl::ObservationLikelihood)(integrator)
    # update obsl.llk and integrator.u
end

and then use PresetTimeCallback(t, obsl). You can always store / pass sigma as another field within the ObservationLikeLihood struct.

I believe this approach is how ODEFunctions are defined, which wrap and store user provided f! and jac! functions and such, and are what is then called in ODE simulations to evaluate the derivative functions.

Note that PresetTimeCallback can take a vector of times so you don’t need to create one for each individual time.

https://docs.sciml.ai/DiffEqCallbacks/stable/timed_callbacks/#Timed-Callbacks

I don’t think I can follow your suggestions in a straightforward way because each callback also depends on the actual observation yi, yet they all share contribute to the same total likelihood. If each observation has a different functor affect!, they will not share the same aggregation. This also makes it harder to define one callback because the value of y is different for every observation time.

Therefore, I’m wondering if your pattern has concrete advantages over what I wrote above, and if so why? (Potentially I could store things as a vector and cycle through the components with more functor fields but it gets uglier.) Whether your affect! is a functor itself or calls another functor, does this make a difference?

Thanks for your patience.

I would generally avoid creating lots of functions or callbacks that are meant to be called once. Can’t you just store your ys and ts data within the ObservationLikelihood structure, and then use integrator.t to figure out which data point you need?

Just to elaborate, under the hood a PresetTimeCallback is just creating a DiscreteCallback with a condition function that checks if the time after a solver timestep is one of your chosen callback times. If you create a ton of them, each of those generated condition functions is called at every timestep of the ODE solver.

See https://github.com/SciML/DiffEqCallbacks.jl/blob/master/src/preset_time.jl

Thanks @isaacsas for these points. I will play around with these ideas and see what is best. I will write a functor with a ticking index::Int64 field that updates each time it is called, which is then used to select the y[index] at that time point. I’ll do some benchmarking then post my solution for posterity.

EDIT: re. multithreading, I’m worried about the global variable nature of this solution: it would be nice to use the multithreaded ensemble framework, and a common callback cache will not work. I think my original goal of writing to the integrator cache would still be best if anyone knows how to do it.

EDIT: one other thing I’m having trouble with is turning off the saving at the callback as well. I really only need the likelihood, available via the callback, not the trajectory

Set the save_positions = (false,false) keyword argument to the callback,

https://docs.sciml.ai/DiffEqDocs/stable/features/callback_functions/#SciMLBase.DiscreteCallback

and adjust the saving controls for the ODE solution as desired:

https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/#CommonSolve.solve-Tuple{SciMLBase.AbstractDEProblem,%20Vararg{Any}}

1 Like