Zygote gradients with ensemble ODEs

Hey,

I’ve been doing some work on getting zygote gradients from ensemble problems. I tried extending the first example in the docs for local sensitivities but I’m getting an error that I’m not sure how to solve. I’m not sure how to simplify the example, except for perhaps having a simpler ode, but this doesn’t seem to be the important factor here. I’m not really sure how to continue debugging this. Any advice appreciated!

MWE:

using DiffEqSensitivity, DifferentialEquations, ForwardDiff, Zygote

# Reproducing docs example. This works with ForwardDiff and Zygote
function fiip(du,u,p,t)
  du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]
  du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2]
end

p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0]
prob = ODEProblem(fiip,u0,(0.0,10.0),p)
sol = solve(prob,Tsit5())

plot(sol)

function sum_of_solution(x)
    _prob = remake(prob,u0=x[1:2],p=x[3:end])
    sum(solve(_prob,Tsit5(),saveat=0.1))
end

Zygote.gradient(sum_of_solution,[u0;p])

# Testing ensemble problem. Works with ForwardDiff. Does not work with Zygote. 
N = 3
eu0 = rand(N,2)
ep = rand(N,4)

ensemble_prob = EnsembleProblem(prob, prob_func=(prob, i, repeat)->remake(prob, u0=eu0[i,:], p=ep[i,:],saveat=0.1))

esol = solve(ensemble_prob, Tsit5(), trajectories=N)

plot(esol)

function sum_of_e_solution(p)
    ensemble_prob = EnsembleProblem(prob, prob_func=(prob, i, repeat)->remake(prob, u0=eu0[i,:],p=p[i,:],saveat=0.1))
    sol = solve(ensemble_prob,Tsit5(),trajectories=N)
    sum(Array(sol[1])) # just test for the first solutions, gradients should be zero for others
end

sum_of_e_solution(ep)

Zygote.gradient(sum_of_e_solution, ep)

and the error i get is:

julia> Zygote.gradient(sum_of_e_solution, ep)
ERROR: LoadError: BoundsError: attempt to access Tuple{Int64} at index [0]
Stacktrace:
  [1] getindex(t::Tuple, i::Int64)
    @ Base ./tuple.jl:29
  [2] (::DiffEqBase.var"#184#188"{1, Vector{AbstractMatrix{Float64}}})(i::Int64)
    @ DiffEqBase ./none:0
  [3] iterate
    @ ./generator.jl:47 [inlined]
  [4] collect
    @ ./array.jl:678 [inlined]
  [5] (::DiffEqBase.var"#EnsembleSolution_adjoint#187")(p̄::Vector{AbstractMatrix{Float64}})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/oe7VF/src/chainrules.jl:62
  [6] (::DiffEqBase.var"#166#back#191"{DiffEqBase.var"#EnsembleSolution_adjoint#187"})(Δ::Vector{AbstractMatrix{Float64}})
    @ DiffEqBase ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [7] Pullback
    @ ~/.julia/packages/SciMLBase/1aTqd/src/ensemble/basic_ensemble_solve.jl:110 [inlined]
  [8] (::typeof(∂(#__solve#457)))(Δ::Vector{AbstractMatrix{Float64}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
  [9] Pullback
    @ ~/.julia/packages/SciMLBase/1aTqd/src/ensemble/basic_ensemble_solve.jl:103 [inlined]
 [10] (::typeof(∂(__solve##kw)))(Δ::Vector{AbstractMatrix{Float64}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [11] Pullback
    @ ~/.julia/packages/SciMLBase/1aTqd/src/ensemble/basic_ensemble_solve.jl:87 [inlined]
 [12] (::typeof(∂(#__solve#456)))(Δ::Vector{AbstractMatrix{Float64}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [13] Pullback
    @ ~/.julia/packages/SciMLBase/1aTqd/src/ensemble/basic_ensemble_solve.jl:58 [inlined]
 [14] (::typeof(∂(__solve##kw)))(Δ::Vector{AbstractMatrix{Float64}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [15] #209
    @ ~/.julia/packages/Zygote/TaBlo/src/lib/lib.jl:203 [inlined]
 [16] (::Zygote.var"#1746#back#211"{Zygote.var"#209#210"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(__solve##kw))}})(Δ::Vector{AbstractMatrix{Float64}})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
 [17] Pullback
    @ ~/.julia/packages/DiffEqBase/oe7VF/src/solve.jl:96 [inlined]
 [18] (::typeof(∂(#solve#61)))(Δ::Vector{AbstractMatrix{Float64}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [19] (::Zygote.var"#209#210"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(#solve#61))})(Δ::Vector{AbstractMatrix{Float64}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/lib/lib.jl:203
 [20] #1746#back
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [21] Pullback
    @ ~/.julia/packages/DiffEqBase/oe7VF/src/solve.jl:93 [inlined]
 [22] (::typeof(∂(solve##kw)))(Δ::Vector{AbstractMatrix{Float64}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [23] Pullback
    @ ~/Projects/TauPet/tests/simpleensembleode.jl:38 [inlined]
 [24] (::typeof(∂(sum_of_e_solution)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [25] (::Zygote.var"#46#47"{typeof(∂(sum_of_e_solution))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:41
 [26] gradient(f::Function, args::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:76
 [27] top-level scope
    @ ~/Projects/TauPet/tests/simpleensembleode.jl:43
1 Like

If you change Array(sol[1]) to Array(sol)[:,:,1] it will use a different (working) adjoint.
You need to pass saveat to solve as well.

using DiffEqSensitivity, DifferentialEquations, Zygote

# Reproducing docs example. This works with ForwardDiff and Zygote
function fiip(du,u,p,t)
  du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]
  du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2]
end

p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0]
prob = ODEProblem(fiip,u0,(0.0,10.0),p)
sol = solve(prob,Tsit5())

plot(sol)

function sum_of_solution(x)
    _prob = remake(prob,u0=x[1:2],p=x[3:end])
    sum(solve(_prob,Tsit5(),saveat=0.1))
end

Zygote.gradient(sum_of_solution,[u0;p])

# Testing ensemble problem. Works with ForwardDiff. Does not work with Zygote.
N = 3
eu0 = rand(N,2)
ep = rand(N,4)

ensemble_prob = EnsembleProblem(prob, prob_func=(prob, i, repeat)->remake(prob, u0=eu0[i,:], p=ep[i,:],saveat=0.1))

esol = solve(ensemble_prob, Tsit5(), trajectories=N)

plot(esol)

function sum_of_e_solution(p)
    ensemble_prob = EnsembleProblem(prob, prob_func=(prob, i, repeat)->remake(prob, u0=eu0[i,:],p=p[i,:]))
    sol = solve(ensemble_prob,Tsit5(),trajectories=N, saveat=0.1)
    sum(Array(sol)[:,:,1]) # just test for the first solutions, gradients should be zero for others
end

sum_of_e_solution(ep)

Zygote.gradient(sum_of_e_solution, ep)
# 3×4 Matrix{Float64}:
#  203.492  -53.2  42.32  -17.6765
#    0.0      0.0   0.0     0.0
#    0.0      0.0   0.0     0.0

Thanks. That works!

The reason I had saveat in the EnsembleProblem is that in my actual problem I need to save each trajectory at different time points. As far as I know, this isn’t possible using saveat in solve. Do you know of a way to do this?

This is a bug. Open an issue.