I am trying to take the derivative of a function that is defined as the solution of an ODE. But gradient throws an error:
using Flux: gradient
using DifferentialEquations: Tsit5, ODEProblem, solve
using DiffEqSensitivity
function odesolution(x)
f(u,p,t) = 0.1u
prob = ODEProblem(f,x,[0.,1.])
sol = solve(prob, Tsit5(), reltol=1e-8, abstol=1e-8)
return last(sol)
end
x = 1.0
gs = gradient(odesolution, x)
MethodError: no method matching similar(::Float64, ::Int64)
Closest candidates are:
similar(::Test.GenericArray, ::Integer...) at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/Test/src/Test.jl:1831
similar(::ReverseDiff.TrackedArray, ::Union{Integer, AbstractUnitRange}...) at ~/.julia/packages/ReverseDiff/Y5qec/src/tracked.jl:387
similar(::BitArray, ::Int64...) at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/bitarray.jl:369
...
Stacktrace:
[1] ODEAdjointProblem(sol::SciMLBase.ODESolution{Float64, 1, Vector{Float64}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Float64}}, ODEProblem{Float64, Tuple{Float64, Float64}, false, SciMLBase.NullParameters, SciMLBase.ODEFunction{false, var"#f#2", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{SciMLBase.ODEFunction{false, var"#f#2", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Float64}, Vector{Float64}, Vector{Vector{Float64}}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, sensealg::QuadratureAdjoint{0, true, Val{:central}, ZygoteVJP}, g::Function, t::Vector{Float64}, dg::Nothing, callback::Nothing)
@ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/T7LDZ/src/quadrature_adjoint.jl:66
[2] _adjoint_sensitivities(sol::SciMLBase.ODESolution{Float64, 1, Vector{Float64}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Float64}}, ODEProblem{Float64, Tuple{Float64, Float64}, false, SciMLBase.NullParameters, SciMLBase.ODEFunction{false, var"#f#2", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{SciMLBase.ODEFunction{false, var"#f#2", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Float64}, Vector{Float64}, Vector{Vector{Float64}}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, sensealg::QuadratureAdjoint{0, true, Val{:central}, ZygoteVJP}, alg::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, g::DiffEqSensitivity.var"#df#251"{Vector{Any}, Colon}, t::Vector{Float64}, dg::Nothing; abstol::Float64, reltol::Float64, callback::Nothing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/T7LDZ/src/quadrature_adjoint.jl:252
[3] adjoint_sensitivities(::SciMLBase.ODESolution{Float64, 1, Vector{Float64}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Float64}}, ODEProblem{Float64, Tuple{Float64, Float64}, false, SciMLBase.NullParameters, SciMLBase.ODEFunction{false, var"#f#2", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{SciMLBase.ODEFunction{false, var"#f#2", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Float64}, Vector{Float64}, Vector{Vector{Float64}}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ::Vararg{Any}; sensealg::QuadratureAdjoint{0, true, Val{:central}, ZygoteVJP}, kwargs::Base.Pairs{Symbol, Union{Nothing, Float64}, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:callback, :reltol, :abstol), Tuple{Nothing, Float64, Float64}}})
@ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/T7LDZ/src/sensitivity_interface.jl:6
[4] (::DiffEqSensitivity.var"#adjoint_sensitivity_backpass#250"{Base.Pairs{Symbol, Float64, Tuple{Symbol, Symbol}, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, QuadratureAdjoint{0, true, Val{:central}, ZygoteVJP}, Float64, SciMLBase.NullParameters, Tuple{}, Colon, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}})(Δ::Vector{Any})
@ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/T7LDZ/src/concrete_solve.jl:249
[5] ZBack
@ ~/.julia/packages/Zygote/ytjqm/src/compiler/chainrules.jl:205 [inlined]
[6] (::Zygote.var"#kw_zpullback#41"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#250"{Base.Pairs{Symbol, Float64, Tuple{Symbol, Symbol}, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, QuadratureAdjoint{0, true, Val{:central}, ZygoteVJP}, Float64, SciMLBase.NullParameters, Tuple{}, Colon, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}}})(dy::RecursiveArrayTools.VectorOfArray{Any, 1, Vector{Any}})
@ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/chainrules.jl:231
[7] #212
@ ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:203 [inlined]
[8] (::Zygote.var"#1750#back#214"{Zygote.var"#212#213"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, Zygote.var"#kw_zpullback#41"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#250"{Base.Pairs{Symbol, Float64, Tuple{Symbol, Symbol}, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, QuadratureAdjoint{0, true, Val{:central}, ZygoteVJP}, Float64, SciMLBase.NullParameters, Tuple{}, Colon, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}}}}})(Δ::RecursiveArrayTools.VectorOfArray{Any, 1, Vector{Any}})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[9] Pullback
@ ~/.julia/packages/DiffEqBase/ziNGu/src/solve.jl:165 [inlined]
[10] (::typeof(∂(#solve#40)))(Δ::RecursiveArrayTools.VectorOfArray{Any, 1, Vector{Any}})
@ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
[11] (::Zygote.var"#212#213"{Tuple{NTuple{7, Nothing}, Tuple{Nothing}}, typeof(∂(#solve#40))})(Δ::RecursiveArrayTools.VectorOfArray{Any, 1, Vector{Any}})
@ Zygote ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:203
[12] (::Zygote.var"#1750#back#214"{Zygote.var"#212#213"{Tuple{NTuple{7, Nothing}, Tuple{Nothing}}, typeof(∂(#solve#40))}})(Δ::RecursiveArrayTools.VectorOfArray{Any, 1, Vector{Any}})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[13] Pullback
@ ~/.julia/packages/DiffEqBase/ziNGu/src/solve.jl:159 [inlined]
[14] (::typeof(∂(solve##kw)))(Δ::RecursiveArrayTools.VectorOfArray{Any, 1, Vector{Any}})
@ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
[15] Pullback
@ ./In[5]:4 [inlined]
[16] (::typeof(∂(odesolution)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
[17] (::Zygote.var"#56#57"{typeof(∂(odesolution))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:41
[18] gradient(f::Function, args::Float64)
@ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:76
[19] top-level scope
@ In[6]:2
[20] eval
@ ./boot.jl:373 [inlined]
[21] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
@ Base ./loading.jl:1196