DiffEqSensitivitys Manual VJPs documentation seems missing information

Hello

I am trying to add Manual VJPs to a project I am working on using adjoint_sensitivities (from DiffEqSensitivity.jl), but the documentation seems missing information:

https://docs.juliahub.com/DiffEqSensitivity/02xYn/6.78.4/manual/differential_equation_sensitivities/#Manual-VJPs

I have two questions:

  1. The function vjp(u,p,t) should return a tuple f(u,p,t),v->J*v, but it never says what f(u,p,t) is.

  2. Isn’t J*v a jvp, not a vjp? Am I missing or misunderstanding something here?

I have an example here. Could someone please help me add a manual vjp to this toy example:

using DifferentialEquations
using ModelingToolkit
using ForwardDiff
using LinearAlgebra
using SciMLSensitivity
using Calculus
using Zygote
using ReverseDiff
using BenchmarkTools

function getSys()

    ### Define independent and dependent variables
    ModelingToolkit.@variables t u1(t) u2(t) u_force(t)

    ### Define parameters
    ModelingToolkit.@parameters p1, p2, p3, p4

    ### Define an operator for the differentiation w.r.t. time
    D = Differential(t)

    ### Derivatives ###
    eqs = [
    D(u1) ~ p1 - p2*u1*u2 - p2*u1
    D(u2) ~  -p3*u2*u_force + p4*u1*u2
    u_force ~ 2.0 #ifelse(t > 2.0, p1, p2)
    ]

    @named sys = ODESystem(eqs)

    ### Initial species concentrations ###
    stateMap = [
    u1 => 8,
    u2 => 4]

    ### SBML file parameter values ###
    paramMap = [
    p1 => 5.0,
    p2 => 1.0,
    p3 => 1.0,
    p4 => 3.0]

    return sys, stateMap, paramMap
end


# Cost function G taking ODE sol (solution) as argument
function G(sol)
    cost = 0.0
    for i in eachindex(1:1:10)
        cost += ((sol[1, i] - 1.0)^2 + (sol[2, i] - 1.0)^2) 
    end

    return cost
end

# Compute simple cost by solving ODE system (compatible with ForwardDiff)
function calc_cost(vecEst, prob)
    pDyn = vecEst[:]
    probUse = remake(prob, u0=convert.(eltype(vecEst), [1.0, 1.0]), tspan=(0.0, 11.0), p=pDyn)
    probUse.u0[1] = sin(vecEst[1] * vecEst[2])
    sol = solve(probUse, Rodas4P(), abstol=1e-8, reltol=1e-8, saveat=1:1:10)
    cost1 = G(sol)
    return cost1
end


# For the lower level interface 
function my_dgdu(out, u, p, t, i)
    out[1] = 2.0 * (u[1] - 1.0)
    out[2] = 2.0 * (u[2] - 1.0)    
end


# To compute the sensitivites at time zero
function setU0(θ)

    u0 = zeros(eltype(θ), 2)
    u0[1] = sin(θ[1] * θ[2])
    return u0
end

# As a challange let us use the initial conditions u0[1] = sin(p[1]*p[2]), u0[2] = 1.0
sys, stateMap, paramMap = getSys()
sysSimple = structural_simplify(sys)
odeProb = ODEProblem(sysSimple, stateMap, (0.0, 11.0), paramMap, jac=true)
calc_cost_use = (pArg) -> calc_cost(pArg, odeProb)
p = [4.0, 2.0, 3.0, 1.0]

# Via ForwardMode
cost = calc_cost_use(p)
grad_forward = ForwardDiff.gradient(calc_cost_use, p)

# Via adjoint sensitivity analysis
S_MAT = ForwardDiff.jacobian(setU0, p)
probUse = remake(odeProb, u0=[0.9893582466233818, 1.0], tspan=(0.0, 11.0), p=p)
sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP())
sol = solve(probUse, Rodas4P(), abstol=1e-9, reltol=1e-9)
du, dp = adjoint_sensitivities(sol, Rodas4P(), dgdu_discrete=my_dgdu, t=collect(1:1:10), 
                               sensealg=sensealg, abstol=1e-9, reltol=1e-9)

dp .+= reshape(du, (1, 2)) * S_MAT # Account for sensitivites at time zero 

println("Grad_forward = ", grad_forward')
println("Grad_adjoint = ", dp)

Yeah it’s still in need of docs since it’s not quite done. @ArnoStrouwen we should duo this and finally set it up.

@ChrisRackauckas Just checking in on how this is progressing. And is there any “unofficial way” to add the manual VJP until the documentation is in place (or do you mean that the feature is not in place yet)?

Yes, you’d define a ChainRulesCore rule on your ODE f and use ZygoteVJP.