Problems with reverse mode automatic diff (e.g. zygote) with NeuralSDEs

Hello, I am having difficulty computing gradients of a neural sdes (either custom or in DiffeqFlux) using Zygote or ReverseDiff. I think it has to do with the shape of the noise during the backward pass but I am not sure.

Here is a MWE:

using Zygote, Random, DiffEqFlux, DifferentialEquations, Lux, ComponentArrays

rng = Random.default_rng()
tspan = (0.0, 1.0) .|> Float32
n_timepoints = 10
n_samples = 32
ts = range(tspan[1], tspan[2], length=n_timepoints)
y = rand32(2, n_timepoints, n_samples)  # (n_dim, n_timepoints, n_samples)
y0 = y[:, 1, :] 

model = NeuralDSDE(
    Dense(2, 2, tanh),
    Dense(2, 2, tanh),
    tspan,
    EulerHeun(),
    dt = 0.1,
    saveat = ts, 
    sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP())
)

p, st = Lux.setup(rng, model)
p = ComponentArray{Float32}(p)

function predict(y0, p, st)
   permutedims(Array(model(y0, p, st)[1]), (1, 3, 2))
end

function loss(p)
    pred = predict(y0, p, st)
    return sum(abs2, pred .- y)
end

Now computing the gradinets using ForwardDiff works, but using Zygote throws an error:

Zygote.gradient(loss, p)

DimensionMismatch: arrays could not be broadcast to a common size; got a dimension with lengths 64 and 2 Stacktrace: [1] _bcs1 @ ./broadcast.jl:529 [inlined] [2] _bcs @ ./broadcast.jl:523 [inlined] [3] broadcast_shape @ ./broadcast.jl:517 [inlined] [4] combine_axes @ ./broadcast.jl:512 [inlined] [5] instantiate @ ./broadcast.jl:294 [inlined] [6] materialize @ ./broadcast.jl:873 [inlined] [7] (::NNlib.var"#145#148"{SubArray{Float32, 1, LinearAlgebra.Diagonal{Float32, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, false}, Matrix{Float32}})() @ NNlib ~/.julia/packages/NNlib/O0zGY/src/activations.jl:903 [8] unthunk @ ~/.julia/packages/ChainRulesCore/zgT0R/src/tangent_types/thunks.jl:204 [inlined] [9] unthunk @ ~/.julia/packages/ChainRulesCore/zgT0R/src/tangent_types/thunks.jl:237 [inlined] [10] wrap_chainrules_output @ ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:110 [inlined] [11] map @ ./tuple.jl:275 [inlined]

The problem can only be resolved if I changed the sensealg to TrackerAdjoint() and the types to Float64

Open an issue on SciMLSensitivity.jl. There’s some ongoing work on improving the SDE adjoints, though discrete adjoints (like TrackerAdjoint) do have some advantages on this kind of SDE.