# Zygote gradients with ensemble ODEs

Hey,

I’ve been doing some work on getting zygote gradients from ensemble problems. I tried extending the first example in the docs for local sensitivities but I’m getting an error that I’m not sure how to solve. I’m not sure how to simplify the example, except for perhaps having a simpler ode, but this doesn’t seem to be the important factor here. I’m not really sure how to continue debugging this. Any advice appreciated!

MWE:

``````using DiffEqSensitivity, DifferentialEquations, ForwardDiff, Zygote

# Reproducing docs example. This works with ForwardDiff and Zygote
function fiip(du,u,p,t)
du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]
du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2]
end

p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0]
prob = ODEProblem(fiip,u0,(0.0,10.0),p)
sol = solve(prob,Tsit5())

plot(sol)

function sum_of_solution(x)
_prob = remake(prob,u0=x[1:2],p=x[3:end])
sum(solve(_prob,Tsit5(),saveat=0.1))
end

# Testing ensemble problem. Works with ForwardDiff. Does not work with Zygote.
N = 3
eu0 = rand(N,2)
ep = rand(N,4)

ensemble_prob = EnsembleProblem(prob, prob_func=(prob, i, repeat)->remake(prob, u0=eu0[i,:], p=ep[i,:],saveat=0.1))

esol = solve(ensemble_prob, Tsit5(), trajectories=N)

plot(esol)

function sum_of_e_solution(p)
ensemble_prob = EnsembleProblem(prob, prob_func=(prob, i, repeat)->remake(prob, u0=eu0[i,:],p=p[i,:],saveat=0.1))
sol = solve(ensemble_prob,Tsit5(),trajectories=N)
sum(Array(sol[1])) # just test for the first solutions, gradients should be zero for others
end

sum_of_e_solution(ep)

``````

and the error i get is:

``````julia> Zygote.gradient(sum_of_e_solution, ep)
ERROR: LoadError: BoundsError: attempt to access Tuple{Int64} at index [0]
Stacktrace:
[1] getindex(t::Tuple, i::Int64)
@ Base ./tuple.jl:29
[2] (::DiffEqBase.var"#184#188"{1, Vector{AbstractMatrix{Float64}}})(i::Int64)
@ DiffEqBase ./none:0
[3] iterate
@ ./generator.jl:47 [inlined]
[4] collect
@ ./array.jl:678 [inlined]
@ DiffEqBase ~/.julia/packages/DiffEqBase/oe7VF/src/chainrules.jl:62
[7] Pullback
@ ~/.julia/packages/SciMLBase/1aTqd/src/ensemble/basic_ensemble_solve.jl:110 [inlined]
[8] (::typeof(∂(#__solve#457)))(Δ::Vector{AbstractMatrix{Float64}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[9] Pullback
@ ~/.julia/packages/SciMLBase/1aTqd/src/ensemble/basic_ensemble_solve.jl:103 [inlined]
[10] (::typeof(∂(__solve##kw)))(Δ::Vector{AbstractMatrix{Float64}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[11] Pullback
@ ~/.julia/packages/SciMLBase/1aTqd/src/ensemble/basic_ensemble_solve.jl:87 [inlined]
[12] (::typeof(∂(#__solve#456)))(Δ::Vector{AbstractMatrix{Float64}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[13] Pullback
@ ~/.julia/packages/SciMLBase/1aTqd/src/ensemble/basic_ensemble_solve.jl:58 [inlined]
[14] (::typeof(∂(__solve##kw)))(Δ::Vector{AbstractMatrix{Float64}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[15] #209
@ ~/.julia/packages/Zygote/TaBlo/src/lib/lib.jl:203 [inlined]
[16] (::Zygote.var"#1746#back#211"{Zygote.var"#209#210"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(__solve##kw))}})(Δ::Vector{AbstractMatrix{Float64}})
[17] Pullback
@ ~/.julia/packages/DiffEqBase/oe7VF/src/solve.jl:96 [inlined]
[18] (::typeof(∂(#solve#61)))(Δ::Vector{AbstractMatrix{Float64}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[19] (::Zygote.var"#209#210"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(#solve#61))})(Δ::Vector{AbstractMatrix{Float64}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/lib/lib.jl:203
[20] #1746#back
[21] Pullback
@ ~/.julia/packages/DiffEqBase/oe7VF/src/solve.jl:93 [inlined]
[22] (::typeof(∂(solve##kw)))(Δ::Vector{AbstractMatrix{Float64}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[23] Pullback
@ ~/Projects/TauPet/tests/simpleensembleode.jl:38 [inlined]
[24] (::typeof(∂(sum_of_e_solution)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[25] (::Zygote.var"#46#47"{typeof(∂(sum_of_e_solution))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:41
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:76
[27] top-level scope
@ ~/Projects/TauPet/tests/simpleensembleode.jl:43
``````
1 Like

If you change `Array(sol[1])` to `Array(sol)[:,:,1]` it will use a different (working) adjoint.
You need to pass `saveat` to solve as well.

``````using DiffEqSensitivity, DifferentialEquations, Zygote

# Reproducing docs example. This works with ForwardDiff and Zygote
function fiip(du,u,p,t)
du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]
du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2]
end

p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0]
prob = ODEProblem(fiip,u0,(0.0,10.0),p)
sol = solve(prob,Tsit5())

plot(sol)

function sum_of_solution(x)
_prob = remake(prob,u0=x[1:2],p=x[3:end])
sum(solve(_prob,Tsit5(),saveat=0.1))
end

# Testing ensemble problem. Works with ForwardDiff. Does not work with Zygote.
N = 3
eu0 = rand(N,2)
ep = rand(N,4)

ensemble_prob = EnsembleProblem(prob, prob_func=(prob, i, repeat)->remake(prob, u0=eu0[i,:], p=ep[i,:],saveat=0.1))

esol = solve(ensemble_prob, Tsit5(), trajectories=N)

plot(esol)

function sum_of_e_solution(p)
ensemble_prob = EnsembleProblem(prob, prob_func=(prob, i, repeat)->remake(prob, u0=eu0[i,:],p=p[i,:]))
sol = solve(ensemble_prob,Tsit5(),trajectories=N, saveat=0.1)
sum(Array(sol)[:,:,1]) # just test for the first solutions, gradients should be zero for others
end

sum_of_e_solution(ep)

# 3×4 Matrix{Float64}:
#  203.492  -53.2  42.32  -17.6765
#    0.0      0.0   0.0     0.0
#    0.0      0.0   0.0     0.0
``````

Thanks. That works!

The reason I had `saveat` in the `EnsembleProblem` is that in my actual problem I need to save each trajectory at different time points. As far as I know, this isn’t possible using `saveat` in `solve`. Do you know of a way to do this?

This is a bug. Open an issue.

For reference, this was solved in Handle all prob.kwargs at the interface level by ChrisRackauckas · Pull Request #685 · SciML/DiffEqBase.jl · GitHub

2 Likes

Hey,

I’ve run into a different issue with zygote gradients working when I’ve built up to a more complicated ensemble model with Turing inference.

A working example is below. I’ve tried to reduce it as much as I can while still getting the same error. I’m not sure what’s going wrong; the gradients work for some solutions (using `saveat`) but not others. I.e. if the `saveat` in the code below is changed from [0.0,1.5,2.0] to [0.0,1.5,1.9], zygote throws an error (also below). Additionally, I don’t seem to get the error in the non-hierarchical case, when I use regular ODEProblems as opposed to ensemble problems, or if I use ForwardDiff.

``````using SimpleWeightedGraphs, LightGraphs
using Turing, Zygote, ForwardDiff
using DifferentialEquations, DiffEqSensitivity
using Random
Random.seed!(1)

N = 10
P = 0.5
L = erdos_renyi(N, P) |> laplacian_matrix

function NetworkFKPP(du, u, p, t)
du .= -p[1] * L * u .+ p[2] .* u .* (1 .- u)
end

prob = ODEProblem(NetworkFKPP, rand(N), (0.0,2.0), [0.1,1.5])
sol = solve(prob, Tsit5(), saveat=0.2)

plot(sol, labels=false)

function make_prob_func(k, a)
function prob_func(prob,i,repeat)
remake(prob, p=[k[i], a[i]], saveat=[0.0,1.0,2.0])
end
end

function output_func(sol,i)
(vec(sol),false)
end

ensemble_prob = EnsembleProblem(prob, prob_func=make_prob_func([1.0], [1.0]), output_func=output_func)

esol = solve(ensemble_prob, Tsit5(), trajectories=1, sensalg=InterpolatingAdjoint(autojacvec=ZygoteVJP()))

@model function fitdata(data, prob, N)
σ ~ InverseGamma(2, 3)

Km ~ truncated(Normal(1.0,1.0), 0.0, 1.0)
Ks ~ truncated(Normal(0.0,1.0), 0.0, 1.0)

Am ~ truncated(Normal(5.0, 2.0), 0.0, 10.0)
As ~ truncated(Normal(0.0, 1.0), 0.0, 1.0)

k ~ filldist(truncated(Normal(Km, Ks), 0.0, 1.0), N)
a ~ filldist(truncated(Normal(Am, As), 0.0, 10.0), N)

ensemble_prob = EnsembleProblem(prob,prob_func=make_prob_func(k, a), output_func=output_func)
predicted = solve(ensemble_prob, Tsit5(), trajectories=N, sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP()))

data ~ MvNormal(predicted[1], σ)
end

model = fitdata(esol[1], prob, 1)

model()

function gen_∂logπ∂θ(vi, spl, model)
function ∂logπ∂θ(x)
end
return ∂logπ∂θ
end

var_info = Turing.VarInfo(model)

spl = DynamicPPL.Sampler(NUTS(.65))

∂logπ∂θ = gen_∂logπ∂θ(var_info, spl, model)
∂logπ∂θ(var_info[spl])
``````

The error I get when I change `[0.0,1.0,2.0]` in `make_prob_func` to `[0.0,1.0,1.9]` gives the following error:

``````ERROR: LoadError: BoundsError: attempt to access 1-element Vector{Float64} at index [0]
Stacktrace:
[1] getindex
@ ./array.jl:801 [inlined]
[2] split_states(du::Vector{Float64}, u::Vector{Float64}, t::Float64, S::DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float64}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}, InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP, Bool}, Vector{Float64}, ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, typeof(NetworkFKPP), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Symbol, Vector{Float64}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Vector{Float64}}}}, SciMLBase.StandardODEProblem}, Tsit5, OrdinaryDiffEq.InterpolationData{ODEFunction{true, typeof(NetworkFKPP), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}}, DiffEqBase.DEStats}, DiffEqSensitivity.CheckpointSolution{ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, typeof(NetworkFKPP), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Symbol, Vector{Float64}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Vector{Float64}}}}, SciMLBase.StandardODEProblem}, Tsit5, OrdinaryDiffEq.InterpolationData{ODEFunction{true, typeof(NetworkFKPP), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}}, DiffEqBase.DEStats}, Vector{Tuple{Float64, Float64}}, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}, Nothing}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, typeof(NetworkFKPP), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Symbol, Vector{Float64}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Vector{Float64}}}}, SciMLBase.StandardODEProblem}, ODEFunction{true, typeof(NetworkFKPP), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}; update::Bool)
[3] split_states
[4] (::DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float64}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}, InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP, Bool}, Vector{Float64}, ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, typeof(NetworkFKPP), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Symbol, Vector{Float64}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Vector{Float64}}}}, SciMLBase.StandardODEProblem}, Tsit5, OrdinaryDiffEq.InterpolationData{ODEFunction{true, typeof(NetworkFKPP), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}}, DiffEqBase.DEStats}, DiffEqSensitivity.CheckpointSolution{ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, typeof(NetworkFKPP), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Symbol, Vector{Float64}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Vector{Float64}}}}, SciMLBase.StandardODEProblem}, Tsit5, OrdinaryDiffEq.InterpolationData{ODEFunction{true, typeof(NetworkFKPP), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}}, DiffEqBase.DEStats}, Vector{Tuple{Float64, Float64}}, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}, Nothing}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, typeof(NetworkFKPP), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Symbol, Vector{Float64}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Vector{Float64}}}}, SciMLBase.StandardODEProblem}, ODEFunction{true, typeof(NetworkFKPP), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}})(du::Vector{Float64}, u::Vector{Float64}, p::Vector{Float64}, t::Float64)
[5] ODEFunction
@ ~/.julia/packages/SciMLBase/cU5k7/src/scimlfunctions.jl:334 [inlined]
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/9mxZY/src/perform_step/low_order_rk_perform_step.jl:633
``````

error continued…

`````` [7] perform_step!
@ ~/.julia/packages/OrdinaryDiffEq/9mxZY/src/perform_step/low_order_rk_perform_step.jl:628 [inlined]
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/9mxZY/src/solve.jl:478
[9] #__solve#467
@ ~/.julia/packages/OrdinaryDiffEq/9mxZY/src/solve.jl:5 [inlined]
[10] #solve_call#58
@ ~/.julia/packages/DiffEqBase/ge0vq/src/solve.jl:61 [inlined]
@ DiffEqBase ~/.julia/packages/DiffEqBase/ge0vq/src/solve.jl:85
[12] #solve#59
@ ~/.julia/packages/DiffEqBase/ge0vq/src/solve.jl:73 [inlined]
[13] _adjoint_sensitivities(sol::ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, typeof(NetworkFKPP), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Symbol, Vector{Float64}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Vector{Float64}}}}, SciMLBase.StandardODEProblem}, Tsit5, OrdinaryDiffEq.InterpolationData{ODEFunction{true, typeof(NetworkFKPP), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}}, DiffEqBase.DEStats}, sensealg::InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP, Bool}, alg::Tsit5, g::DiffEqSensitivity.var"#df#211"{Matrix{Float64}, Colon}, t::Vector{Float64}, dg::Nothing; abstol::Float64, reltol::Float64, checkpoints::Vector{Float64}, corfunc_analytical::Nothing, callback::Nothing, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/EA0mk/src/sensitivity_interface.jl:36
[14] adjoint_sensitivities(::ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, typeof(NetworkFKPP), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Symbol, Vector{Float64}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Vector{Float64}}}}, SciMLBase.StandardODEProblem}, Tsit5, OrdinaryDiffEq.InterpolationData{ODEFunction{true, typeof(NetworkFKPP), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}}, DiffEqBase.DEStats}, ::Tsit5, ::Vararg{Any, N} where N; sensealg::InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP, Bool}, kwargs::Base.Iterators.Pairs{Symbol, Nothing, Tuple{Symbol}, NamedTuple{(:callback,), Tuple{Nothing}}})
@ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/EA0mk/src/sensitivity_interface.jl:6
[15] (::DiffEqSensitivity.var"#adjoint_sensitivity_backpass#210"{Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Tsit5, InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP, Bool}, Vector{Float64}, Vector{Float64}, Tuple{}, Colon, NamedTuple{(), Tuple{}}})(Δ::Matrix{Float64})
@ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/EA0mk/src/concrete_solve.jl:227
[16] ZBack
@ ~/.julia/packages/Zygote/TaBlo/src/compiler/chainrules.jl:91 [inlined]
[17] (::Zygote.var"#209#210"{Tuple{NTuple{4, Nothing}, Tuple{Nothing}}, Zygote.ZBack{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#210"{Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Tsit5, InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP, Bool}, Vector{Float64}, Vector{Float64}, Tuple{}, Colon, NamedTuple{(), Tuple{}}}}})(Δ::Matrix{Float64})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/lib/lib.jl:203
[18] (::Zygote.var"#1786#back#211"{Zygote.var"#209#210"{Tuple{NTuple{4, Nothing}, Tuple{Nothing}}, Zygote.ZBack{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#210"{Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Tsit5, InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP, Bool}, Vector{Float64}, Vector{Float64}, Tuple{}, Colon, NamedTuple{(), Tuple{}}}}}})(Δ::Matrix{Float64})
[19] Pullback
@ ~/.julia/packages/DiffEqBase/ge0vq/src/solve.jl:73 [inlined]
[20] (::typeof(∂(#solve#59)))(Δ::Matrix{Float64})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[21] (::Zygote.var"#209#210"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, typeof(∂(#solve#59))})(Δ::Matrix{Float64})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/lib/lib.jl:203
[22] (::Zygote.var"#1786#back#211"{Zygote.var"#209#210"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, typeof(∂(#solve#59))}})(Δ::Matrix{Float64})
[23] Pullback
@ ~/.julia/packages/DiffEqBase/ge0vq/src/solve.jl:68 [inlined]
[24] (::typeof(∂(solve##kw)))(Δ::Matrix{Float64})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[25] Pullback
@ ~/.julia/packages/SciMLBase/cU5k7/src/ensemble/basic_ensemble_solve.jl:143 [inlined]
[26] (::typeof(∂(#batch_func#458)))(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[27] Pullback
@ ~/.julia/packages/SciMLBase/cU5k7/src/ensemble/basic_ensemble_solve.jl:139 [inlined]
[28] (::typeof(∂(batch_func##kw)))(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[29] Pullback
@ ~/.julia/packages/SciMLBase/cU5k7/src/ensemble/basic_ensemble_solve.jl:195 [inlined]
[30] (::typeof(∂(λ)))(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[31] (::DiffEqBase.var"#204#212")(f::typeof(∂(λ)), δ::Vector{Float64})
@ DiffEqBase ~/.julia/packages/DiffEqBase/ge0vq/src/init.jl:137
[32] responsible_map(::Function, ::Vector{typeof(∂(λ))}, ::Vararg{Any, N} where N)
@ SciMLBase ~/.julia/packages/SciMLBase/cU5k7/src/ensemble/basic_ensemble_solve.jl:188
[33] (::DiffEqBase.var"#∇responsible_map_internal#211"{Vector{typeof(∂(λ))}})(Δ::EnsembleSolution{Float64, 2, Vector{Vector{Float64}}})
@ DiffEqBase ~/.julia/packages/DiffEqBase/ge0vq/src/init.jl:137
[34] #157#back
[35] Pullback
@ ~/.julia/packages/SciMLBase/cU5k7/src/ensemble/basic_ensemble_solve.jl:194 [inlined]
[36] (::typeof(∂(#solve_batch#462)))(Δ::EnsembleSolution{Float64, 2, Vector{Vector{Float64}}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
--- the last 2 lines are repeated 1 more time ---
[39] Pullback
@ ~/.julia/packages/SciMLBase/cU5k7/src/ensemble/basic_ensemble_solve.jl:203 [inlined]
[40] (::typeof(∂(#solve_batch#465)))(Δ::EnsembleSolution{Float64, 2, Vector{Vector{Float64}}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[41] Pullback
@ ~/.julia/packages/SciMLBase/cU5k7/src/ensemble/basic_ensemble_solve.jl:201 [inlined]
[42] (::typeof(∂(solve_batch##kw)))(Δ::EnsembleSolution{Float64, 2, Vector{Vector{Float64}}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[43] macro expansion
@ ./timing.jl:287 [inlined]
[44] Pullback
@ ~/.julia/packages/SciMLBase/cU5k7/src/ensemble/basic_ensemble_solve.jl:108 [inlined]
[45] (::typeof(∂(#__solve#457)))(Δ::Vector{Vector{Float64}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[46] Pullback
@ ~/.julia/packages/SciMLBase/cU5k7/src/ensemble/basic_ensemble_solve.jl:103 [inlined]
[47] (::typeof(∂(__solve##kw)))(Δ::Vector{Vector{Float64}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[48] Pullback
@ ~/.julia/packages/SciMLBase/cU5k7/src/ensemble/basic_ensemble_solve.jl:87 [inlined]
[49] (::typeof(∂(#__solve#456)))(Δ::Vector{Vector{Float64}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
``````
`````` [50] Pullback
@ ~/.julia/packages/SciMLBase/cU5k7/src/ensemble/basic_ensemble_solve.jl:58 [inlined]
[51] (::typeof(∂(__solve##kw)))(Δ::Vector{Vector{Float64}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[52] #209
@ ~/.julia/packages/Zygote/TaBlo/src/lib/lib.jl:203 [inlined]
[53] (::Zygote.var"#1786#back#211"{Zygote.var"#209#210"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(__solve##kw))}})(Δ::Vector{Vector{Float64}})
[54] Pullback
@ ~/.julia/packages/DiffEqBase/ge0vq/src/solve.jl:99 [inlined]
[55] (::typeof(∂(#solve#61)))(Δ::Vector{Vector{Float64}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[56] (::Zygote.var"#209#210"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(#solve#61))})(Δ::Vector{Vector{Float64}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/lib/lib.jl:203
[57] #1786#back
[58] Pullback
@ ~/.julia/packages/DiffEqBase/ge0vq/src/solve.jl:96 [inlined]
[59] (::typeof(∂(solve##kw)))(Δ::Vector{Vector{Float64}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[60] Pullback
@ ~/Projects/TauPet/tests/hierarchicalensemble.jl:61 [inlined]
[61] (::typeof(∂(#18)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[62] macro expansion
@ ~/.julia/packages/DynamicPPL/F7F1M/src/model.jl:0 [inlined]
[63] Pullback
@ ~/.julia/packages/DynamicPPL/F7F1M/src/model.jl:156 [inlined]
[64] (::typeof(∂(_evaluate)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[65] Pullback
@ ~/.julia/packages/DynamicPPL/F7F1M/src/model.jl:146 [inlined]
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[67] Pullback
@ ~/.julia/packages/DynamicPPL/F7F1M/src/model.jl:99 [inlined]
[68] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[69] Pullback
@ ~/.julia/packages/DynamicPPL/F7F1M/src/model.jl:91 [inlined]
[70] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[71] #209
@ ~/.julia/packages/Zygote/TaBlo/src/lib/lib.jl:203 [inlined]
[72] #1786#back
[73] Pullback
@ ~/.julia/packages/DynamicPPL/F7F1M/src/model.jl:104 [inlined]
[74] Pullback
[75] (::typeof(∂(λ)))(Δ::Int64)
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[76] (::Zygote.var"#46#47"{typeof(∂(λ))})(Δ::Int64)
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:41
[79] (::var"#∂logπ∂θ#20"{DynamicPPL.TypedVarInfo{NamedTuple{(:σ, :Km, :Ks, :Am, :As, :k, :a), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:σ, Tuple{}}, Int64}, Vector{InverseGamma{Float64}}, Vector{AbstractPPL.VarName{:σ, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:Km, Tuple{}}, Int64}, Vector{Truncated{Normal{Float64}, Continuous, Float64}}, Vector{AbstractPPL.VarName{:Km, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:Ks, Tuple{}}, Int64}, Vector{Truncated{Normal{Float64}, Continuous, Float64}}, Vector{AbstractPPL.VarName{:Ks, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:Am, Tuple{}}, Int64}, Vector{Truncated{Normal{Float64}, Continuous, Float64}}, Vector{AbstractPPL.VarName{:Am, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:As, Tuple{}}, Int64}, Vector{Truncated{Normal{Float64}, Continuous, Float64}}, Vector{AbstractPPL.VarName{:As, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:k, Tuple{}}, Int64}, Vector{Product{Continuous, Truncated{Normal{Float64}, Continuous, Float64}, FillArrays.Fill{Truncated{Normal{Float64}, Continuous, Float64}, 1, Tuple{Base.OneTo{Int64}}}}}, Vector{AbstractPPL.VarName{:k, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:a, Tuple{}}, Int64}, Vector{Product{Continuous, Truncated{Normal{Float64}, Continuous, Float64}, FillArrays.Fill{Truncated{Normal{Float64}, Continuous, Float64}, 1, Tuple{Base.OneTo{Int64}}}}}, Vector{AbstractPPL.VarName{:a, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, DynamicPPL.Sampler{NUTS{Turing.Core.ZygoteAD, (), AdvancedHMC.DiagEuclideanMetric}}, DynamicPPL.Model{var"#18#19", (:data, :prob, :N), (), (), Tuple{Vector{Float64}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, typeof(NetworkFKPP), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Int64}, Tuple{}}})(x::Vector{Float64})
@ Main ~/Projects/TauPet/tests/hierarchicalensemble.jl:73
in expression starting at /Users/pavanchaggar/Projects/TauPet/tests/hierarchicalensemble.jl:84
``````

Here’s what I’ve observed while playing with your code. Maybe it helps finding the source of this issue.

``````Can't differentiate foreigncall expression
...
in Pullback at Distributions/HjzA0/src/univariate/continuous/normal.jl:175
in  at Zygote/TaBlo/src/compiler/interface2.jl
in Pullback at SpecialFunctions/LC8dm/src/erf.jl:8
``````

`BoundsError: attempt to access 1-element Vector{Float64} at index [0]`

`DimensionMismatch("Inconsistent array dimensions.")` # 2.0 in sol.t

`AssertionError: IndexStyle(value) === IndexLinear()`

`L = Array(L)`
`Can't differentiate foreigncall expression`

`Zygote.@nograd Distributions.erfc`
`Zygote.@nograd Distributions.logabsgamma`
`L = Array(L)`
`(-54.85694429029243, [-29.442063300563014, 0.8204800651344112, -0.9993931991679814, 4.759777195648575, -0.21835507381571337, -0.8427305609685282, -3.9747116088690357])`

In case this is not known yet:

``````┌ Warning: `convert(::Type{<:NamedTuple}, t::Tangent{<:Any, <:NamedTuple})` is deprecated, use `backing(t)` instead.
│   caller = wrap_chainrules_output at chainrules.jl:66 [inlined]
└ @ Core ~/.julia/packages/Zygote/TaBlo/src/compiler/chainrules.jl:66
``````

I think I was facing the same issue in a different context; `EnsembleProblem` but with `saveat=1.0`.
I could fix my issue by using `save_start=true` and later returning everything but the first timestep to get the solution from `1.0:end`.

So with `InterpolatingAdjoint` and

``````function make_prob_func(k, a)
function prob_func(prob,i,repeat)
remake(prob, p=[k[i], a[i]], saveat=0.5, save_start=false)
end
end

function output_func(sol,i)
@show sol.t
(vec(sol[:,end-2:end]),false)
end
``````

I get a `BoundsError`

`save_start=true` raises a `Can't differentiate foreigncall expression`.

Is an output of `sol.t = [0.0, 0.5, 1.0, 1.5, 2.0]` an expected output for `saveat=0.5, save_start=false` ?

Edit: If `saveat` is an Array I don’t get the `Can't differentiate foreigncall expression` error

With `save_start = true` I still get the same error. For my problem, the `saveat` array still needs to be an array of value, including the initial value.
It’s still unclear to me why zygote breaks in these cases.

``````Is an output of `sol.t = [0.0, 0.5, 1.0, 1.5, 2.0]` an expected output for `saveat=0.5, save_start=false` ?
``````

I don’t get this behaviour. For me, this results in the initial value not being saved.

Here’s my environment info, incase that’s helpful:

``````(TauPet) pkg> status
Status `~/Projects/TauPet/Project.toml`
[6e4b80f9] BenchmarkTools v1.1.1
[41bf760c] DiffEqSensitivity v6.57.0
[0c46a032] DifferentialEquations v6.18.0
[31c24e10] Distributions v0.25.11
[634d3b9d] DrWatson v2.1.2
[5789e2e9] FileIO v1.10.1
[f6369f11] ForwardDiff v0.10.19
[7073ff75] IJulia v1.23.2
[824d6782] JSServe v1.2.3
[5078a376] LazyArrays v0.21.14
[093fc24a] LightGraphs v1.3.5
[bdcacae8] LoopVectorization v0.12.61
[23992714] MAT v0.10.1
[c7f686f2] MCMCChains v4.13.1
[6fafb56a] Memoization v0.1.13
[7269a6da] MeshIO v0.4.7
[91a5bcdd] Plots v1.20.0
[c3e4b0f8] Pluto v0.15.1
[7f904dfe] PlutoUI v0.7.9
[37e2e3b7] ReverseDiff v1.9.0
[47aef6b3] SimpleWeightedGraphs v1.1.1
[90137ffa] StaticArrays v1.2.9
[f3b207a7] StatsPlots v0.14.26
[fce5fe82] Turing v0.16.6
[ea0860ee] TuringCallbacks v0.1.3
[276b4fcb] WGLMakie v0.4.4
[e88e6eb3] Zygote v0.6.17
[8ba89e20] Distributed
``````

What is the actual example here? I am lost. Can you show me something that cuts out all of the random PPL stuff and just shows the ODE being differentiated and giving the error? I pasted what you had above in my REPL and it looked fine.

Hi,

Sorry for my late reply – I’ve been working toward some deadlines so had to put this on the back burner.

I’ve tested differentiated the ensemble ODE using the same set up as the original post here and that works fine. Code here:

``````using DiffEqSensitivity, DifferentialEquations, ForwardDiff, Zygote

N = 10
P = 0.5
L = erdos_renyi(N, P) |> laplacian_matrix

function NetworkFKPP(du, u, p, t)
du .= -p[1] * L * u .+ p[2] .* u .* (1 .- u)
end

prob = ODEProblem(NetworkFKPP, rand(N), (0.0,2.0), [0.1,1.5])
sol = solve(prob, Tsit5(), saveat=0.2)

function sum_of_solution(x)
_prob = remake(prob,u0=x[1:N],p=x[N+1:end])
sum(solve(_prob,Tsit5(),saveat=0.1))
end

# Testing ensemble problem. Works with ForwardDiff. Does not work with Zygote.
n = 3
eu0 = rand(n,N)
ep = rand(n,2)

ensemble_prob = EnsembleProblem(prob, prob_func=(prob, i, repeat)->remake(prob, u0=eu0[i,:], p=ep[i,:],saveat=[0.0,1.0,2.0]))

esol = solve(ensemble_prob, Tsit5(), trajectories=n)

function sum_of_e_solution(p)
ensemble_prob = EnsembleProblem(prob, prob_func=(prob, i, repeat)->remake(prob, u0=eu0[i,:],p=p[i,:],saveat=[0.0,1.0,1.9]))
sol = solve(ensemble_prob,Tsit5(),trajectories=n)
sum(Array(sol[1])) # just test for the first solutions, gradients should be zero for others
end

sum_of_e_solution(ep)

``````

The code below should fail. It’s the same as before but with a slightly simpler Turing model. As before, using Zygote this works for `saveat=[0.0,1.0,2.0]` but not `saveat=[0.0,1.0,1.9]`. Both work with ForwardDiff.

``````using SimpleWeightedGraphs, LightGraphs
using Turing, Zygote, ForwardDiff
using DifferentialEquations, DiffEqSensitivity
using Random
Random.seed!(1)

N = 10
P = 0.5
L = erdos_renyi(N, P) |> laplacian_matrix

function NetworkFKPP(du, u, p, t)
du .= -p[1] * L * u .+ p[2] .* u .* (1 .- u)
end

prob = ODEProblem(NetworkFKPP, rand(N), (0.0,2.0), [0.1,1.5])
sol = solve(prob, Tsit5(), saveat=0.2)

plot(sol, labels=false)

function make_prob_func(k, a)
function prob_func(prob,i,repeat)
remake(prob, p=[k[i], a[i]], saveat=[0.0,1.0,1.9])
end
end

function output_func(sol,i)
(vec(sol),false)
end

ensemble_prob = EnsembleProblem(prob, prob_func=make_prob_func([1.0], [1.0]), output_func=output_func)

esol = solve(ensemble_prob, Tsit5(), trajectories=1, sensalg=InterpolatingAdjoint(autojacvec=ZygoteVJP()))

@model function fitdata(data, prob, N)
σ ~ InverseGamma(2, 3)

k ~ filldist(truncated(Normal(0.0, 5.0), 0.0, Inf), N)
a ~ filldist(truncated(Normal(0.0, 5.0), 0.0, Inf), N)

ensemble_prob = EnsembleProblem(prob,prob_func=make_prob_func(k, a), output_func=output_func)
predicted = solve(ensemble_prob, Tsit5(), trajectories=N, sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP()))

data ~ MvNormal(predicted[1], σ)
end

model = fitdata(esol[1], prob, 1)

model()

function gen_∂logπ∂θ(vi, spl, model)
function ∂logπ∂θ(x)
end
return ∂logπ∂θ
end

var_info = Turing.VarInfo(model)

spl = DynamicPPL.Sampler(NUTS(.65))

∂logπ∂θ = gen_∂logπ∂θ(var_info, spl, model)
∂logπ∂θ(var_info[spl])
``````

I’ll continue to try to debug now that some of my deadlines have passed, but I’m pretty stuck. Any tips are appreciated!

Thanks,
Pavan

What ever happened with this? Was an issue opened and solved?

Hi Chris,

I wasn’t able to work it out and then got caught up again with PhD and paper deadlines. I’ll check again tomorrow and post an issue if it’s still not working. Where would be the best place for an issue? DiffEqSensitivity?

Yes DiffEqSensitivity.

Issue posted here.

1 Like