# ReverseDiff gradient for loss functions including ODE

Hey everybody,

I am trying to compute the gradient of a loss function which includes an ODE-solution. A MWE for a second order example is given below:

using DifferentialEquations, DiffEqSensitivity, ReverseDiff
N = 2
function myODE!(du, u, p, t)
du[1:N] .= -p*u[1:N]
end
x0 = rand(N)
tspan = (0.0,1.0)
p = rand(N,N)
prob = ODEProblem(myODE!,x0,tspan,p)
function f(a)
x1,x2,x3,x4 = a
return sum(solve(remake(prob; p=[x1 x3; x2 x4]),Tsit5(),saveat=0.1)[end])
end
a = rand(N,N)

This code computes the gradient as desired. However, if I define the loss function as

function f(a)
return sum(solve(remake(prob; p=a),Tsit5(),saveat=0.1)[end])
end

and run the code I get the following error message:

ERROR: LoadError: DimensionMismatch(â€śarrays could not be broadcast to a common size; got a dimension with lengths 5 and 6â€ť)
Stacktrace:
[1] _bcs1
[2] _bcs
[4] combine_axes
[5] _axes
[6] axes
[7] copy
[8] materialize
[9] special_reverse_exec!(instruction::ReverseDiff.SpecialInstruction{typeof(DiffEqBase.solve_up), Tuple{ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}, ODEFunction{true, typeof(myODE!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Nothing, Vector{Float64}, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}, Tuple{DiffEqSensitivity.var"#forward_sensitivity_backpass#263"{0, Float64, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Matrix{Float64}, ODEFunction{true,
typeof(myODE!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ForwardDiffSensitivity{0, nothing}, Vector{Float64}, Matrix{Float64}, Tuple{}, UnitRange{Int64}}, DiffEqBase.var"##solve_up#274#189"{DiffEqBase.var"##solve_up#274#188#190"}, NamedTuple{(:saveat,), Tuple{Float64}}}})
@ DiffEqBase C:\Users\Eric Otto.julia\packages\ReverseDiff\Y5qec\src\macros.jl:218
[10] reverse_exec!(instruction::ReverseDiff.SpecialInstruction{typeof(DiffEqBase.solve_up), Tuple{ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}, ODEFunction{true, typeof(myODE!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing,
Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Nothing, Vector{Float64}, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}, Tuple{DiffEqSensitivity.var"#forward_sensitivity_backpass#263"{0, Float64, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Matrix{Float64}, ODEFunction{true, typeof(myODE!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ForwardDiffSensitivity{0, nothing}, Vector{Float64}, Matrix{Float64}, Tuple{}, UnitRange{Int64}}, DiffEqBase.var"##solve_up#274#189"{DiffEqBase.var"##solve_up#274#188#190"}, NamedTuple{(:saveat,), Tuple{Float64}}}})
@ ReverseDiff C:\Users\Eric Otto.julia\packages\ReverseDiff\Y5qec\src\tape.jl:93
[11] reverse_pass!(tape::Vector{ReverseDiff.AbstractInstruction})
@ ReverseDiff C:\Users\Eric Otto.julia\packages\ReverseDiff\Y5qec\src\tape.jl:87
[12] reverse_pass!
@ C:\Users\Eric Otto.julia\packages\ReverseDiff\Y5qec\src\api\tape.jl:36 [inlined]
[13] seeded_reverse_pass!(result::Matrix{Float64}, output::ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}}, input::ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}, tape::ReverseDiff.GradientTape{typeof(f), ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}, ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}}})
@ ReverseDiff C:\Users\Eric Otto.julia\packages\ReverseDiff\Y5qec\src\api\utils.jl:31
[14] seeded_reverse_pass!(result::Matrix{Float64}, t::ReverseDiff.GradientTape{typeof(f), ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}, ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}}})
@ ReverseDiff C:\Users\Eric Otto.julia\packages\ReverseDiff\Y5qec\src\api\tape.jl:47
[17] top-level scope
@ c:\Users\Eric Otto\Desktop\Julia\startup.jl:30
in expression starting at c:\Users\Eric Otto\Desktop\Julia\startup.jl:30

Can someone explain to me what the problem is, or give an alternative solution? I am asking since I would like to compute gradients for systems with much larger order. Extracting the array elements into scalars would be weird for 1000+ parameters. I can compute the gradient using `ForwardDiff`, I think however the performance for a high order system should be better using `ReverseDiff`.

The problem is that you were implicitly doing a struct of array to array of structs transformation by splatting the `a`. Thatâ€™s going to be really slow anyways. I would suggest using the recommended Zygote, since that will use the adjoint overloads. See this page for details: