I need to cache some data during the evaluation of an ODE. However the time gradients make it a bit difficult to me and cuse a MethodError
.
MWE
using DiffEqBase, OrdinaryDiffEq
struct rb{T}
B::T
end
rb(n::Integer)=rb(DiffEqBase.dualcache(zeros(n), Val{3}))
function (f!::rb)(du,u,p,t)
tmp = DiffEqBase.get_tmp(f!.B, u)
tmp .= tmp .+ t # Possible dual number
du .= u.*tmp
end
f! = rb(3)
u0=zeros(3)
p=ODEProblem(f!,u0,(0,1.),p=[])
julia> solve(p,Rodas5())
ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.TimeGradientWrapper{ODEFunction{true, rb{DiffEqBase.DiffCache{Vector{Float64}, Vector{ForwardDiff.Dual{nothing, Float64, 3}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Float64}, SciMLBase.NullParameters}, Float64}, Float64, 1})
Closest candidates are:
(::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat at rounding.jl:200
(::Type{T})(::T) where T<:Number at boot.jl:760
(::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number} at char.jl:50
...
Stacktrace:
[1] convert(#unused#::Type{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.TimeGradientWrapper{ODEFunction{true, rb{DiffEqBase.DiffCache{Vector{Float64}, Vector{ForwardDiff.Dual{nothing, Float64, 3}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Float64}, SciMLBase.NullParameters}, Float64}, Float64, 1})
@ Base ./number.jl:7
[2] setindex!(A::Vector{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.TimeGradientWrapper{ODEFunction{true, rb{DiffEqBase.DiffCache{Vector{Float64}, Vector{ForwardDiff.Dual{nothing, Float64, 3}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Float64}, SciMLBase.NullParameters}, Float64}, Float64, 1}, i1::Int64)
@ Base ./array.jl:839
[3] macro expansion
@ ./broadcast.jl:984 [inlined]
[4] macro expansion
@ ./simdloop.jl:77 [inlined]
[5] copyto!
@ ./broadcast.jl:983 [inlined]
[6] copyto!
@ ./broadcast.jl:936 [inlined]
[7] materialize!
@ ./broadcast.jl:894 [inlined]
[8] materialize!
@ ./broadcast.jl:891 [inlined]
[9] (::rb{DiffEqBase.DiffCache{Vector{Float64}, Vector{ForwardDiff.Dual{nothing, Float64, 3}}}})(du::Vector{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.TimeGradientWrapper{ODEFunction{true, rb{DiffEqBase.DiffCache{Vector{Float64}, Vector{ForwardDiff.Dual{nothing, Float64, 3}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Float64}, SciMLBase.NullParameters}, Float64}, Float64, 1}}, u::Vector{Float64}, p::SciMLBase.NullParameters, t::ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.TimeGradientWrapper{ODEFunction{true, rb{DiffEqBase.DiffCache{Vector{Float64}, Vector{ForwardDiff.Dual{nothing, Float64, 3}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Float64}, SciMLBase.NullParameters}, Float64}, Float64, 1})
@ Main ~/.julia/dev/AnnA/Playground.jl:189
[10] ODEFunction
@ ~/.julia/packages/SciMLBase/9EjAY/src/scimlfunctions.jl:334 [inlined]
[11] TimeGradientWrapper
@ ~/.julia/packages/SciMLBase/9EjAY/src/function_wrappers.jl:7 [inlined]
[12] derivative!(df::Vector{Float64}, f::SciMLBase.TimeGradientWrapper{ODEFunction{true,
...
[16] perform_step!
@ ~/.julia/packages/OrdinaryDiffEq/vxMSM/src/perform_step/rosenbrock_perform_step.jl:1053 [inlined]
[17] solve!(integrator::OrdinaryDiffEq.ODEIntegrator{Rodas5{0, true, DefaultLinSolve, DataType}, true, Vector{Float64}, Nothing, Float64, SciMLBase.NullParameters, Float64, Float64, Float64, Vector{Vector{Float64}}, ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true,
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/vxMSM/src/solve.jl:455
[18] __solve(::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, rb{DiffEqBase.DiffCache{Vector{Float64}, Vector{ForwardDiff.Dual{nothing, Float64, 3}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Symbol, Vector{Any}, Tuple{Symbol}, NamedTuple{(:p,), Tuple{Vector{Any}}}}, SciMLBase.StandardODEProblem}, ::Rodas5{0, true, DefaultLinSolve, DataType}; kwargs::Base.Iterators.Pairs{Symbol, Vector{Any}, Tuple{Symbol}, NamedTuple{(:p,), Tuple{Vector{Any}}}})
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/vxMSM/src/solve.jl:5
[19] #solve_call#56
@ ~/.julia/packages/DiffEqBase/jhLIm/src/solve.jl:61 [inlined]
[20] solve_call
@ ~/.julia/packages/DiffEqBase/jhLIm/src/solve.jl:48 [inlined]
[21] #solve_up#58
@ ~/.julia/packages/DiffEqBase/jhLIm/src/solve.jl:82 [inlined]
[22] solve_up
@ ~/.julia/packages/DiffEqBase/jhLIm/src/solve.jl:75 [inlined]
[23] #solve#57
@ ~/.julia/packages/DiffEqBase/jhLIm/src/solve.jl:70 [inlined]
[24] solve(prob::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, rb{DiffEqBase.DiffCache{Vector{Float64}, Vector{ForwardDiff.Dual{nothing, Float64, 3}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Symbol, Vector{Any}, Tuple{Symbol}, NamedTuple{(:p,), Tuple{Vector{Any}}}}, SciMLBase.StandardODEProblem}, args::Rodas5{0, true, DefaultLinSolve, DataType})
@ DiffEqBase ~/.julia/packages/DiffEqBase/jhLIm/src/solve.jl:68
[25] top-level scope
@ REPL[25]:1
I can circumvent the problem by:
function (f!::rb)(du,u,p,t) #works
tmp =@view DiffEqBase.get_tmp(f!.B, du)[1:3]
tmp .= tmp .+ t
du .= u.*tmp
end
but it dos not look like this is how it shold be done. Can someone push me on the right track?