Hello, I am having difficulty computing gradients of a neural sdes (either custom or in DiffeqFlux) using Zygote
or ReverseDiff
. I think it has to do with the shape of the noise during the backward pass but I am not sure.
Here is a MWE:
using Zygote, Random, DiffEqFlux, DifferentialEquations, Lux, ComponentArrays
rng = Random.default_rng()
tspan = (0.0, 1.0) .|> Float32
n_timepoints = 10
n_samples = 32
ts = range(tspan[1], tspan[2], length=n_timepoints)
y = rand32(2, n_timepoints, n_samples) # (n_dim, n_timepoints, n_samples)
y0 = y[:, 1, :]
model = NeuralDSDE(
Dense(2, 2, tanh),
Dense(2, 2, tanh),
tspan,
EulerHeun(),
dt = 0.1,
saveat = ts,
sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP())
)
p, st = Lux.setup(rng, model)
p = ComponentArray{Float32}(p)
function predict(y0, p, st)
permutedims(Array(model(y0, p, st)[1]), (1, 3, 2))
end
function loss(p)
pred = predict(y0, p, st)
return sum(abs2, pred .- y)
end
Now computing the gradinets using ForwardDiff works, but using Zygote throws an error:
Zygote.gradient(loss, p)
DimensionMismatch: arrays could not be broadcast to a common size; got a dimension with lengths 64 and 2 Stacktrace: [1] _bcs1 @ ./broadcast.jl:529 [inlined] [2] _bcs @ ./broadcast.jl:523 [inlined] [3] broadcast_shape @ ./broadcast.jl:517 [inlined] [4] combine_axes @ ./broadcast.jl:512 [inlined] [5] instantiate @ ./broadcast.jl:294 [inlined] [6] materialize @ ./broadcast.jl:873 [inlined] [7] (::NNlib.var"#145#148"{SubArray{Float32, 1, LinearAlgebra.Diagonal{Float32, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, false}, Matrix{Float32}})() @ NNlib ~/.julia/packages/NNlib/O0zGY/src/activations.jl:903 [8] unthunk @ ~/.julia/packages/ChainRulesCore/zgT0R/src/tangent_types/thunks.jl:204 [inlined] [9] unthunk @ ~/.julia/packages/ChainRulesCore/zgT0R/src/tangent_types/thunks.jl:237 [inlined] [10] wrap_chainrules_output @ ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:110 [inlined] [11] map @ ./tuple.jl:275 [inlined]