I have a two-state ODE, one of which is embedded in a neural network, making it a UDE. I am using 9 datasets for training, resulting in 9 distinct ODEs.
prob1 = ODEProblem(UDE_model1!,[SOC0_1, T∞1],(time1[1],time1[end]),para_init)
prob2 = ODEProblem(UDE_model2!,[SOC0_2, T∞2],(time2[1],time2[end]),para_init)
prob3 = ODEProblem(UDE_model3!,[SOC0_3, T∞3],(time3[1],time3[end]),para_init)
prob4 = ODEProblem(UDE_model4!,[SOC0_4, T∞4],(time4[1],time4[end]),para_init)
prob5 = ODEProblem(UDE_model5!,[SOC0_5, T∞5],(time5[1],time5[end]),para_init)
prob6 = ODEProblem(UDE_model6!,[SOC0_6, T∞6],(time6[1],time6[end]),para_init)
prob7 = ODEProblem(UDE_model7!,[SOC0_7, T∞7],(time7[1],time7[end]),para_init)
prob8 = ODEProblem(UDE_model8!,[SOC0_8, T∞8],(time8[1],time8[end]),para_init)
prob9 = ODEProblem(UDE_model9!,[SOC0_9, T∞9],(time9[1],time9[end]),para_init)
My loss function is looking something like this:
function totalloss_UDE(θ)
total_error = 0.0
data_points = [(prob1, Tavg1, time1),
(prob2, Tavg2, time2),
(prob3, Tavg3, time3),
(prob4, Tavg4, time4),
(prob5, Tavg5, time5),
(prob6, Tavg6, time6),
(prob7, Tavg7, time7),
(prob8, Tavg8, time8),
(prob9, Tavg9, time9)]
for (prob, Tavg, time) in data_points
_prob = remake(prob,p=θ)
_sol = Array(solve(_prob,Tsit5(),saveat = time,sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))))
error = mean(abs2,Tavg .- _sol[2,:])
total_error = total_error + error
end
return total_error
end
Is there any way to run this faster? Right now, it takes a lot of time on my CPU as this is a serial problem. Is there a way to make this problem parallel?