DiffEqFlux: an error when using NN's that take both time and position as input

Hi everyone, I am new to the Julia programming language and ran into an error when programming.

My problem:
I would like to solve a ODE optimal control problem with a neural network based approach. I take the ODE and replace the control variable by a Neural Network that takes as input both the position and time. The Neural Network is trained by using the sciml_train function (of the DiffEqFlux package) on a particular loss function. This loss function is the difference between the position and the all-ones vector. Hence, I want to steer the solution of the ODE to the all-ones vector.

Code:

using Flux, DiffEqFlux, DiffEqSensitivity, DifferentialEquations

x_size = 6 # Size of the spatial dimensions in the ODE
v_size = 2 # Output size of the control 

# Define Neural Network for the control input
input_size = x_size + 1 # size of the spatial dimensions PLUS one time dimensions
nn_initial = Chain(Dense(input_size,32, relu), Dense(32,v_size, tanh)) # The actual neural network
p_nn, model = Flux.destructure(nn_initial)
nn(x,p) = model(p)(x) 

# Define the right hand side of the ODE
const_mat = ones(Float64, (x_size, v_size)) 

function f!(du,u,p,t)
    du .= 2.0.*u + const_mat*nn([u;t],p)
end

# Define ODE problem
u0 = vec(rand(Float64, (x_size,1)))
tspan = (0.0, 1.0) 
prob = ODEProblem{true}(f!, u0, tspan, p_nn)

# Defining the loss function
function loss(pars, prob)

    function prob_func(prob, i, repeat)
        # Prepare new initial sate and remake the problem
        u0tmp = vec(rand(Float64,(x_size,1)))

        remake(prob, p = pars, u0 = u0tmp)
    end

    ensembleprob = EnsembleProblem(prob, prob_func = prob_func)

    _sol = solve(ensembleprob, Tsit5(), EnsembleThreads(), sensealg = ReverseDiffAdjoint(), saveat = vec(0:0.1:1);
    dt = 0.01, trajectories = 10)

    A = convert(Array,_sol)

    loss = sum(abs2, A .- 1)

    return loss
end

# Training the model
result = DiffEqFlux.sciml_train((p) -> loss(p,prob), p_nn, ADAM(0.01), maxiters=5)

Error when running the code:

ERROR: LoadError: MethodError: no method matching increment_deriv!(::Float64, ::Float64)
Closest candidates are:
  increment_deriv!(::ReverseDiff.TrackedArray, ::Real, ::Any) at C:\Users\Sven\.julia\packages\ReverseDiff\60noS\src\derivatives\propagation.jl:34
  increment_deriv!(::AbstractArray, ::Real, ::Any) at C:\Users\Sven\.julia\packages\ReverseDiff\60noS\src\derivatives\propagation.jl:36
  increment_deriv!(::AbstractArray, ::Any) at C:\Users\Sven\.julia\packages\ReverseDiff\60noS\src\derivatives\propagation.jl:38
  ...
Stacktrace:
 [1] increment_deriv! at C:\Users\Sven\.julia\packages\ReverseDiff\60noS\src\derivatives\propagation.jl:35 [inlined]
 [2] increment_deriv!(::Array{Real,1}, ::Array{Float64,1}) at C:\Users\Sven\.julia\packages\ReverseDiff\60noS\src\derivatives\propagation.jl:40
 [3] reverse_mul!(::ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}, ::Array{Float64,1}, ::ReverseDiff.TrackedArray{Float32,Float64,2,Array{Float32,2},Array{Float64,2}}, ::Array{Real,1}, ::Array{Float64,2}, ::Array{Float64,1}) at C:\Users\Sven\.julia\packages\ReverseDiff\60noS\src\derivatives\linalg\arithmetic.jl:274
 [4] special_reverse_exec!(::ReverseDiff.SpecialInstruction{typeof(*),Tuple{ReverseDiff.TrackedArray{Float32,Float64,2,Array{Float32,2},Array{Float64,2}},Array{Real,1}},ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}},Tuple{Array{Float64,2},Array{Float64,1}}}) at C:\Users\Sven\.julia\packages\ReverseDiff\60noS\src\derivatives\linalg\arithmetic.jl:265
 [5] reverse_exec!(::ReverseDiff.SpecialInstruction{typeof(*),Tuple{ReverseDiff.TrackedArray{Float32,Float64,2,Array{Float32,2},Array{Float64,2}},Array{Real,1}},ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}},Tuple{Array{Float64,2},Array{Float64,1}}}) at C:\Users\Sven\.julia\packages\ReverseDiff\60noS\src\tape.jl:93
 [6] reverse_pass!(::Array{ReverseDiff.AbstractInstruction,1}) at C:\Users\Sven\.julia\packages\ReverseDiff\60noS\src\tape.jl:87
 [7] reverse_pass! at C:\Users\Sven\.julia\packages\ReverseDiff\60noS\src\api\tape.jl:36 [inlined]
 [8] reversediff_adjoint_backpass at C:\Users\Sven\.julia\packages\DiffEqSensitivity\LDOtY\src\concrete_solve.jl:413 [inlined]
 [9] #263#back at C:\Users\Sven\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:65 [inlined]
 [10] #180 at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\lib\lib.jl:194 [inlined]
 [11] (::Zygote.var"#1701#back#182"{Zygote.var"#180#181"{DiffEqBase.var"#263#back#74"{DiffEqSensitivity.var"#reversediff_adjoint_backpass#205"{Tuple{},ReverseDiff.GradientTape{DiffEqSensitivity.var"#reversediff_adjoint_forwardpass#202"{Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:saveat, :dt),Tuple{StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Float64}}},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float32,1},ODEFunction{true,typeof(f!),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},SciMLBase.StandardODEProblem},Tsit5,ReverseDiffAdjoint,Tuple{}},Tuple{ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}},ReverseDiff.TrackedArray{Float32,Float64,1,Array{Float32,1},Array{Float64,1}}},Array{ReverseDiff.TrackedReal{Float64,Float64,ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}},2}},ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}},ReverseDiff.TrackedArray{Float32,Float64,1,Array{Float32,1},Array{Float64,1}},Array{ReverseDiff.TrackedReal{Float64,Float64,ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}},2}}},Tuple{NTuple{6,Nothing},Tuple{Nothing}}}})(::Array{Array{Float64,1},1}) at C:\Users\Sven\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [12] #solve#57 at C:\Users\Sven\.julia\packages\DiffEqBase\U3Zj7\src\solve.jl:70 [inlined]
 [13] (::typeof(∂(#solve#57)))(::Array{Array{Float64,1},1}) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0
 [14] (::Zygote.var"#180#181"{typeof(∂(#solve#57)),Tuple{NTuple{6,Nothing},Tuple{Nothing}}})(::Array{Array{Float64,1},1}) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\lib\lib.jl:194
 [15] (::Zygote.var"#1701#back#182"{Zygote.var"#180#181"{typeof(∂(#solve#57)),Tuple{NTuple{6,Nothing},Tuple{Nothing}}}})(::Array{Array{Float64,1},1}) at C:\Users\Sven\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [16] (::typeof(∂(solve##kw)))(::Array{Array{Float64,1},1}) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0
 [17] #batch_func#452 at C:\Users\Sven\.julia\packages\SciMLBase\XuLdB\src\ensemble\basic_ensemble_solve.jl:143 [inlined]
 [18] (::typeof(∂(#batch_func#452)))(::Array{Array{Float64,1},1}) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0 (repeats 2 times)
 [19] #457 at C:\Users\Sven\.julia\packages\SciMLBase\XuLdB\src\ensemble\basic_ensemble_solve.jl:195 [inlined]
 [20] (::typeof(∂(λ)))(::Array{Array{Float64,1},1}) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0
 [21] (::DiffEqBase.var"#219#227")(::typeof(∂(λ)), ::Array{Array{Float64,1},1}) at C:\Users\Sven\.julia\packages\DiffEqBase\U3Zj7\src\init.jl:259
 [22] responsible_map(::Function, ::Array{typeof(∂(λ)),1}, ::Vararg{Any,N} where N) at C:\Users\Sven\.julia\packages\SciMLBase\XuLdB\src\ensemble\basic_ensemble_solve.jl:188
 [23] (::DiffEqBase.var"#∇responsible_map_internal#226"{Array{typeof(∂(λ)),1}})(::EnsembleSolution{Array{Float64,1},2,Array{Array{Array{Float64,1},1},1}}) at C:\Users\Sven\.julia\packages\DiffEqBase\U3Zj7\src\init.jl:259
 [24] #361#back at C:\Users\Sven\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59 [inlined]
 [25] #solve_batch#456 at C:\Users\Sven\.julia\packages\SciMLBase\XuLdB\src\ensemble\basic_ensemble_solve.jl:194 [inlined]
 [26] (::typeof(∂(#solve_batch#456)))(::EnsembleSolution{Array{Float64,1},2,Array{Array{Array{Float64,1},1},1}}) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0 (repeats 2 times)
 [27] #solve_batch#459 at C:\Users\Sven\.julia\packages\SciMLBase\XuLdB\src\ensemble\basic_ensemble_solve.jl:203 [inlined]
 [28] (::typeof(∂(#solve_batch#459)))(::EnsembleSolution{Array{Float64,1},2,Array{Array{Array{Float64,1},1},1}}) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0 (repeats 2 times)
 [29] macro expansion at .\timing.jl:233 [inlined]
 [30] #__solve#451 at C:\Users\Sven\.julia\packages\SciMLBase\XuLdB\src\ensemble\basic_ensemble_solve.jl:108 [inlined]
 [31] (::typeof(∂(#__solve#451)))(::Array{Float64,3}) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0 (repeats 2 times)
 [32] #180 at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\lib\lib.jl:194 [inlined]
 [33] (::Zygote.var"#1701#back#182"{Zygote.var"#180#181"{typeof(∂(__solve##kw)),Tuple{Tuple{Nothing,Nothing,Nothing},Tuple{Nothing,Nothing}}}})(::Array{Float64,3}) at C:\Users\Sven\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [34] #solve#59 at C:\Users\Sven\.julia\packages\DiffEqBase\U3Zj7\src\solve.jl:96 [inlined]
 [35] (::typeof(∂(#solve#59)))(::Array{Float64,3}) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0
 [36] (::Zygote.var"#180#181"{typeof(∂(#solve#59)),Tuple{Tuple{Nothing,Nothing,Nothing},Tuple{Nothing,Nothing}}})(::Array{Float64,3}) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\lib\lib.jl:194
 [37] #1701#back at C:\Users\Sven\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59 [inlined]
 [38] (::typeof(∂(solve##kw)))(::Array{Float64,3}) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0
 [39] loss at c:\Users\Sven\Documents\Master\Master_thesis\Julia\Code\TestRepository\TestFileDiscourse.jl:36 [inlined]
 [40] (::typeof(∂(loss)))(::Float64) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0
 [41] #24 at c:\Users\Sven\Documents\Master\Master_thesis\Julia\Code\TestRepository\TestFileDiscourse.jl:47 [inlined]
 [42] (::typeof(∂(#24)))(::Float64) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0
 [43] #69 at C:\Users\Sven\.julia\packages\DiffEqFlux\alPQ3\src\train.jl:3 [inlined]
 [44] #180 at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\lib\lib.jl:194 [inlined]
 [45] #1701#back at C:\Users\Sven\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59 [inlined]
 [46] OptimizationFunction at C:\Users\Sven\.julia\packages\SciMLBase\XuLdB\src\problems\basic_problems.jl:107 [inlined]
 ... (the last 3 lines are repeated 1 more time)
 [50] #180 at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\lib\lib.jl:194 [inlined]
 [51] (::Zygote.var"#1701#back#182"{Zygote.var"#180#181"{typeof(∂(λ)),Tuple{Tuple{Nothing,Nothing},Int64}}})(::Float64) at C:\Users\Sven\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [52] #8 at C:\Users\Sven\.julia\packages\GalacticOptim\JnLwV\src\solve.jl:94 [inlined]
 [53] (::typeof(∂(λ)))(::Float64) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0
 [54] (::Zygote.var"#69#70"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float64) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\compiler\interface.jl:252
 [55] gradient(::Function, ::Zygote.Params) at C:\Users\Sven\.julia\packages\Zygote\6HN9x\src\compiler\interface.jl:59
 [56] __solve(::OptimizationProblem{false,OptimizationFunction{false,GalacticOptim.AutoZygote,OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{var"#24#25"},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},GalacticOptim.var"#146#156"{GalacticOptim.var"#145#155"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{var"#24#25"},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#149#159"{GalacticOptim.var"#145#155"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{var"#24#25"},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#154#164",Nothing,Nothing,Nothing},Array{Float32,1},SciMLBase.NullParameters,Nothing,Nothing,Nothing,Base.Iterators.Pairs{Symbol,Int64,Tuple{Symbol},NamedTuple{(:maxiters,),Tuple{Int64}}}}, ::ADAM, ::Base.Iterators.Cycle{Tuple{GalacticOptim.NullData}}; maxiters::Int64, cb::Function, progress::Bool, save_best::Bool, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at C:\Users\Sven\.julia\packages\GalacticOptim\JnLwV\src\solve.jl:93
 [57] #solve#468 at C:\Users\Sven\.julia\packages\SciMLBase\XuLdB\src\solve.jl:3 [inlined]
 [58] sciml_train(::var"#24#25", ::Array{Float32,1}, ::ADAM, ::GalacticOptim.AutoZygote; lower_bounds::Nothing, upper_bounds::Nothing, kwargs::Base.Iterators.Pairs{Symbol,Int64,Tuple{Symbol},NamedTuple{(:maxiters,),Tuple{Int64}}}) at C:\Users\Sven\.julia\packages\DiffEqFlux\alPQ3\src\train.jl:6
in expression starting at c:\Users\Sven\Documents\Master\Master_thesis\Julia\Code\TestRepository\TestFileDiscourse.jl:47

Cause of the error
I believe the cause of the error is inside function f!(du,u,p,t). More specific, I believe [u;t] in nn([u;t],p) causes the error. When changing the variable input_size to x_size and replacing nn([u;t],p) by nn(u,p) everything works fine.

Question:
What adjustments can I make to the code to make it work?

What version of ReverseDiff do you have? v1.8? Show ]st -m . This looks like the error on the latest ReverseDiff.jl which we specifically block in DiffEqSensitivity.jl because it causes this issue.

Thank you for your response. It is indeed version 1.8. I just checked whether the code works with BacksolveAdjoint(): it does. I will look at TrackerAdjoint() and see if that works.

What do you exactly mean by ‘blocking’ ReverseDiff.jl? When I introduce:

# Define an identity matrix
mult_mat = zeros((input_size, input_size))
for i = 1:input_size
    mult_mat[i,i] = 1
end

and change my f! to

function f!(du,u,p,t)
    du .= 2.0.*u + const_mat*nn(mult_mat*[u;t],p)
end

the code does run with ReverseDiffAdjoint() and gives an output. So it does not totally block ReverseDiff.jl, right?

No, we just block v1.8.

So, if I understand correctly, if I would switch to v1.7 of ReverseDiff it should work?

Yes

Great! Thanks for taking your time in helping me solve this problem. I really appreciate it.