Problems using BacksolveAdjoint for an SDE with DiffEqFlux

I’m trying to train a SDE with DiffEqFlux. I had issues with stack overflows when using TrackerAdjoint (probably because of the size of the problem), so i tried using BacksolveAdjoint but this errors on calculating the gradients. This is small MWE (the actual problem is much larger) of what I did:

using Flux
using DiffEqFlux
using DifferentialEquations

f = FastChain(
    FastDense(3, 2, tanh),
    FastDense(2, 3),
) 
g = FastChain(
    FastDense(3, 2, tanh),
    FastDense(2, 3),
)

x = rand(Float32, 3, 4)
y = rand(Float32, 3, 4)

neuralsde = NeuralDSDE(f, g, (0f0, 1f0), sensealg=BacksolveAdjoint())
#neuralsde = NeuralDSDE(f, g, (0f0, 1f0), sensealg=TrackerAdjoint()) <- defining the problem like this works

ps = Flux.params(neuralsde.p)
Flux.gradient(ps) do
    ŷ = neuralsde(x)[:,:,end]
    Flux.Losses.mse(ŷ, y)
end

It gives the following error:

ERROR: UndefVarError: xs not defined
Stacktrace:
  [1] (::Zygote.var"#442#443")(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/lib/array.jl:79
  [2] (::Zygote.var"#2383#back#444"{Zygote.var"#442#443"})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [3] Pullback
    @ ./broadcast.jl:894 [inlined]
  [4] Pullback
    @ ./broadcast.jl:891 [inlined]
  [5] Pullback
    @ ./broadcast.jl:887 [inlined]
  [6] (::typeof(∂(materialize!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
  [7] Pullback
    @ ~/.julia/packages/Zygote/TaBlo/src/lib/array.jl:44 [inlined]
  [8] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
  [9] Pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [10] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [11] Pullback
    @ ~/.julia/packages/DiffEqFlux/N7blG/src/neural_de.jl:147 [inlined]
 [12] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Matrix{Float32}, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [13] Pullback
    @ ~/.julia/packages/DiffEqSensitivity/cLl5o/src/sde_tools.jl:73 [inlined]
 [14] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Matrix{Float32}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [15] Pullback
    @ ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:41 [inlined]
 [16] (::typeof(∂(λ)))(Δ::Tuple{Matrix{Float32}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [17] Pullback
    @ ~/.julia/packages/DiffEqSensitivity/cLl5o/src/sde_tools.jl:75 [inlined]
 [18] (::typeof(∂(λ)))(Δ::Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Tuple{}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [19] #204#back
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [20] Pullback
    @ ~/.julia/packages/DiffEqSensitivity/cLl5o/src/derivative_wrappers.jl:454 [inlined]
 [21] (::typeof(∂(λ)))(Δ::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [22] (::Zygote.var"#46#47"{typeof(∂(λ))})(Δ::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:41
...

According to the docs the BacksolveAdjoint method can be used for SDE’s. I’m I using it incorrect? Is there a way to make it work?

Looks like you’re on an older version of Zygote.

https://github.com/FluxML/Zygote.jl/blob/master/src/lib/array.jl#L78

The error seems to be that something is mutating? What package version?

Thanks for your help! I was on Zygote 0.6.17, updating to 0.6.21 (and with that downgrading DiffEqFlux) does indeed reveal an mutating arrays error.

versions now:

  [aae7a2af] DiffEqFlux v1.39.0
  [0c46a032] DifferentialEquations v6.19.0
  [587475ba] Flux v0.12.6
  [0bca4576] SciMLBase v1.18.7
  [e88e6eb3] Zygote v0.6.21

new error:

ERROR: Mutating arrays is not supported -- called copyto!(::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, _...)
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.var"#440#441"{SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}})(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/lib/array.jl:74
  [3] (::Zygote.var"#2374#back#442"{Zygote.var"#440#441"{SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [4] Pullback
    @ ./broadcast.jl:894 [inlined]
  [5] Pullback
    @ ./broadcast.jl:891 [inlined]
  [6] Pullback
    @ ./broadcast.jl:887 [inlined]
  [7] (::typeof(∂(materialize!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
  [8] Pullback
    @ ~/.julia/packages/Zygote/ajuwN/src/lib/array.jl:39 [inlined]
  [9] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
 [10] Pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [11] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
 [12] Pullback
    @ ~/Projects/Markets2/src/model.jl:52 [inlined]
 [13] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Matrix{Float32}, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
 [14] Pullback
    @ ~/.julia/packages/DiffEqSensitivity/EpC0d/src/sde_tools.jl:73 [inlined]
 [15] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Matrix{Float32}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
 [16] Pullback
    @ ~/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:41 [inlined]
 [17] (::typeof(∂(λ)))(Δ::Tuple{Matrix{Float32}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
 [18] Pullback
    @ ~/.julia/packages/DiffEqSensitivity/EpC0d/src/sde_tools.jl:75 [inlined]
 [19] (::typeof(∂(λ)))(Δ::Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Tuple{}})
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
 [20] #204#back
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [21] Pullback
    @ ~/.julia/packages/DiffEqSensitivity/EpC0d/src/derivative_wrappers.jl:428 [inlined]
 [22] (::typeof(∂(λ)))(Δ::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true})
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
 [23] (::Zygote.var"#50#51"{typeof(∂(λ))})(Δ::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true})
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:41
...

Might it have something to do with al the Nothing types in the stacktrace? Like the gradient does not have a type?

@frankschae is this just the mutation in the Strat->Ito transformation?

yeah that could be the issue… I’ll need to dig into this. I made a note.

I think there are two issues:
The batches are not handled correctly, i.e.,

using Random
x1 = rand(Float32, 3, 4)
y1 = rand(Float32, 3, 4)
prob1 = SDEProblem{false}(ff,g_,x1,neuralsde.tspan,neuralsde.p)

function loss(p;prob = prob1, y=y1, alg=EulerHeun(),sensealg=BacksolveAdjoint())
    _prob = remake(prob,p=p)
    ŷ = solve(_prob,alg,dt=0.01f0,sensealg=sensealg)[end]
    sum(abs2,ŷ - y)
end

Zygote.gradient(p->loss(p,alg=EulerHeun(),sensealg=BacksolveAdjoint()), neuralsde.p)

fails with

ERROR: MethodError: no method matching mul!(::Vector{Float32}, ::Matrix{Float32}, ::Matrix{Float32}, ::Bool, ::Bool)
Closest candidates are:
  mul!(::StridedVecOrMat{T} where T, ::LinearAlgebra.SymTridiagonal, ::StridedVecOrMat{T} where T, ::Number, ::Number) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/tridiag.jl:214
  mul!(::StridedVecOrMat{T} where T, ::SparseArrays.AbstractSparseMatrixCSC, ::Union{LinearAlgebra.Adjoint{var"#s832", var"#s831"} where {var"#s832", var"#s831"<:Union{LinearAlgebra.LowerTriangular, LinearAlgebra.UnitLowerTriangular, LinearAlgebra.UnitUpperTriangular, LinearAlgebra.UpperTriangular, StridedMatrix{T} where T, BitMatrix}}, LinearAlgebra.LowerTriangular, LinearAlgebra.Transpose{var"#s830", var"#s829"} where {var"#s830", var"#s829"<:Union{LinearAlgebra.LowerTriangular, LinearAlgebra.UnitLowerTriangular, LinearAlgebra.UnitUpperTriangular, LinearAlgebra.UpperTriangular, StridedMatrix{T} where T, BitMatrix}}, LinearAlgebra.UnitLowerTriangular, LinearAlgebra.UnitUpperTriangular, LinearAlgebra.UpperTriangular, StridedMatrix{T} where T, StridedVector{T} where T, BitMatrix, BitVector}, ::Number, ::Number) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/SparseArrays/src/linalg.jl:30
  mul!(::StridedVecOrMat{T} where T, ::Union{LinearAlgebra.Adjoint{var"#s832", var"#s831"} where {var"#s832", var"#s831"<:Union{LinearAlgebra.LowerTriangular, LinearAlgebra.UnitLowerTriangular, LinearAlgebra.UnitUpperTriangular, LinearAlgebra.UpperTriangular, StridedMatrix{T} where T, BitMatrix}}, LinearAlgebra.LowerTriangular, LinearAlgebra.Transpose{var"#s830", var"#s829"} where {var"#s830", var"#s829"<:Union{LinearAlgebra.LowerTriangular, LinearAlgebra.UnitLowerTriangular, LinearAlgebra.UnitUpperTriangular, LinearAlgebra.UpperTriangular, StridedMatrix{T} where T, BitMatrix}}, LinearAlgebra.UnitLowerTriangular, LinearAlgebra.UnitUpperTriangular, LinearAlgebra.UpperTriangular, StridedMatrix{T} where T, BitMatrix}, ::SparseArrays.AbstractSparseMatrixCSC, ::Number, ::Number) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/SparseArrays/src/linalg.jl:87
  ...

Stacktrace:
  [1] mul!
    @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:275 [inlined]
  [2] perform_step!(integrator::StochasticDiffEq.SDEIntegrator{EulerHeun, true, Vector{Float32},

inside the diffusion function. I am not sure about the best way to fix this. Maybe we could define
https://github.com/SciML/DiffEqFlux.jl/blob/ab04993e85d4cc2311b6ed4b6f2c36e41102ad36/src/neural_de.jl#L143
via an EnsembleProblem inside DiffEqFlux, or we could write the adjoint state definition inside DiffEqSensitivity more general?

Without batches, it’s the Zygote over Zygote issue due to the Ito-Stratonovich conversion requiring additional differentiation. The issue to track here is: https://github.com/SciML/DiffEqSensitivity.jl/issues/385 . To cope with it, for the moment, you’d have to either rewrite the SDE in the Stratonovich sense or use an AD-vjp option for which the nesting works, like TrackerVJP().

x2 = rand(Float32, 3)
y2 = rand(Float32, 3)
prob2 = SDEProblem{false}(ff,g_,x2,neuralsde.tspan,neuralsde.p)

Zygote.gradient(p->loss(p,prob = prob2,y=y2,alg=EulerHeun(),sensealg=BacksolveAdjoint()), neuralsde.p) #works 
Zygote.gradient(p->loss(p,prob = prob2,y=y2,alg=EM(),sensealg=BacksolveAdjoint()),  neuralsde.p) #doesn't works Zygote over Zygote
Zygote.gradient(p->loss(p,prob = prob2,y=y2,alg=EM(),sensealg=BacksolveAdjoint(autojacvec=TrackerVJP())), neuralsde.p) #works