I need to work with individual Wiener increments from an SDE solution in my loss function. Whenever I try to get these values from sol.W.W
command, the ERROR: type Nothing has no field W
pops up when computing the gradients. Can it be a bug in DiffEqSensitivity package?
MWE:
using DifferentialEquations, Flux, DiffEqFlux
using DiffEqSensitivity
function dt!(du, u, p, t)
x, y = u
α, β, δ, γ = p
du[1] = dx = α*x - β*x*y
du[2] = dy = -δ*y + γ*x*y
end
function dW!(du, u, p, t)
du[1] = 0.1u[1]
du[2] = 0.1u[2]
end
u0 = [1.0,1.0]
tspan = (0.0, 10.0)
p = [2.2, 1.0, 2.0, 0.4]
W = WienerProcess(0.0,0.0,0.0)
prob_sde = SDEProblem(dt!, dW!, u0, tspan,p,noise=W)
function some_loss(p) #some loss function
sol=solve(prob_sde, EM(), saveat = 0.1,sensealg = ForwardDiffSensitivity(), dt=0.001)
Wsum=sum(sol.W.W) #problematic part
println("Wsum=",Wsum)
sum(abs2, x-1 for x in Array(sol))
end
ps = Flux.params(p)
@time gs = gradient(ps) do
some_loss(p)
end #ERROR: type Nothing has no field W
1 Like
I think the problem traces back to sol.W
being nothing
in the backward pass. Using ForwardDiff
instead of Zygote
to compute the gradients
@time gs = ForwardDiff.gradient(p) do p
some_loss(p)
end
removes the error. If reverse-mode AD is needed, you could use pre-defined noise values and NoiseGrid
from the DiffEqNoiseProcess
package
t = Array(tspan[0]:dt:tspan[1])
Z = randn(length(t))
Z1 = cumsum([0;sqrt(dt)*Z[1:end-1]])
Zygote.ignore() do
NG = NoiseGrid(t,Z1)
end
tmp_prob = remake(prob,p=p,noise=NG)
as a possible workaround. @ChrisRackauckas, do you have a better idea for using the noise values in a loss function with Zygote
?
So we would need to calculate the derivative w.r.t. the Brownian motion? I think that’s defined by the Malliavan calculus but I don’t think we’ve implemented it in the adjoint expressions. I think @frankschae is right that the best way to do this right now would be to define the AbstractNoiseProcess
directly as a NoiseGrid
with the variables from the AD system so that you supply them and AD will differentiate them. That would at least be a way to do it today.
Thank you both for comments and suggestions!