Dimension mismatch in Flux with NeuralODE

In Flux I’m trying to chain together a NeuralODE and some other layers. When I evaluate the network and the a loss function, it works ok. But when I try to calculate the gradients of the loss w.r.t. the params there is a dimension mismatch that is coming from the NeuralODE. I’m not sure why the output dimension changes, but it might be due to the solve, I’ve tried a number of things to fix this but to no avail.

Below is a simplified bit of code that reproduces the error as well as the output/error message I get. Does anyone have any advice on how to solve this problem?

using DiffEqFlux: NeuralODE
using Flux: Chain, Dense, gradient, params
using NNlib: leakyrelu
using OrdinaryDiffEq: Tsit5
using Statistics: mean

function loss_func(y0, yt)
    return mean(abs.(y0 - yt))
end

n_batch = 64
d_l = 4
d_i = 256

t_start = 0.
t_end = 1.
t_span = (t_start, t_end)

# MLP for latent space NeuralODE
l_mlp = Chain(Dense(d_l => d_l, leakyrelu), Dense(d_l => d_l))

n_ode = NeuralODE(l_mlp, t_span, Tsit5(), saveat=t_end, save_start=false, save_everystep=false)

encoder = Chain(Dense(d_i => d_l, leakyrelu))
decoder = Chain(Dense(d_l => d_i, leakyrelu))

ae = Chain(encoder, n_ode, (x) -> reshape(Array(x), (d_l, n_batch)), decoder)

y0 = rand(d_i, n_batch) # some random data
yt = rand(d_i, n_batch) # some random data

# This line works fine
loss = loss_func(ae(y0), yt)
println("Loss: $(loss)")

# This line gives an error
grad = gradient(params(ae)) do 
    loss_func(ae(y0), yt)
end

error message:

Loss: 0.49221336310842245
ERROR: DimensionMismatch("new dimensions (4, 64) must be consistent with array size 512")
Stacktrace:
  [1] (::Base.var"#throw_dmrsa#272")(dims::Tuple{Int64, Int64}, len::Int64)
    @ Base ./reshapedarray.jl:41
  [2] reshape(a::Array{Float64, 3}, dims::Tuple{Int64, Int64})
    @ Base ./reshapedarray.jl:45
  [3] adjoint
    @ ~/.julia/packages/Zygote/IoW2g/src/lib/array.jl:106 [inlined]
  [4] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
  [5] _pullback
    @ ~/Documents/code/ae_test/error_example.jl:29 [inlined]
  [6] _pullback(ctx::Zygote.Context, f::var"#9#10", args::SciMLBase.ODESolution{Float64, 3, Vector{Matrix{Float64}}, Nothing, Nothing, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Nothing, SciMLBase.ODEProblem{Matrix{Float64}, Tuple{Float64, Float64}, false, Vector{Float32}, SciMLBase.ODEFunction{false, DiffEqFlux.var"#dudt_#133"{NeuralODE{Chain{Tuple{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Optimisers.Restructure{Chain{Tuple{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, Tuple{Float64, Float64}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Real, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :save_start, :save_everystep), Tuple{Float64, Bool, Bool}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, SciMLBase.SensitivityInterpolation{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Matrix{Float64}}}, DiffEqBase.DEStats})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
  [7] macro expansion
    @ ~/.julia/packages/Flux/KkC79/src/layers/basic.jl:53 [inlined]
  [8] _pullback
    @ ~/.julia/packages/Flux/KkC79/src/layers/basic.jl:53 [inlined]
  [9] _pullback(::Zygote.Context, ::typeof(Flux._applychain), ::Tuple{Chain{Tuple{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}}}, NeuralODE{Chain{Tuple{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Optimisers.Restructure{Chain{Tuple{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, Tuple{Float64, Float64}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Real, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :save_start, :save_everystep), Tuple{Float64, Bool, Bool}}}}, var"#9#10", Chain{Tuple{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}}}}, ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
 [10] _pullback
    @ ~/.julia/packages/Flux/KkC79/src/layers/basic.jl:51 [inlined]
 [11] _pullback(ctx::Zygote.Context, f::Chain{Tuple{Chain{Tuple{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}}}, NeuralODE{Chain{Tuple{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Optimisers.Restructure{Chain{Tuple{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, Tuple{Float64, Float64}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Real, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :save_start, :save_everystep), Tuple{Float64, Bool, Bool}}}}, var"#9#10", Chain{Tuple{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}}}}}, args::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
 [12] _pullback
    @ ~/Documents/code/ae_test/error_example.jl:40 [inlined]
 [13] _pullback(::Zygote.Context, ::var"#11#12")
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
 [14] pullback(f::Function, ps::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface.jl:352
 [15] gradient(f::Function, args::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface.jl:75
 [16] top-level scope
    @ ~/Documents/code/ae_test/error_example.jl:39