Using the noise parameters sol.W.W from an SDE solution throws an error when taking gradients

I need to work with individual Wiener increments from an SDE solution in my loss function. Whenever I try to get these values from sol.W.W command, the ERROR: type Nothing has no field W pops up when computing the gradients. Can it be a bug in DiffEqSensitivity package?


using DifferentialEquations, Flux,  DiffEqFlux
using  DiffEqSensitivity

function dt!(du, u, p, t)
  x, y = u
  α, β, δ, γ = p
  du[1] = dx = α*x - β*x*y
  du[2] = dy = -δ*y + γ*x*y

function dW!(du, u, p, t)
  du[1] = 0.1u[1]
  du[2] = 0.1u[2]

u0 = [1.0,1.0]
tspan = (0.0, 10.0)
p = [2.2, 1.0, 2.0, 0.4]
W = WienerProcess(0.0,0.0,0.0)
prob_sde = SDEProblem(dt!, dW!, u0, tspan,p,noise=W)

function some_loss(p) #some loss function
  sol=solve(prob_sde, EM(),  saveat = 0.1,sensealg = ForwardDiffSensitivity(), dt=0.001)
  Wsum=sum(sol.W.W) #problematic part
  sum(abs2, x-1 for x in Array(sol))

ps = Flux.params(p)
@time gs = gradient(ps) do
end #ERROR: type Nothing has no field W

1 Like

I think the problem traces back to sol.W being nothing in the backward pass. Using ForwardDiff instead of Zygote to compute the gradients

@time gs = ForwardDiff.gradient(p) do p

removes the error. If reverse-mode AD is needed, you could use pre-defined noise values and NoiseGrid from the DiffEqNoiseProcess package

  t = Array(tspan[0]:dt:tspan[1])
  Z = randn(length(t))
  Z1 = cumsum([0;sqrt(dt)*Z[1:end-1]])
  Zygote.ignore() do
     NG = NoiseGrid(t,Z1)
  tmp_prob = remake(prob,p=p,noise=NG)

as a possible workaround. @ChrisRackauckas, do you have a better idea for using the noise values in a loss function with Zygote?

So we would need to calculate the derivative w.r.t. the Brownian motion? I think that’s defined by the Malliavan calculus but I don’t think we’ve implemented it in the adjoint expressions. I think @frankschae is right that the best way to do this right now would be to define the AbstractNoiseProcess directly as a NoiseGrid with the variables from the AD system so that you supply them and AD will differentiate them. That would at least be a way to do it today.

Thank you both for comments and suggestions!