Problems with reverse mode automatic diff (e.g. zygote) with NeuralSDEs

Hello, I am having difficulty computing gradients of a neural sdes (either custom or in DiffeqFlux) using Zygote or ReverseDiff. I think it has to do with the shape of the noise during the backward pass but I am not sure.

Here is a MWE:

using Zygote, Random, DiffEqFlux, DifferentialEquations, Lux, ComponentArrays

rng = Random.default_rng()
tspan = (0.0, 1.0) .|> Float32
n_timepoints = 10
n_samples = 32
ts = range(tspan[1], tspan[2], length=n_timepoints)
y = rand32(2, n_timepoints, n_samples)  # (n_dim, n_timepoints, n_samples)
y0 = y[:, 1, :] 

model = NeuralDSDE(
    Dense(2, 2, tanh),
    Dense(2, 2, tanh),
    tspan,
    EulerHeun(),
    dt = 0.1,
    saveat = ts, 
    sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP())
)

p, st = Lux.setup(rng, model)
p = ComponentArray{Float32}(p)

function predict(y0, p, st)
   permutedims(Array(model(y0, p, st)[1]), (1, 3, 2))
end

function loss(p)
    pred = predict(y0, p, st)
    return sum(abs2, pred .- y)
end

Now computing the gradinets using ForwardDiff works, but using Zygote throws an error:

Zygote.gradient(loss, p)

DimensionMismatch: arrays could not be broadcast to a common size; got a dimension with lengths 64 and 2 Stacktrace: [1] _bcs1 @ ./broadcast.jl:529 [inlined] [2] _bcs @ ./broadcast.jl:523 [inlined] [3] broadcast_shape @ ./broadcast.jl:517 [inlined] [4] combine_axes @ ./broadcast.jl:512 [inlined] [5] instantiate @ ./broadcast.jl:294 [inlined] [6] materialize @ ./broadcast.jl:873 [inlined] [7] (::NNlib.var"#145#148"{SubArray{Float32, 1, LinearAlgebra.Diagonal{Float32, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, false}, Matrix{Float32}})() @ NNlib ~/.julia/packages/NNlib/O0zGY/src/activations.jl:903 [8] unthunk @ ~/.julia/packages/ChainRulesCore/zgT0R/src/tangent_types/thunks.jl:204 [inlined] [9] unthunk @ ~/.julia/packages/ChainRulesCore/zgT0R/src/tangent_types/thunks.jl:237 [inlined] [10] wrap_chainrules_output @ ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:110 [inlined] [11] map @ ./tuple.jl:275 [inlined]

The problem can only be resolved if I changed the sensealg to TrackerAdjoint() and the types to Float64