Problems using BacksolveAdjoint for an SDE with DiffEqFlux

I’m trying to train a SDE with DiffEqFlux. I had issues with stack overflows when using TrackerAdjoint (probably because of the size of the problem), so i tried using BacksolveAdjoint but this errors on calculating the gradients. This is small MWE (the actual problem is much larger) of what I did:

using Flux
using DiffEqFlux
using DifferentialEquations

f = FastChain(
    FastDense(3, 2, tanh),
    FastDense(2, 3),
) 
g = FastChain(
    FastDense(3, 2, tanh),
    FastDense(2, 3),
)

x = rand(Float32, 3, 4)
y = rand(Float32, 3, 4)

neuralsde = NeuralDSDE(f, g, (0f0, 1f0), sensealg=BacksolveAdjoint())
#neuralsde = NeuralDSDE(f, g, (0f0, 1f0), sensealg=TrackerAdjoint()) <- defining the problem like this works

ps = Flux.params(neuralsde.p)
Flux.gradient(ps) do
    ŷ = neuralsde(x)[:,:,end]
    Flux.Losses.mse(ŷ, y)
end

It gives the following error:

ERROR: UndefVarError: xs not defined
Stacktrace:
  [1] (::Zygote.var"#442#443")(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/lib/array.jl:79
  [2] (::Zygote.var"#2383#back#444"{Zygote.var"#442#443"})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [3] Pullback
    @ ./broadcast.jl:894 [inlined]
  [4] Pullback
    @ ./broadcast.jl:891 [inlined]
  [5] Pullback
    @ ./broadcast.jl:887 [inlined]
  [6] (::typeof(∂(materialize!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
  [7] Pullback
    @ ~/.julia/packages/Zygote/TaBlo/src/lib/array.jl:44 [inlined]
  [8] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
  [9] Pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [10] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [11] Pullback
    @ ~/.julia/packages/DiffEqFlux/N7blG/src/neural_de.jl:147 [inlined]
 [12] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Matrix{Float32}, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [13] Pullback
    @ ~/.julia/packages/DiffEqSensitivity/cLl5o/src/sde_tools.jl:73 [inlined]
 [14] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Matrix{Float32}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [15] Pullback
    @ ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:41 [inlined]
 [16] (::typeof(∂(λ)))(Δ::Tuple{Matrix{Float32}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [17] Pullback
    @ ~/.julia/packages/DiffEqSensitivity/cLl5o/src/sde_tools.jl:75 [inlined]
 [18] (::typeof(∂(λ)))(Δ::Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Tuple{}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [19] #204#back
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [20] Pullback
    @ ~/.julia/packages/DiffEqSensitivity/cLl5o/src/derivative_wrappers.jl:454 [inlined]
 [21] (::typeof(∂(λ)))(Δ::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [22] (::Zygote.var"#46#47"{typeof(∂(λ))})(Δ::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:41
...

According to the docs the BacksolveAdjoint method can be used for SDE’s. I’m I using it incorrect? Is there a way to make it work?

Looks like you’re on an older version of Zygote.

The error seems to be that something is mutating? What package version?

Thanks for your help! I was on Zygote 0.6.17, updating to 0.6.21 (and with that downgrading DiffEqFlux) does indeed reveal an mutating arrays error.

versions now:

  [aae7a2af] DiffEqFlux v1.39.0
  [0c46a032] DifferentialEquations v6.19.0
  [587475ba] Flux v0.12.6
  [0bca4576] SciMLBase v1.18.7
  [e88e6eb3] Zygote v0.6.21

new error:

ERROR: Mutating arrays is not supported -- called copyto!(::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, _...)
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.var"#440#441"{SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}})(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/lib/array.jl:74
  [3] (::Zygote.var"#2374#back#442"{Zygote.var"#440#441"{SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [4] Pullback
    @ ./broadcast.jl:894 [inlined]
  [5] Pullback
    @ ./broadcast.jl:891 [inlined]
  [6] Pullback
    @ ./broadcast.jl:887 [inlined]
  [7] (::typeof(∂(materialize!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
  [8] Pullback
    @ ~/.julia/packages/Zygote/ajuwN/src/lib/array.jl:39 [inlined]
  [9] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
 [10] Pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [11] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
 [12] Pullback
    @ ~/Projects/Markets2/src/model.jl:52 [inlined]
 [13] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Matrix{Float32}, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
 [14] Pullback
    @ ~/.julia/packages/DiffEqSensitivity/EpC0d/src/sde_tools.jl:73 [inlined]
 [15] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Matrix{Float32}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
 [16] Pullback
    @ ~/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:41 [inlined]
 [17] (::typeof(∂(λ)))(Δ::Tuple{Matrix{Float32}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
 [18] Pullback
    @ ~/.julia/packages/DiffEqSensitivity/EpC0d/src/sde_tools.jl:75 [inlined]
 [19] (::typeof(∂(λ)))(Δ::Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Tuple{}})
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
 [20] #204#back
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [21] Pullback
    @ ~/.julia/packages/DiffEqSensitivity/EpC0d/src/derivative_wrappers.jl:428 [inlined]
 [22] (::typeof(∂(λ)))(Δ::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true})
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
 [23] (::Zygote.var"#50#51"{typeof(∂(λ))})(Δ::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true})
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:41
...

Might it have something to do with al the Nothing types in the stacktrace? Like the gradient does not have a type?

@frankschae is this just the mutation in the Strat->Ito transformation?

yeah that could be the issue… I’ll need to dig into this. I made a note.