Hi!
I am researcher working with problems related to optimal stopping/switching problems, and I have an issue where I need to, in a for loop, advance one time step (size dt) of an SDE for lots of previous states. And I am doing this recursively in a dynamic programming way, so I cannot directly generate complete trajectories here. In the end I want to do further computations on the results of the below, repeatedly in the for loop.
So far I have been doing it using a function like this:
function cpu_euler_maryama_alt(X_prev,time,dt,StepProc,p)
d,M = size(X_prev)
dX = [col .+ StepProc.drift(col,p,time) .* dt .+ StepProc.dispersion(col,p,time) * randn(Float32,d)*sqrt(dt) for col in eachcol(X_prev)]
return reduce(hcat,dX)
end
where StepProc is a struct containing the drift and dispersion of the SDE This works fine and is pretty efficient. I am trying to see if I can use Differentialequations.jl to accomplish the same thing, but I end up having to do a lot of allocations. Here is an example of what I have tried
function one_step(SDEprob,v0,dt,t)
function start_values(prob,i,repeat)
remake(prob,u0=v0[:,i],tspan=(t,t+dt))
end
d,M = size(v0)
ensembleprob = EnsembleProblem(SDEprob,prob_func=start_values)
sol = solve(ensembleprob,EnsembleThreads(),trajectories=M,dt=dt,saveat=[t+dt])
end
and above SDEprob is an instance of SDEProblem.
Do you know if it seems likely to do this efficiently with Differentialequations.jl?