# Problems with reverse mode automatic diff (e.g. zygote) with NeuralSDEs

Hello, I am having difficulty computing gradients of a neural sdes (either custom or in DiffeqFlux) using `Zygote` or `ReverseDiff`. I think it has to do with the shape of the noise during the backward pass but I am not sure.

Here is a MWE:

``````using Zygote, Random, DiffEqFlux, DifferentialEquations, Lux, ComponentArrays

rng = Random.default_rng()
tspan = (0.0, 1.0) .|> Float32
n_timepoints = 10
n_samples = 32
ts = range(tspan[1], tspan[2], length=n_timepoints)
y = rand32(2, n_timepoints, n_samples)  # (n_dim, n_timepoints, n_samples)
y0 = y[:, 1, :]

model = NeuralDSDE(
Dense(2, 2, tanh),
Dense(2, 2, tanh),
tspan,
EulerHeun(),
dt = 0.1,
saveat = ts,
)

p, st = Lux.setup(rng, model)
p = ComponentArray{Float32}(p)

function predict(y0, p, st)
permutedims(Array(model(y0, p, st)[1]), (1, 3, 2))
end

function loss(p)
pred = predict(y0, p, st)
return sum(abs2, pred .- y)
end
``````

Now computing the gradinets using ForwardDiff works, but using Zygote throws an error:

`Zygote.gradient(loss, p)`

DimensionMismatch: arrays could not be broadcast to a common size; got a dimension with lengths 64 and 2 Stacktrace: [1] _bcs1 @ ./broadcast.jl:529 [inlined] [2] _bcs @ ./broadcast.jl:523 [inlined] [3] broadcast_shape @ ./broadcast.jl:517 [inlined] [4] combine_axes @ ./broadcast.jl:512 [inlined] [5] instantiate @ ./broadcast.jl:294 [inlined] [6] materialize @ ./broadcast.jl:873 [inlined] [7] (::NNlib.var"#145#148"{SubArray{Float32, 1, LinearAlgebra.Diagonal{Float32, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, false}, Matrix{Float32}})() @ NNlib ~/.julia/packages/NNlib/O0zGY/src/activations.jl:903 [8] unthunk @ ~/.julia/packages/ChainRulesCore/zgT0R/src/tangent_types/thunks.jl:204 [inlined] [9] unthunk @ ~/.julia/packages/ChainRulesCore/zgT0R/src/tangent_types/thunks.jl:237 [inlined] [10] wrap_chainrules_output @ ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:110 [inlined] [11] map @ ./tuple.jl:275 [inlined]

The problem can only be resolved if I changed the sensealg to `TrackerAdjoint()` and the types to `Float64`

Open an issue on SciMLSensitivity.jl. There’s some ongoing work on improving the SDE adjoints, though discrete adjoints (like TrackerAdjoint) do have some advantages on this kind of SDE.