I am trying to differentiate the solution of a differential equation using DifferentialEquation
with respect to some parameters but ignoring parts of the calculation of the gradient for being redundant or computationally expensive to compute. However, I cannot manage to ignore parts of the forward model when computing the gradient using Zygote
. I am including next a MWE.
We can compute the gradient of the solution of a simple ODE with respect of the vector parameter p
as follows
using DifferentialEquations
using Zygote, SciMLSensitivity
using Plots
using DiffEqFlux
using ChainRulesCore
using Zygote: @ignore
p = [0.1, 0.2]
function dynamics(du, u, p, t)
du[1] = - p[1] * u[1] + p[2]
end
dp = Zygote.gradient(p -> solve(ODEProblem(dynamics,
[10.0],
(0.0,10.0),
tstops=[4.0],
p), Tsit5()).u[end][1], p)
which results in the final calculation of dp=([-42.072752200991175, 6.321205292676615],)
. Now, I would like to consider a case in which the dependency of the solution with one of the parameters, let say p[2]
is ignored. Zygote allows ignoring certain computations of the gradient by using the macro @ignore
, for example in the following example:
using Zygote: @ignore
function foo(x)
y = @ignore x
return y*x
end
where the computed gradient gives the formula f'(x) = x
instead of f'(x) = 2x
. However, running the previous example with the ignore
macro inside dynamics()
leads to the same numerical value of the gradient
function dynamics2(du, u, p, t)
offset = @ignore p[2]
du[1] = - p[1] * u[1] + offset
end
dp2 = Zygote.gradient(p -> solve(ODEProblem(dynamics2,
[10.0],
(0.0,10.0),
p), Tsit5()).u[end][1], p)
where dp2 = ([-42.072752200991175, 6.321205292676615],)
.
Does anyone knows if @ignore
is supported for differential equations? There is a chance I am also missing something about the behavior of @ignore
, but my understanding is that this command should ignore the dependency of certain parts of the code at the moment of applying AD.
Thank you!