Hello
I am trying to add Manual VJPs to a project I am working on using adjoint_sensitivities (from DiffEqSensitivity.jl), but the documentation seems missing information:
I have two questions:
-
The function
vjp(u,p,t)
should return a tuplef(u,p,t),v->J*v
, but it never says whatf(u,p,t)
is. -
Isn’t
J*v
a jvp, not a vjp? Am I missing or misunderstanding something here?
I have an example here. Could someone please help me add a manual vjp to this toy example:
using DifferentialEquations
using ModelingToolkit
using ForwardDiff
using LinearAlgebra
using SciMLSensitivity
using Calculus
using Zygote
using ReverseDiff
using BenchmarkTools
function getSys()
### Define independent and dependent variables
ModelingToolkit.@variables t u1(t) u2(t) u_force(t)
### Define parameters
ModelingToolkit.@parameters p1, p2, p3, p4
### Define an operator for the differentiation w.r.t. time
D = Differential(t)
### Derivatives ###
eqs = [
D(u1) ~ p1 - p2*u1*u2 - p2*u1
D(u2) ~ -p3*u2*u_force + p4*u1*u2
u_force ~ 2.0 #ifelse(t > 2.0, p1, p2)
]
@named sys = ODESystem(eqs)
### Initial species concentrations ###
stateMap = [
u1 => 8,
u2 => 4]
### SBML file parameter values ###
paramMap = [
p1 => 5.0,
p2 => 1.0,
p3 => 1.0,
p4 => 3.0]
return sys, stateMap, paramMap
end
# Cost function G taking ODE sol (solution) as argument
function G(sol)
cost = 0.0
for i in eachindex(1:1:10)
cost += ((sol[1, i] - 1.0)^2 + (sol[2, i] - 1.0)^2)
end
return cost
end
# Compute simple cost by solving ODE system (compatible with ForwardDiff)
function calc_cost(vecEst, prob)
pDyn = vecEst[:]
probUse = remake(prob, u0=convert.(eltype(vecEst), [1.0, 1.0]), tspan=(0.0, 11.0), p=pDyn)
probUse.u0[1] = sin(vecEst[1] * vecEst[2])
sol = solve(probUse, Rodas4P(), abstol=1e-8, reltol=1e-8, saveat=1:1:10)
cost1 = G(sol)
return cost1
end
# For the lower level interface
function my_dgdu(out, u, p, t, i)
out[1] = 2.0 * (u[1] - 1.0)
out[2] = 2.0 * (u[2] - 1.0)
end
# To compute the sensitivites at time zero
function setU0(θ)
u0 = zeros(eltype(θ), 2)
u0[1] = sin(θ[1] * θ[2])
return u0
end
# As a challange let us use the initial conditions u0[1] = sin(p[1]*p[2]), u0[2] = 1.0
sys, stateMap, paramMap = getSys()
sysSimple = structural_simplify(sys)
odeProb = ODEProblem(sysSimple, stateMap, (0.0, 11.0), paramMap, jac=true)
calc_cost_use = (pArg) -> calc_cost(pArg, odeProb)
p = [4.0, 2.0, 3.0, 1.0]
# Via ForwardMode
cost = calc_cost_use(p)
grad_forward = ForwardDiff.gradient(calc_cost_use, p)
# Via adjoint sensitivity analysis
S_MAT = ForwardDiff.jacobian(setU0, p)
probUse = remake(odeProb, u0=[0.9893582466233818, 1.0], tspan=(0.0, 11.0), p=p)
sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP())
sol = solve(probUse, Rodas4P(), abstol=1e-9, reltol=1e-9)
du, dp = adjoint_sensitivities(sol, Rodas4P(), dgdu_discrete=my_dgdu, t=collect(1:1:10),
sensealg=sensealg, abstol=1e-9, reltol=1e-9)
dp .+= reshape(du, (1, 2)) * S_MAT # Account for sensitivites at time zero
println("Grad_forward = ", grad_forward')
println("Grad_adjoint = ", dp)