ReverseDiff gradient for loss functions including ODE

Hey everybody,

I am trying to compute the gradient of a loss function which includes an ODE-solution. A MWE for a second order example is given below:

using DifferentialEquations, DiffEqSensitivity, ReverseDiff
N = 2
function myODE!(du, u, p, t)
du[1:N] .= -p*u[1:N]
end
x0 = rand(N)
tspan = (0.0,1.0)
p = rand(N,N)
prob = ODEProblem(myODE!,x0,tspan,p)
function f(a)
x1,x2,x3,x4 = a
return sum(solve(remake(prob; p=[x1 x3; x2 x4]),Tsit5(),saveat=0.1)[end])
end
a = rand(N,N)
ReverseDiff.gradient(f,a)

This code computes the gradient as desired. However, if I define the loss function as

function f(a)
return sum(solve(remake(prob; p=a),Tsit5(),saveat=0.1)[end])
end

and run the code I get the following error message:

ERROR: LoadError: DimensionMismatch(“arrays could not be broadcast to a common size; got a dimension with lengths 5 and 6”)
Stacktrace:
[1] _bcs1
@ .\broadcast.jl:516 [inlined]
[2] _bcs
@ .\broadcast.jl:510 [inlined]
[3] broadcast_shape(::Tuple{Base.OneTo{Int64}}, ::Tuple{Base.OneTo{Int64}})
@ Base.Broadcast .\broadcast.jl:504
[4] combine_axes
@ .\broadcast.jl:499 [inlined]
[5] _axes
@ .\broadcast.jl:224 [inlined]
[6] axes
@ .\broadcast.jl:222 [inlined]
[7] copy
@ .\broadcast.jl:1072 [inlined]
[8] materialize
@ .\broadcast.jl:860 [inlined]
[9] special_reverse_exec!(instruction::ReverseDiff.SpecialInstruction{typeof(DiffEqBase.solve_up), Tuple{ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}, ODEFunction{true, typeof(myODE!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Nothing, Vector{Float64}, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}, Tuple{DiffEqSensitivity.var"#forward_sensitivity_backpass#263"{0, Float64, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Matrix{Float64}, ODEFunction{true,
typeof(myODE!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ForwardDiffSensitivity{0, nothing}, Vector{Float64}, Matrix{Float64}, Tuple{}, UnitRange{Int64}}, DiffEqBase.var"##solve_up#274#189"{DiffEqBase.var"##solve_up#274#188#190"}, NamedTuple{(:saveat,), Tuple{Float64}}}})
@ DiffEqBase C:\Users\Eric Otto.julia\packages\ReverseDiff\Y5qec\src\macros.jl:218
[10] reverse_exec!(instruction::ReverseDiff.SpecialInstruction{typeof(DiffEqBase.solve_up), Tuple{ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}, ODEFunction{true, typeof(myODE!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing,
Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Nothing, Vector{Float64}, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}, Tuple{DiffEqSensitivity.var"#forward_sensitivity_backpass#263"{0, Float64, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Matrix{Float64}, ODEFunction{true, typeof(myODE!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ForwardDiffSensitivity{0, nothing}, Vector{Float64}, Matrix{Float64}, Tuple{}, UnitRange{Int64}}, DiffEqBase.var"##solve_up#274#189"{DiffEqBase.var"##solve_up#274#188#190"}, NamedTuple{(:saveat,), Tuple{Float64}}}})
@ ReverseDiff C:\Users\Eric Otto.julia\packages\ReverseDiff\Y5qec\src\tape.jl:93
[11] reverse_pass!(tape::Vector{ReverseDiff.AbstractInstruction})
@ ReverseDiff C:\Users\Eric Otto.julia\packages\ReverseDiff\Y5qec\src\tape.jl:87
[12] reverse_pass!
@ C:\Users\Eric Otto.julia\packages\ReverseDiff\Y5qec\src\api\tape.jl:36 [inlined]
[13] seeded_reverse_pass!(result::Matrix{Float64}, output::ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}}, input::ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}, tape::ReverseDiff.GradientTape{typeof(f), ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}, ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}}})
@ ReverseDiff C:\Users\Eric Otto.julia\packages\ReverseDiff\Y5qec\src\api\utils.jl:31
[14] seeded_reverse_pass!(result::Matrix{Float64}, t::ReverseDiff.GradientTape{typeof(f), ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}, ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}}})
@ ReverseDiff C:\Users\Eric Otto.julia\packages\ReverseDiff\Y5qec\src\api\tape.jl:47
[15] gradient(f::Function, input::Matrix{Float64}, cfg::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}})
@ ReverseDiff C:\Users\Eric Otto.julia\packages\ReverseDiff\Y5qec\src\api\gradients.jl:24
[16] gradient(f::Function, input::Matrix{Float64})
@ ReverseDiff C:\Users\Eric Otto.julia\packages\ReverseDiff\Y5qec\src\api\gradients.jl:22
[17] top-level scope
@ c:\Users\Eric Otto\Desktop\Julia\startup.jl:30
in expression starting at c:\Users\Eric Otto\Desktop\Julia\startup.jl:30

Can someone explain to me what the problem is, or give an alternative solution? I am asking since I would like to compute gradients for systems with much larger order. Extracting the array elements into scalars would be weird for 1000+ parameters. I can compute the gradient using ForwardDiff, I think however the performance for a high order system should be better using ReverseDiff.

Thanks in advance
Eric

The problem is that you were implicitly doing a struct of array to array of structs transformation by splatting the a. That’s going to be really slow anyways. I would suggest using the recommended Zygote, since that will use the adjoint overloads. See this page for details:

https://diffeq.sciml.ai/stable/analysis/sensitivity/

1 Like

That works great, thank you!