I’m trying to train a SDE with DiffEqFlux. I had issues with stack overflows when using TrackerAdjoint (probably because of the size of the problem), so i tried using BacksolveAdjoint but this errors on calculating the gradients. This is small MWE (the actual problem is much larger) of what I did:
using Flux
using DiffEqFlux
using DifferentialEquations
f = FastChain(
FastDense(3, 2, tanh),
FastDense(2, 3),
)
g = FastChain(
FastDense(3, 2, tanh),
FastDense(2, 3),
)
x = rand(Float32, 3, 4)
y = rand(Float32, 3, 4)
neuralsde = NeuralDSDE(f, g, (0f0, 1f0), sensealg=BacksolveAdjoint())
#neuralsde = NeuralDSDE(f, g, (0f0, 1f0), sensealg=TrackerAdjoint()) <- defining the problem like this works
ps = Flux.params(neuralsde.p)
Flux.gradient(ps) do
ŷ = neuralsde(x)[:,:,end]
Flux.Losses.mse(ŷ, y)
end
It gives the following error:
ERROR: UndefVarError: xs not defined
Stacktrace:
[1] (::Zygote.var"#442#443")(#unused#::Nothing)
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/lib/array.jl:79
[2] (::Zygote.var"#2383#back#444"{Zygote.var"#442#443"})(Δ::Nothing)
@ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[3] Pullback
@ ./broadcast.jl:894 [inlined]
[4] Pullback
@ ./broadcast.jl:891 [inlined]
[5] Pullback
@ ./broadcast.jl:887 [inlined]
[6] (::typeof(∂(materialize!)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[7] Pullback
@ ~/.julia/packages/Zygote/TaBlo/src/lib/array.jl:44 [inlined]
[8] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[9] Pullback
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[10] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[11] Pullback
@ ~/.julia/packages/DiffEqFlux/N7blG/src/neural_de.jl:147 [inlined]
[12] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Matrix{Float32}, Nothing, Nothing})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[13] Pullback
@ ~/.julia/packages/DiffEqSensitivity/cLl5o/src/sde_tools.jl:73 [inlined]
[14] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Matrix{Float32}, Nothing})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[15] Pullback
@ ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:41 [inlined]
[16] (::typeof(∂(λ)))(Δ::Tuple{Matrix{Float32}, Nothing})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[17] Pullback
@ ~/.julia/packages/DiffEqSensitivity/cLl5o/src/sde_tools.jl:75 [inlined]
[18] (::typeof(∂(λ)))(Δ::Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Tuple{}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[19] #204#back
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[20] Pullback
@ ~/.julia/packages/DiffEqSensitivity/cLl5o/src/derivative_wrappers.jl:454 [inlined]
[21] (::typeof(∂(λ)))(Δ::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[22] (::Zygote.var"#46#47"{typeof(∂(λ))})(Δ::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:41
...
According to the docs the BacksolveAdjoint method can be used for SDE’s. I’m I using it incorrect? Is there a way to make it work?