Neural Networks combined with Diffeq

When training a neural network in place with a diffeq solver, does the neural network have to be inside the diffeq solver call which is then wrapped with the sciml_train? Or can I have the neural network be defined to generate parameters that are used in the diffeq solver so they exist outside the diffeq solver function call?

I’m just hitting roadblock after roadblock on every sensitivity I fail to debug. Zygote experiences try/catch failures even though I don’t think I have any, so I think they are inside the diffeq solver somewhere. ReverseDiff gets a trackedValue error on a variable that is part of the integrator composite type. So I’m thinking maybe this order of neural networks outside the diffeq solver is not supported.

Here is the first bit of the latest ReverseDiff error.

ERROR: ArgumentError: Converting an instance of ReverseDiff.TrackedReal{Float64, Float64, Nothing} to Float32 is not defined. Please use `ReverseDiff.value` instead.
Stacktrace:
  [1] convert(#unused#::Type{Float32}, t::ReverseDiff.TrackedReal{Float64, Float64, Nothing})
    @ ReverseDiff ~\.julia\packages\ReverseDiff\iHmB4\src\tracked.jl:261
  [2] fastpow
    @ ~\.julia\packages\DiffEqBase\rN9Px\src\fastpow.jl:92 [inlined]
  [3] PI_stepsize_controller!
    @ ~\.julia\packages\OrdinaryDiffEq\5egkj\src\integrators\controllers.jl:44 [inlined]
  [4] stepsize_controller!
    @ ~\.julia\packages\OrdinaryDiffEq\5egkj\src\integrators\controllers.jl:58 [inlined]
  [5] _loopfooter!(integrator::OrdinaryDiffEq.ODEIntegrator{OrdinaryDiffEq.Tsit5, true, Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, Nothing, ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}, Vector{Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}}, SciMLBase.ODESolution{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}, 2, Vector{Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}}, Nothing, Nothing, Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, Vector{Vector{Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}}}, SciMLBase.ODEProblem{Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, Tuple{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}, ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, true, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, SciMLBase.ODEFunction{true, Main.MLSAGA.SAGA.var"#17#18", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Symbol, DiffEqBase.CallbackSet{Tuple{DiffEqBase.ContinuousCallback{Main.MLSAGA.SAGA.var"#10#11", typeof(SciMLBase.terminate!), typeof(SciMLBase.terminate!), typeof(DiffEqBase.INITIALIZE_DEFAULT), typeof(DiffEqBase.FINALIZE_DEFAULT), Float64, Int64, Nothing, Int64}, DiffEqBase.ContinuousCallback{Main.MLSAGA.SAGA.var"#8#9", typeof(SciMLBase.terminate!), typeof(SciMLBase.terminate!), typeof(DiffEqBase.INITIALIZE_DEFAULT), typeof(DiffEqBase.FINALIZE_DEFAULT), Float64, Int64, Nothing, Int64}}, Tuple{}}, Tuple{Symbol}, NamedTuple{(:callback,), Tuple{DiffEqBase.CallbackSet{Tuple{DiffEqBase.ContinuousCallback{Main.MLSAGA.SAGA.var"#10#11", typeof(SciMLBase.terminate!), typeof(SciMLBase.terminate!), typeof(DiffEqBase.INITIALIZE_DEFAULT), typeof(DiffEqBase.FINALIZE_DEFAULT), Float64, Int64, Nothing, Int64}, DiffEqBase.ContinuousCallback{Main.MLSAGA.SAGA.var"#8#9", typeof(SciMLBase.terminate!), typeof(SciMLBase.terminate!), typeof(DiffEqBase.INITIALIZE_DEFAULT), typeof(DiffEqBase.FINALIZE_DEFAULT), Float64, Int64, Nothing, Int64}}, Tuple{}}}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.Tsit5, OrdinaryDiffEq.InterpolationData{SciMLBase.ODEFunction{true, Main.MLSAGA.SAGA.var"#17#18", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}}, Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, Vector{Vector{Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}}}, OrdinaryDiffEq.Tsit5Cache{Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, OrdinaryDiffEq.Tsit5ConstantCache{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}, ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}}}, DiffEqBase.DEStats}, SciMLBase.ODEFunction{true, Main.MLSAGA.SAGA.var"#17#18", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, OrdinaryDiffEq.Tsit5Cache{Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, OrdinaryDiffEq.Tsit5ConstantCache{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}, ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}}, OrdinaryDiffEq.DEOptions{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}, Float64, ReverseDiff.TrackedReal{Float64, Float64, Nothing}, ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Nothing, DiffEqBase.CallbackSet{Tuple{DiffEqBase.ContinuousCallback{Main.MLSAGA.SAGA.var"#10#11", typeof(SciMLBase.terminate!), typeof(SciMLBase.terminate!), typeof(DiffEqBase.INITIALIZE_DEFAULT), typeof(DiffEqBase.FINALIZE_DEFAULT), Float64, Int64, Nothing, Int64}, DiffEqBase.ContinuousCallback{Main.MLSAGA.SAGA.var"#8#9", typeof(SciMLBase.terminate!), typeof(SciMLBase.terminate!), typeof(DiffEqBase.INITIALIZE_DEFAULT), typeof(DiffEqBase.FINALIZE_DEFAULT), Float64, Int64, Nothing, Int64}}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryMinHeap{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, DataStructures.BinaryMinHeap{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, Nothing, Nothing, Int64, Tuple{}, Tuple{}, Tuple{}}, Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}, Nothing, OrdinaryDiffEq.DefaultInit})
    @ OrdinaryDiffEq ~\.julia\packages\OrdinaryDiffEq\5egkj\src\integrators\integrator_utils.jl:188
  [6] loopfooter!
    @ ~\.julia\packages\OrdinaryDiffEq\5egkj\src\integrators\integrator_utils.jl:168 [inlined]

It seems that the Diffeq step size controller uses fastpow routine and it is not a generic method that can be used with tracked value. It seems the routine trys to convert the value to Float32. I think there is a missing fallback method in fastpow in Diffeq for tracked values.

Can you give me an MWE? This is an easy thing to fix, but is something introduced by a recent ReverseDiff patch. I just need to figure out where ReverseDiff wants me to make the change.

This is an issue introduced by v0.12 (@CarloLucibello @dhairyagandhi96).

function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...)
  size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))"))
  if eltype(bias) == eltype(weights)
    return bias
  else
    @warn "converting bias to match element type of weights" typeof(weights) typeof(bias) maxlog=3 _id=hash(dims)
    return broadcast(eltype(weights), bias)
  end
end

is not correct because not all number types can convert like that.

function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...)
  size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))"))
  if eltype(bias) == eltype(weights)
    return bias
  else
    @warn "converting bias to match element type of weights" typeof(weights) typeof(bias) maxlog=3 _id=hash(dims)
    return convert.((eltype(weights),),bias)
  end
end

is needed, but that throws a warning every time you use restructure. I think we’ll need to pirate Flux for now until we find a way to eliminate the warning from legitimate use cases.

It seems that this knocked down our tests, and

should fix it up. And an upstream issue:

Is that from something in my model. I do have a custom type or is this a bug in their code?

Should I back up my Flux or DiffEq version to fix this or will it be fixed soon enough?

I’ll merge this fix in the next 2 hours.

You are most excellent. Thank you. I don’t know if it will solve all of my problems, but one more brick on the road to victory!

Tagged.

I’m assuming that putting the neural network outside the diffeq solver should work right, or do I need to move it inside the diffeq solver?

Both should be fine

I tried updating which pushed things to Flux 0.12.1 and I got the same error as before. I tried to find the tag and see if it was just because it wasn’t in master yet, and I failed there too but I may have taken the incorrect tag (dg/den). Do I have the wrong version, or did this not fix my problem?

Flux 0.12.1 should be fine, unless you have a different issue in which case you need to share some code.

Sharing code is challenging. I need to come up with a toy problem that does the same thing.

I do remember someone suggesting in a previous discourse thread converting the p and u0 inputs in the remake command so that it handles the tracked variables better. I can’t find that comment anymore so I don’t remember the syntax.

Could that be causing a problem?

It may be that I have a composite type that holds the setup for some parameters. I think I may be passing that into the solver in a way that doesn’t work for the tracked maybe. I convert the struct to a vector and then back to a struct once inside the ODE_Prob code. I just noticed that I’m converting a “Bool” parameter to a “Real” in there. That may not be compatible with the Tracked conversions.

Inside my prob_callback the parameters get put back into the struct. I think I remember seeing that there is a way to pass that structure in without doing that.

# Not the real code, but basically does this
prob_noSave = ODEProblem((du, u, p, t)->example!(du, u, p, t),u0,tspan)


# The sc variable has parameters that also have some value calculated by a neural network before
# this call.  That neural network weights are the parameters being trained by the DiffEqFlux.sciml_train 
# command that is called upstream of this function call
function runSim_noSave(sc::SAGAConst)
    
    u0 = setupSim(sc)
    p = toParameter(sc)

    # Setup ODE Problem
    _prob = remake(prob_noSave; p, u0, callback = cb_noSave)
    sol = solve(_prob,Tsit5(),reltol=1e-9,sensealg=sensitivity)
    #sol = solve(_prob,Tsit5(),reltol=1e-9)
    

    return(;sol,sc)
end

# This code is not my actual code, but basically what it does
function example!(du, u, p, t)
    a= u[1:3]
    b= u[4:6]  
     blah blah blah

    sc = fromParameter(p)

    out = examplecalcs(sc,t,a,b)

   du[1:3] = out.var1
   du[4:6] = out.a
end



_prob = remake(prob_noSave; p, u0, callback = cb_noSave) isn’t correct. _prob = remake(prob_noSave; p=_p, u0=_u0, callback = cb_noSave) etc.

It did not change the results for ReverseDiffAdjoint or ZygoteAdjoint. ReverseDiff gives the same error as listed above and Zygote gives the try-catch error.

Okay. Without a runnable MWE I cannot do much more. But the actual code here doesn’t necessarily matter (nor do I want it :laughing: ). Try and delete as much as possible in a way that still has the same error. If you do that, you should be able to end up with something that looks nothing like the original model but contains the original error, and that’s a runnable code you can share, I can fix, and then I can add to the repo as a test.

One thing to try is just ]add Flux@0.11 and see if it was recently introduced (if it was related to what I just fixed, or not).

Ok, thanks, I appreciate that.

The only sensitivity that seems to run, is TrackerAdjoint, but it doesn’t actually train. The error always stays the same. Does that seem typical of trackedAdjoint that it would have problems? I’m hoping it points me to my larger problem.

I wish I knew how to understand the debug output better. I don’t feel I get much useful information out of it to know where to zero in on in my own code.

And the reason the trackerAdjoint doesn’t do anything, is that if I take the gradient of the loss function with the neural network parameters, it is zero. Boo.