Does sciml_train require ensemble to collect MSE of monte carlo solves when connected to a neural network or can I do it with a for loop?

General diffeqflux question.

I have a diffeq simulation using DifferentialEquations.jl and solve that works fine by itself.

I want to prepend the outputs of a neural network to the solution setup so when trained, it would take solution input parameters and create additional sets of optimized parameters that are needed by the simulation. I want to train this, by running the solution N number of times with randomly determined input parameters and MSE the cost/loss function to get the ultimate loss function that goes back to DiffEqFlux.sciml_train.

I am not using the ensemble feature in the diffeq solver. I have set it up to run the solve in a for loop and all of that appears to work fine. But when all of the runs are made for the first pass through the sciml_train routine, it crashes on the zygote reverse AD stuff.

Do I need to use the ensemble feature for this to work? It seems like Zygote is going to try to trace N times through each solver over and over. Before I changed everything, I thought I would ask for advice. Right now I get a crash message: “ERROR: Compiling Tuple{…} try/catch is not supported.” in the instrument function of reverse.jl

Thoughts?
Best Regards,
Allan Baker

You don’t need the ensemble feature for this to work, but you do need to make sure you have differentiable parallel primitives. EnsembleThreads implements a differentiable tmap in SciMLBase for this.

Thank you. I will look for that in the documentation.

I also may have something else wrong. I’m going to take it out of the for loop and see if it will work with just one.

Is zygote the best to use for this kind of thing or is there a better sensitivity?

It’s probably pretty good here, just avoid try/catch which is baked into things like @threads

Side note. I took out the for loop so there is only a one time through. I still get a compile error. Is there any advice on how to debug this type of error?

In the Zygote error message it says try/catch is not supported… how do I find that, I don’t think it is in my code and I don’t have any threading. In that long error string will it point to the villain? I’ll keep searching.

I do have some table interpolations going on in the function, could that be a problem?

If I use reverse diff the error is:

ERROR: TypeError: in TrackedReal, in V, expected V<:Real, got Type{Any}

If I use Zygote the error is:

ERROR: Compiling Tuple{typeof(OrdinaryDiffEq._postamble!),OrdinaryDiffEq.ODEIntegrator{OrdinaryDiffEq.Tsit5,true,Array{Float64,1},Nothing,Float64,Array{Any,1},Float64,Float64,Float64,Array{Array{Float64,1},1},SciMLBase.ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},SciMLBase.ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Any,1},SciMLBase.ODEFunction{true,Main.MLSAGA.SAGA.var"#14#15"{Main.MLSAGA.SAGA.SAGAConst},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Symbol,DiffEqBase.CallbackSet{Tuple{DiffEqBase.ContinuousCallback{Main.MLSAGA.SAGA.var"#6#7",typeof(SciMLBase.terminate!),typeof(SciMLBase.terminate!),typeof(DiffEqBase.INITIALIZE_DEFAULT),typeof(DiffEqBase.FINALIZE_DEFAULT),Float64,Int64,Nothing,Int64},DiffEqBase.ContinuousCallback{Main.MLSAGA.SAGA.var"#4#5",typeof(SciMLBase.terminate!),typeof(SciMLBase.terminate!),typeof(DiffEqBase.INITIALIZE_DEFAULT),typeof(DiffEqBase.FINALIZE_DEFAULT),Float64,Int64,Nothing,Int64}},Tuple{}},Tuple{Symbol},NamedTuple{(:callback,),Tuple{DiffEqBase.CallbackSet{Tuple{DiffEqBase.ContinuousCallback{Main.MLSAGA.SAGA.var"#6#7",typeof(SciMLBase.terminate!),typeof(SciMLBase.terminate!),typeof(DiffEqBase.INITIALIZE_DEFAULT),typeof(DiffEqBase.FINALIZE_DEFAULT),Float64,Int64,Nothing,Int64},DiffEqBase.ContinuousCallback{Main.MLSAGA.SAGA.var"#4#5",typeof(SciMLBase.terminate!),typeof(SciMLBase.terminate!),typeof(DiffEqBase.INITIALIZE_DEFAULT),typeof(DiffEqBase.FINALIZE_DEFAULT),Float64,Int64,Nothing,Int64}},Tuple{}}}}},SciMLBase.StandardODEProblem},OrdinaryDiffEq.Tsit5,OrdinaryDiffEq.InterpolationData{SciMLBase.ODEFunction{true,Main.MLSAGA.SAGA.var"#14#15"{Main.MLSAGA.SAGA.SAGAConst},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats},SciMLBase.ODEFunction{true,Main.MLSAGA.SAGA.var"#14#15"{Main.MLSAGA.SAGA.SAGAConst},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}},OrdinaryDiffEq.DEOptions{Float64,Float64,Float64,Float64,typeof(DiffEqBase.ODE_DEFAULT_NORM),typeof(LinearAlgebra.opnorm),Nothing,DiffEqBase.CallbackSet{Tuple{DiffEqBase.ContinuousCallback{Main.MLSAGA.SAGA.var"#6#7",typeof(SciMLBase.terminate!),typeof(SciMLBase.terminate!),typeof(DiffEqBase.INITIALIZE_DEFAULT),typeof(DiffEqBase.FINALIZE_DEFAULT),Float64,Int64,Nothing,Int64},DiffEqBase.ContinuousCallback{Main.MLSAGA.SAGA.var"#4#5",typeof(SciMLBase.terminate!),typeof(SciMLBase.terminate!),typeof(DiffEqBase.INITIALIZE_DEFAULT),typeof(DiffEqBase.FINALIZE_DEFAULT),Float64,Int64,Nothing,Int64}},Tuple{}},typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN),typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE),typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK),DataStructures.BinaryHeap{Float64,Base.Order.ForwardOrdering},DataStructures.BinaryHeap{Float64,Base.Order.ForwardOrdering},Nothing,Nothing,Int64,Tuple{},Tuple{},Tuple{}},Array{Float64,1},Float64,Nothing,OrdinaryDiffEq.DefaultInit}}: try/catch is not supported.
Stacktrace:
 [1] error(::String) at .\error.jl:33
 [2] instrument(::IRTools.Inner.IR) at C:\Users\bakerar\.julia\packages\Zygote\KpME9\src\compiler\reverse.jl:121

I usually just isolate.