Using Flux: gradient on DifferentialEquations: solve results in an error

I am trying to take the derivative of a function that is defined as the solution of an ODE. But gradient throws an error:

using Flux: gradient
using DifferentialEquations: Tsit5, ODEProblem, solve
using DiffEqSensitivity

function odesolution(x)
    f(u,p,t) = 0.1u
    prob = ODEProblem(f,x,[0.,1.])
    sol = solve(prob, Tsit5(), reltol=1e-8, abstol=1e-8)
    return last(sol)
end

x = 1.0
gs = gradient(odesolution, x)
MethodError: no method matching similar(::Float64, ::Int64)
Closest candidates are:
  similar(::Test.GenericArray, ::Integer...) at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/Test/src/Test.jl:1831
  similar(::ReverseDiff.TrackedArray, ::Union{Integer, AbstractUnitRange}...) at ~/.julia/packages/ReverseDiff/Y5qec/src/tracked.jl:387
  similar(::BitArray, ::Int64...) at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/bitarray.jl:369
  ...

Stacktrace:
  [1] ODEAdjointProblem(sol::SciMLBase.ODESolution{Float64, 1, Vector{Float64}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Float64}}, ODEProblem{Float64, Tuple{Float64, Float64}, false, SciMLBase.NullParameters, SciMLBase.ODEFunction{false, var"#f#2", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{SciMLBase.ODEFunction{false, var"#f#2", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Float64}, Vector{Float64}, Vector{Vector{Float64}}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, sensealg::QuadratureAdjoint{0, true, Val{:central}, ZygoteVJP}, g::Function, t::Vector{Float64}, dg::Nothing, callback::Nothing)
    @ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/T7LDZ/src/quadrature_adjoint.jl:66
  [2] _adjoint_sensitivities(sol::SciMLBase.ODESolution{Float64, 1, Vector{Float64}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Float64}}, ODEProblem{Float64, Tuple{Float64, Float64}, false, SciMLBase.NullParameters, SciMLBase.ODEFunction{false, var"#f#2", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{SciMLBase.ODEFunction{false, var"#f#2", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Float64}, Vector{Float64}, Vector{Vector{Float64}}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, sensealg::QuadratureAdjoint{0, true, Val{:central}, ZygoteVJP}, alg::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, g::DiffEqSensitivity.var"#df#251"{Vector{Any}, Colon}, t::Vector{Float64}, dg::Nothing; abstol::Float64, reltol::Float64, callback::Nothing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/T7LDZ/src/quadrature_adjoint.jl:252
  [3] adjoint_sensitivities(::SciMLBase.ODESolution{Float64, 1, Vector{Float64}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Float64}}, ODEProblem{Float64, Tuple{Float64, Float64}, false, SciMLBase.NullParameters, SciMLBase.ODEFunction{false, var"#f#2", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{SciMLBase.ODEFunction{false, var"#f#2", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Float64}, Vector{Float64}, Vector{Vector{Float64}}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ::Vararg{Any}; sensealg::QuadratureAdjoint{0, true, Val{:central}, ZygoteVJP}, kwargs::Base.Pairs{Symbol, Union{Nothing, Float64}, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:callback, :reltol, :abstol), Tuple{Nothing, Float64, Float64}}})
    @ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/T7LDZ/src/sensitivity_interface.jl:6
  [4] (::DiffEqSensitivity.var"#adjoint_sensitivity_backpass#250"{Base.Pairs{Symbol, Float64, Tuple{Symbol, Symbol}, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, QuadratureAdjoint{0, true, Val{:central}, ZygoteVJP}, Float64, SciMLBase.NullParameters, Tuple{}, Colon, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}})(Δ::Vector{Any})
    @ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/T7LDZ/src/concrete_solve.jl:249
  [5] ZBack
    @ ~/.julia/packages/Zygote/ytjqm/src/compiler/chainrules.jl:205 [inlined]
  [6] (::Zygote.var"#kw_zpullback#41"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#250"{Base.Pairs{Symbol, Float64, Tuple{Symbol, Symbol}, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, QuadratureAdjoint{0, true, Val{:central}, ZygoteVJP}, Float64, SciMLBase.NullParameters, Tuple{}, Colon, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}}})(dy::RecursiveArrayTools.VectorOfArray{Any, 1, Vector{Any}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/chainrules.jl:231
  [7] #212
    @ ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:203 [inlined]
  [8] (::Zygote.var"#1750#back#214"{Zygote.var"#212#213"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, Zygote.var"#kw_zpullback#41"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#250"{Base.Pairs{Symbol, Float64, Tuple{Symbol, Symbol}, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, QuadratureAdjoint{0, true, Val{:central}, ZygoteVJP}, Float64, SciMLBase.NullParameters, Tuple{}, Colon, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}}}}})(Δ::RecursiveArrayTools.VectorOfArray{Any, 1, Vector{Any}})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [9] Pullback
    @ ~/.julia/packages/DiffEqBase/ziNGu/src/solve.jl:165 [inlined]
 [10] (::typeof(∂(#solve#40)))(Δ::RecursiveArrayTools.VectorOfArray{Any, 1, Vector{Any}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [11] (::Zygote.var"#212#213"{Tuple{NTuple{7, Nothing}, Tuple{Nothing}}, typeof(∂(#solve#40))})(Δ::RecursiveArrayTools.VectorOfArray{Any, 1, Vector{Any}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:203
 [12] (::Zygote.var"#1750#back#214"{Zygote.var"#212#213"{Tuple{NTuple{7, Nothing}, Tuple{Nothing}}, typeof(∂(#solve#40))}})(Δ::RecursiveArrayTools.VectorOfArray{Any, 1, Vector{Any}})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
 [13] Pullback
    @ ~/.julia/packages/DiffEqBase/ziNGu/src/solve.jl:159 [inlined]
 [14] (::typeof(∂(solve##kw)))(Δ::RecursiveArrayTools.VectorOfArray{Any, 1, Vector{Any}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [15] Pullback
    @ ./In[5]:4 [inlined]
 [16] (::typeof(∂(odesolution)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [17] (::Zygote.var"#56#57"{typeof(∂(odesolution))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:41
 [18] gradient(f::Function, args::Float64)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:76
 [19] top-level scope
    @ In[6]:2
 [20] eval
    @ ./boot.jl:373 [inlined]
 [21] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
    @ Base ./loading.jl:1196

Looks like another application for Zygote.forwarddiff:

# using Flux: gradient
using Zygote
using ForwardDiff
using DifferentialEquations: Tsit5, ODEProblem, solve
using DiffEqSensitivity

function odesolution(x)
    f(u,p,t) = 0.1u
    prob = ODEProblem(f,x,[0.,1.])
    sol = solve(prob, Tsit5(), reltol=1e-8, abstol=1e-8)
    return last(sol)
end

x = 1.0
# gs = gradient(odesolution, x)
@show ForwardDiff.derivative(x -> odesolution(x), x)
@show Zygote.gradient(x -> Zygote.forwarddiff(odesolution, x), x)

seems to work.

Edit: question to others: I don’t know enough about Flux and DifferentialEquations compatibility to assess if this is a bug? At least I didn’t find a related issue at GitHub.

Edit: corrected usage of forwarddiff

1 Like
using Flux: gradient
using DifferentialEquations: Tsit5, ODEProblem, solve
using DiffEqSensitivity

function odesolution(x)
    f(u,p,t) = 0.1u
    prob = ODEProblem(f,[x],[0.,1.])
    sol = solve(prob, Tsit5(), reltol=1e-8, abstol=1e-8)
    return last(sol)[1]
end

x = 1.0
gs = gradient(odesolution, x)

The issue is that adjoints do not work with scalar equations. Adjoint Sensitivity Out of Place Support · Issue #113 · SciML/DiffEqSensitivity.jl · GitHub . That said adjoints are an absolutely terrible idea on scalar ODEs (you need like 100 ODEs for it to make sense, [1812.01892] A Comparison of Automatic Differentiation and Continuous Sensitivity Analysis for Derivatives of Differential Equation Solutions), and so the right thing to do is to use forward-mode. We should make this case just internally see a scalar and switch to forward-mode to both get rid of the error and make it optimal.

3 Likes

Thanks a lot for your great answer!

I see. After seeing @goerch 's answer I was worried that forwarddiff would slow down my code, but based on your answer it is probably much better than gradient. Thanks for the great answer!

Unfortunately I am having a little trouble with the implementation of forward mode differentiation. Your solution worked in this simple example where x is a scalar, but in general I will have f(u,p,t) defined by a neural network. And I found that Zygote.forwarddiff does not return a gradient, it always returns a scalar.

I found another package named ForwardDiff which has a gradient function, but I can’t get that to work.

W = [randn(1,1)] #I am using this structure because in the general case I could have multiple layers. 
b = [randn(1)]
ps = [W,b]
function myfun(x, ps)
    W,b = ps
    y = W*x+b
    return y
end
function ode(x, ps)
    f(u,p,t) = myfun(u,ps)[1][1]
    prob = ODEProblem(f, x, [0.,1.])
    sol = solve(prob, Tsit5(), reltol=1e-8, abstol=1e-8)
    return last(sol)
end;

@show ode(0.1, ps) #returns 0.4723608359338894

#   Zygote.forwarddiff(W ->ode(0.1,ps),   ps) #Not a gradient
ForwardDiff.gradient(ps->ode(0.1,ps),   ps) #MethodError: no method matching one(::Type{Vector})
# ForwardDiff.gradient(W ->ode(0.1,[W,b]), W) #MethodError: no method matching one(::Type{Matrix{Float64}})

I am not sure how to proceed at this point.

This is what I use when I don’t understand Zygote. AFAIU your code has two minor problems

  • I believe you have to somehow reshape your parameters into a vector
  • x in ode has to have the same type as the parameters (it could be a Dual for ForwardDiff), otherwise DifferentialEquations errors.

With these changes I get

using Zygote
using ForwardDiff
using DifferentialEquations: Tsit5, ODEProblem, solve
using DiffEqSensitivity

W = [randn(1,1)] #I am using this structure because in the general case I could have multiple layers. 
b = [randn(1)]
ps = [W[1][1,1],b[1][1]]
function myfun(x, ps)
    W = [ps[1]]
    b = [ps[2]]
    y = W*x+b
    return [y]
end
function ode(x, ps)
    f(u,p,t) = myfun(u,ps)[1][1]
    prob = ODEProblem(f, convert(eltype(ps), x), [0.,1.])
    sol = solve(prob, Tsit5(), reltol=1e-8, abstol=1e-8)
    return last(sol)
end;

@show ode(0.1, ps)

@show Zygote.gradient(ps->Zygote.forwarddiff(ps->ode(0.1,ps), ps),   ps) 
@show ForwardDiff.gradient(ps->ode(0.1,ps),   ps) 

yielding

ode(0.1, ps) = 0.7162960502619551
Zygote.gradient((ps->begin
            Zygote.forwarddiff((ps->begin
                        ode(0.1, ps)
                    end), ps)
        end), ps) = ([0.4205684897299848, 1.0625333719189938],)
ForwardDiff.gradient((ps->begin
            ode(0.1, ps)
        end), ps) = [0.4205684897299848, 1.0625333719189938]

Question for others: what is the best way to manage W and b as parameters in this scenario?

Edit: finally understood how Zygote.forwarddiff is intended to be used.