Problems using BacksolveAdjoint for an SDE with DiffEqFlux

I’m trying to train a SDE with DiffEqFlux. I had issues with stack overflows when using TrackerAdjoint (probably because of the size of the problem), so i tried using BacksolveAdjoint but this errors on calculating the gradients. This is small MWE (the actual problem is much larger) of what I did:

using Flux
using DiffEqFlux
using DifferentialEquations

f = FastChain(
    FastDense(3, 2, tanh),
    FastDense(2, 3),
) 
g = FastChain(
    FastDense(3, 2, tanh),
    FastDense(2, 3),
)

x = rand(Float32, 3, 4)
y = rand(Float32, 3, 4)

neuralsde = NeuralDSDE(f, g, (0f0, 1f0), sensealg=BacksolveAdjoint())
#neuralsde = NeuralDSDE(f, g, (0f0, 1f0), sensealg=TrackerAdjoint()) <- defining the problem like this works

ps = Flux.params(neuralsde.p)
Flux.gradient(ps) do
    ŷ = neuralsde(x)[:,:,end]
    Flux.Losses.mse(ŷ, y)
end

It gives the following error:

ERROR: UndefVarError: xs not defined
Stacktrace:
  [1] (::Zygote.var"#442#443")(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/lib/array.jl:79
  [2] (::Zygote.var"#2383#back#444"{Zygote.var"#442#443"})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [3] Pullback
    @ ./broadcast.jl:894 [inlined]
  [4] Pullback
    @ ./broadcast.jl:891 [inlined]
  [5] Pullback
    @ ./broadcast.jl:887 [inlined]
  [6] (::typeof(∂(materialize!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
  [7] Pullback
    @ ~/.julia/packages/Zygote/TaBlo/src/lib/array.jl:44 [inlined]
  [8] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
  [9] Pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [10] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [11] Pullback
    @ ~/.julia/packages/DiffEqFlux/N7blG/src/neural_de.jl:147 [inlined]
 [12] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Matrix{Float32}, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [13] Pullback
    @ ~/.julia/packages/DiffEqSensitivity/cLl5o/src/sde_tools.jl:73 [inlined]
 [14] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Matrix{Float32}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [15] Pullback
    @ ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:41 [inlined]
 [16] (::typeof(∂(λ)))(Δ::Tuple{Matrix{Float32}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [17] Pullback
    @ ~/.julia/packages/DiffEqSensitivity/cLl5o/src/sde_tools.jl:75 [inlined]
 [18] (::typeof(∂(λ)))(Δ::Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Tuple{}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [19] #204#back
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [20] Pullback
    @ ~/.julia/packages/DiffEqSensitivity/cLl5o/src/derivative_wrappers.jl:454 [inlined]
 [21] (::typeof(∂(λ)))(Δ::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [22] (::Zygote.var"#46#47"{typeof(∂(λ))})(Δ::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:41
...

According to the docs the BacksolveAdjoint method can be used for SDE’s. I’m I using it incorrect? Is there a way to make it work?

Looks like you’re on an older version of Zygote.

The error seems to be that something is mutating? What package version?

Thanks for your help! I was on Zygote 0.6.17, updating to 0.6.21 (and with that downgrading DiffEqFlux) does indeed reveal an mutating arrays error.

versions now:

  [aae7a2af] DiffEqFlux v1.39.0
  [0c46a032] DifferentialEquations v6.19.0
  [587475ba] Flux v0.12.6
  [0bca4576] SciMLBase v1.18.7
  [e88e6eb3] Zygote v0.6.21

new error:

ERROR: Mutating arrays is not supported -- called copyto!(::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, _...)
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.var"#440#441"{SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}})(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/lib/array.jl:74
  [3] (::Zygote.var"#2374#back#442"{Zygote.var"#440#441"{SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [4] Pullback
    @ ./broadcast.jl:894 [inlined]
  [5] Pullback
    @ ./broadcast.jl:891 [inlined]
  [6] Pullback
    @ ./broadcast.jl:887 [inlined]
  [7] (::typeof(∂(materialize!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
  [8] Pullback
    @ ~/.julia/packages/Zygote/ajuwN/src/lib/array.jl:39 [inlined]
  [9] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
 [10] Pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [11] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
 [12] Pullback
    @ ~/Projects/Markets2/src/model.jl:52 [inlined]
 [13] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Matrix{Float32}, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
 [14] Pullback
    @ ~/.julia/packages/DiffEqSensitivity/EpC0d/src/sde_tools.jl:73 [inlined]
 [15] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Matrix{Float32}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
 [16] Pullback
    @ ~/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:41 [inlined]
 [17] (::typeof(∂(λ)))(Δ::Tuple{Matrix{Float32}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
 [18] Pullback
    @ ~/.julia/packages/DiffEqSensitivity/EpC0d/src/sde_tools.jl:75 [inlined]
 [19] (::typeof(∂(λ)))(Δ::Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Tuple{}})
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
 [20] #204#back
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [21] Pullback
    @ ~/.julia/packages/DiffEqSensitivity/EpC0d/src/derivative_wrappers.jl:428 [inlined]
 [22] (::typeof(∂(λ)))(Δ::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true})
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
 [23] (::Zygote.var"#50#51"{typeof(∂(λ))})(Δ::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true})
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:41
...

Might it have something to do with al the Nothing types in the stacktrace? Like the gradient does not have a type?

@frankschae is this just the mutation in the Strat->Ito transformation?

yeah that could be the issue… I’ll need to dig into this. I made a note.

I think there are two issues:
The batches are not handled correctly, i.e.,

using Random
x1 = rand(Float32, 3, 4)
y1 = rand(Float32, 3, 4)
prob1 = SDEProblem{false}(ff,g_,x1,neuralsde.tspan,neuralsde.p)

function loss(p;prob = prob1, y=y1, alg=EulerHeun(),sensealg=BacksolveAdjoint())
    _prob = remake(prob,p=p)
    ŷ = solve(_prob,alg,dt=0.01f0,sensealg=sensealg)[end]
    sum(abs2,ŷ - y)
end

Zygote.gradient(p->loss(p,alg=EulerHeun(),sensealg=BacksolveAdjoint()), neuralsde.p)

fails with

ERROR: MethodError: no method matching mul!(::Vector{Float32}, ::Matrix{Float32}, ::Matrix{Float32}, ::Bool, ::Bool)
Closest candidates are:
  mul!(::StridedVecOrMat{T} where T, ::LinearAlgebra.SymTridiagonal, ::StridedVecOrMat{T} where T, ::Number, ::Number) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/tridiag.jl:214
  mul!(::StridedVecOrMat{T} where T, ::SparseArrays.AbstractSparseMatrixCSC, ::Union{LinearAlgebra.Adjoint{var"#s832", var"#s831"} where {var"#s832", var"#s831"<:Union{LinearAlgebra.LowerTriangular, LinearAlgebra.UnitLowerTriangular, LinearAlgebra.UnitUpperTriangular, LinearAlgebra.UpperTriangular, StridedMatrix{T} where T, BitMatrix}}, LinearAlgebra.LowerTriangular, LinearAlgebra.Transpose{var"#s830", var"#s829"} where {var"#s830", var"#s829"<:Union{LinearAlgebra.LowerTriangular, LinearAlgebra.UnitLowerTriangular, LinearAlgebra.UnitUpperTriangular, LinearAlgebra.UpperTriangular, StridedMatrix{T} where T, BitMatrix}}, LinearAlgebra.UnitLowerTriangular, LinearAlgebra.UnitUpperTriangular, LinearAlgebra.UpperTriangular, StridedMatrix{T} where T, StridedVector{T} where T, BitMatrix, BitVector}, ::Number, ::Number) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/SparseArrays/src/linalg.jl:30
  mul!(::StridedVecOrMat{T} where T, ::Union{LinearAlgebra.Adjoint{var"#s832", var"#s831"} where {var"#s832", var"#s831"<:Union{LinearAlgebra.LowerTriangular, LinearAlgebra.UnitLowerTriangular, LinearAlgebra.UnitUpperTriangular, LinearAlgebra.UpperTriangular, StridedMatrix{T} where T, BitMatrix}}, LinearAlgebra.LowerTriangular, LinearAlgebra.Transpose{var"#s830", var"#s829"} where {var"#s830", var"#s829"<:Union{LinearAlgebra.LowerTriangular, LinearAlgebra.UnitLowerTriangular, LinearAlgebra.UnitUpperTriangular, LinearAlgebra.UpperTriangular, StridedMatrix{T} where T, BitMatrix}}, LinearAlgebra.UnitLowerTriangular, LinearAlgebra.UnitUpperTriangular, LinearAlgebra.UpperTriangular, StridedMatrix{T} where T, BitMatrix}, ::SparseArrays.AbstractSparseMatrixCSC, ::Number, ::Number) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/SparseArrays/src/linalg.jl:87
  ...

Stacktrace:
  [1] mul!
    @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:275 [inlined]
  [2] perform_step!(integrator::StochasticDiffEq.SDEIntegrator{EulerHeun, true, Vector{Float32},

inside the diffusion function. I am not sure about the best way to fix this. Maybe we could define

via an EnsembleProblem inside DiffEqFlux, or we could write the adjoint state definition inside DiffEqSensitivity more general?

Without batches, it’s the Zygote over Zygote issue due to the Ito-Stratonovich conversion requiring additional differentiation. The issue to track here is: Ito on SDEs with adjoints: nested differentiation of Zygote · Issue #385 · SciML/DiffEqSensitivity.jl · GitHub . To cope with it, for the moment, you’d have to either rewrite the SDE in the Stratonovich sense or use an AD-vjp option for which the nesting works, like TrackerVJP().

x2 = rand(Float32, 3)
y2 = rand(Float32, 3)
prob2 = SDEProblem{false}(ff,g_,x2,neuralsde.tspan,neuralsde.p)

Zygote.gradient(p->loss(p,prob = prob2,y=y2,alg=EulerHeun(),sensealg=BacksolveAdjoint()), neuralsde.p) #works 
Zygote.gradient(p->loss(p,prob = prob2,y=y2,alg=EM(),sensealg=BacksolveAdjoint()),  neuralsde.p) #doesn't works Zygote over Zygote
Zygote.gradient(p->loss(p,prob = prob2,y=y2,alg=EM(),sensealg=BacksolveAdjoint(autojacvec=TrackerVJP())), neuralsde.p) #works