# DiffEqFlux: an error when using NN's that take both time and position as input

Hi everyone, I am new to the Julia programming language and ran into an error when programming.

My problem:
I would like to solve a ODE optimal control problem with a neural network based approach. I take the ODE and replace the control variable by a Neural Network that takes as input both the position and time. The Neural Network is trained by using the sciml_train function (of the DiffEqFlux package) on a particular loss function. This loss function is the difference between the position and the all-ones vector. Hence, I want to steer the solution of the ODE to the all-ones vector.

Code:

``````using Flux, DiffEqFlux, DiffEqSensitivity, DifferentialEquations

x_size = 6 # Size of the spatial dimensions in the ODE
v_size = 2 # Output size of the control

# Define Neural Network for the control input
input_size = x_size + 1 # size of the spatial dimensions PLUS one time dimensions
nn_initial = Chain(Dense(input_size,32, relu), Dense(32,v_size, tanh)) # The actual neural network
p_nn, model = Flux.destructure(nn_initial)
nn(x,p) = model(p)(x)

# Define the right hand side of the ODE
const_mat = ones(Float64, (x_size, v_size))

function f!(du,u,p,t)
du .= 2.0.*u + const_mat*nn([u;t],p)
end

# Define ODE problem
u0 = vec(rand(Float64, (x_size,1)))
tspan = (0.0, 1.0)
prob = ODEProblem{true}(f!, u0, tspan, p_nn)

# Defining the loss function
function loss(pars, prob)

function prob_func(prob, i, repeat)
# Prepare new initial sate and remake the problem
u0tmp = vec(rand(Float64,(x_size,1)))

remake(prob, p = pars, u0 = u0tmp)
end

ensembleprob = EnsembleProblem(prob, prob_func = prob_func)

dt = 0.01, trajectories = 10)

A = convert(Array,_sol)

loss = sum(abs2, A .- 1)

return loss
end

# Training the model
result = DiffEqFlux.sciml_train((p) -> loss(p,prob), p_nn, ADAM(0.01), maxiters=5)

``````

Error when running the code:

``````ERROR: LoadError: MethodError: no method matching increment_deriv!(::Float64, ::Float64)
Closest candidates are:
increment_deriv!(::ReverseDiff.TrackedArray, ::Real, ::Any) at C:\Users\Sven\.julia\packages\ReverseDiff\60noS\src\derivatives\propagation.jl:34
increment_deriv!(::AbstractArray, ::Real, ::Any) at C:\Users\Sven\.julia\packages\ReverseDiff\60noS\src\derivatives\propagation.jl:36
increment_deriv!(::AbstractArray, ::Any) at C:\Users\Sven\.julia\packages\ReverseDiff\60noS\src\derivatives\propagation.jl:38
...
Stacktrace:
[1] increment_deriv! at C:\Users\Sven\.julia\packages\ReverseDiff\60noS\src\derivatives\propagation.jl:35 [inlined]
[2] increment_deriv!(::Array{Real,1}, ::Array{Float64,1}) at C:\Users\Sven\.julia\packages\ReverseDiff\60noS\src\derivatives\propagation.jl:40
[3] reverse_mul!(::ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}, ::Array{Float64,1}, ::ReverseDiff.TrackedArray{Float32,Float64,2,Array{Float32,2},Array{Float64,2}}, ::Array{Real,1}, ::Array{Float64,2}, ::Array{Float64,1}) at C:\Users\Sven\.julia\packages\ReverseDiff\60noS\src\derivatives\linalg\arithmetic.jl:274
[4] special_reverse_exec!(::ReverseDiff.SpecialInstruction{typeof(*),Tuple{ReverseDiff.TrackedArray{Float32,Float64,2,Array{Float32,2},Array{Float64,2}},Array{Real,1}},ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}},Tuple{Array{Float64,2},Array{Float64,1}}}) at C:\Users\Sven\.julia\packages\ReverseDiff\60noS\src\derivatives\linalg\arithmetic.jl:265
[5] reverse_exec!(::ReverseDiff.SpecialInstruction{typeof(*),Tuple{ReverseDiff.TrackedArray{Float32,Float64,2,Array{Float32,2},Array{Float64,2}},Array{Real,1}},ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}},Tuple{Array{Float64,2},Array{Float64,1}}}) at C:\Users\Sven\.julia\packages\ReverseDiff\60noS\src\tape.jl:93
[6] reverse_pass!(::Array{ReverseDiff.AbstractInstruction,1}) at C:\Users\Sven\.julia\packages\ReverseDiff\60noS\src\tape.jl:87
[7] reverse_pass! at C:\Users\Sven\.julia\packages\ReverseDiff\60noS\src\api\tape.jl:36 [inlined]
[10] #180 at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\lib\lib.jl:194 [inlined]
[12] #solve#57 at C:\Users\Sven\.julia\packages\DiffEqBase\U3Zj7\src\solve.jl:70 [inlined]
[13] (::typeof(∂(#solve#57)))(::Array{Array{Float64,1},1}) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0
[14] (::Zygote.var"#180#181"{typeof(∂(#solve#57)),Tuple{NTuple{6,Nothing},Tuple{Nothing}}})(::Array{Array{Float64,1},1}) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\lib\lib.jl:194
[16] (::typeof(∂(solve##kw)))(::Array{Array{Float64,1},1}) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0
[17] #batch_func#452 at C:\Users\Sven\.julia\packages\SciMLBase\XuLdB\src\ensemble\basic_ensemble_solve.jl:143 [inlined]
[18] (::typeof(∂(#batch_func#452)))(::Array{Array{Float64,1},1}) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0 (repeats 2 times)
[19] #457 at C:\Users\Sven\.julia\packages\SciMLBase\XuLdB\src\ensemble\basic_ensemble_solve.jl:195 [inlined]
[20] (::typeof(∂(λ)))(::Array{Array{Float64,1},1}) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0
[21] (::DiffEqBase.var"#219#227")(::typeof(∂(λ)), ::Array{Array{Float64,1},1}) at C:\Users\Sven\.julia\packages\DiffEqBase\U3Zj7\src\init.jl:259
[22] responsible_map(::Function, ::Array{typeof(∂(λ)),1}, ::Vararg{Any,N} where N) at C:\Users\Sven\.julia\packages\SciMLBase\XuLdB\src\ensemble\basic_ensemble_solve.jl:188
[23] (::DiffEqBase.var"#∇responsible_map_internal#226"{Array{typeof(∂(λ)),1}})(::EnsembleSolution{Array{Float64,1},2,Array{Array{Array{Float64,1},1},1}}) at C:\Users\Sven\.julia\packages\DiffEqBase\U3Zj7\src\init.jl:259
[25] #solve_batch#456 at C:\Users\Sven\.julia\packages\SciMLBase\XuLdB\src\ensemble\basic_ensemble_solve.jl:194 [inlined]
[26] (::typeof(∂(#solve_batch#456)))(::EnsembleSolution{Array{Float64,1},2,Array{Array{Array{Float64,1},1},1}}) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0 (repeats 2 times)
[27] #solve_batch#459 at C:\Users\Sven\.julia\packages\SciMLBase\XuLdB\src\ensemble\basic_ensemble_solve.jl:203 [inlined]
[28] (::typeof(∂(#solve_batch#459)))(::EnsembleSolution{Array{Float64,1},2,Array{Array{Array{Float64,1},1},1}}) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0 (repeats 2 times)
[29] macro expansion at .\timing.jl:233 [inlined]
[30] #__solve#451 at C:\Users\Sven\.julia\packages\SciMLBase\XuLdB\src\ensemble\basic_ensemble_solve.jl:108 [inlined]
[31] (::typeof(∂(#__solve#451)))(::Array{Float64,3}) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0 (repeats 2 times)
[32] #180 at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\lib\lib.jl:194 [inlined]
[34] #solve#59 at C:\Users\Sven\.julia\packages\DiffEqBase\U3Zj7\src\solve.jl:96 [inlined]
[35] (::typeof(∂(#solve#59)))(::Array{Float64,3}) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0
[36] (::Zygote.var"#180#181"{typeof(∂(#solve#59)),Tuple{Tuple{Nothing,Nothing,Nothing},Tuple{Nothing,Nothing}}})(::Array{Float64,3}) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\lib\lib.jl:194
[38] (::typeof(∂(solve##kw)))(::Array{Float64,3}) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0
[39] loss at c:\Users\Sven\Documents\Master\Master_thesis\Julia\Code\TestRepository\TestFileDiscourse.jl:36 [inlined]
[40] (::typeof(∂(loss)))(::Float64) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0
[41] #24 at c:\Users\Sven\Documents\Master\Master_thesis\Julia\Code\TestRepository\TestFileDiscourse.jl:47 [inlined]
[42] (::typeof(∂(#24)))(::Float64) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0
[43] #69 at C:\Users\Sven\.julia\packages\DiffEqFlux\alPQ3\src\train.jl:3 [inlined]
[44] #180 at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\lib\lib.jl:194 [inlined]
[46] OptimizationFunction at C:\Users\Sven\.julia\packages\SciMLBase\XuLdB\src\problems\basic_problems.jl:107 [inlined]
... (the last 3 lines are repeated 1 more time)
[50] #180 at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\lib\lib.jl:194 [inlined]
[52] #8 at C:\Users\Sven\.julia\packages\GalacticOptim\JnLwV\src\solve.jl:94 [inlined]
[53] (::typeof(∂(λ)))(::Float64) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0
[54] (::Zygote.var"#69#70"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float64) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\compiler\interface.jl:252
[56] __solve(::OptimizationProblem{false,OptimizationFunction{false,GalacticOptim.AutoZygote,OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{var"#24#25"},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},GalacticOptim.var"#146#156"{GalacticOptim.var"#145#155"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{var"#24#25"},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#149#159"{GalacticOptim.var"#145#155"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{var"#24#25"},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#154#164",Nothing,Nothing,Nothing},Array{Float32,1},SciMLBase.NullParameters,Nothing,Nothing,Nothing,Base.Iterators.Pairs{Symbol,Int64,Tuple{Symbol},NamedTuple{(:maxiters,),Tuple{Int64}}}}, ::ADAM, ::Base.Iterators.Cycle{Tuple{GalacticOptim.NullData}}; maxiters::Int64, cb::Function, progress::Bool, save_best::Bool, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at C:\Users\Sven\.julia\packages\GalacticOptim\JnLwV\src\solve.jl:93
[57] #solve#468 at C:\Users\Sven\.julia\packages\SciMLBase\XuLdB\src\solve.jl:3 [inlined]
[58] sciml_train(::var"#24#25", ::Array{Float32,1}, ::ADAM, ::GalacticOptim.AutoZygote; lower_bounds::Nothing, upper_bounds::Nothing, kwargs::Base.Iterators.Pairs{Symbol,Int64,Tuple{Symbol},NamedTuple{(:maxiters,),Tuple{Int64}}}) at C:\Users\Sven\.julia\packages\DiffEqFlux\alPQ3\src\train.jl:6
in expression starting at c:\Users\Sven\Documents\Master\Master_thesis\Julia\Code\TestRepository\TestFileDiscourse.jl:47
``````

Cause of the error
I believe the cause of the error is inside `function f!(du,u,p,t)`. More specific, I believe `[u;t]` in `nn([u;t],p)` causes the error. When changing the variable `input_size` to `x_size` and replacing `nn([u;t],p)` by `nn(u,p)` everything works fine.

Question:
What adjustments can I make to the code to make it work?

What version of ReverseDiff do you have? v1.8? Show `]st -m` . This looks like the error on the latest ReverseDiff.jl which we specifically block in DiffEqSensitivity.jl because it causes this issue.

Thank you for your response. It is indeed version 1.8. I just checked whether the code works with BacksolveAdjoint(): it does. I will look at TrackerAdjoint() and see if that works.

What do you exactly mean by ‘blocking’ ReverseDiff.jl? When I introduce:

``````# Define an identity matrix
mult_mat = zeros((input_size, input_size))
for i = 1:input_size
mult_mat[i,i] = 1
end
``````

and change my `f!` to

``````function f!(du,u,p,t)
du .= 2.0.*u + const_mat*nn(mult_mat*[u;t],p)
end
``````

the code does run with ReverseDiffAdjoint() and gives an output. So it does not totally block ReverseDiff.jl, right?

No, we just block v1.8.

So, if I understand correctly, if I would switch to v1.7 of ReverseDiff it should work?

Yes

Great! Thanks for taking your time in helping me solve this problem. I really appreciate it.