In Flux I’m trying to chain together a NeuralODE and some other layers. When I evaluate the network and the a loss function, it works ok. But when I try to calculate the gradients of the loss w.r.t. the params there is a dimension mismatch that is coming from the NeuralODE. I’m not sure why the output dimension changes, but it might be due to the solve, I’ve tried a number of things to fix this but to no avail.
Below is a simplified bit of code that reproduces the error as well as the output/error message I get. Does anyone have any advice on how to solve this problem?
using DiffEqFlux: NeuralODE
using Flux: Chain, Dense, gradient, params
using NNlib: leakyrelu
using OrdinaryDiffEq: Tsit5
using Statistics: mean
function loss_func(y0, yt)
return mean(abs.(y0 - yt))
end
n_batch = 64
d_l = 4
d_i = 256
t_start = 0.
t_end = 1.
t_span = (t_start, t_end)
# MLP for latent space NeuralODE
l_mlp = Chain(Dense(d_l => d_l, leakyrelu), Dense(d_l => d_l))
n_ode = NeuralODE(l_mlp, t_span, Tsit5(), saveat=t_end, save_start=false, save_everystep=false)
encoder = Chain(Dense(d_i => d_l, leakyrelu))
decoder = Chain(Dense(d_l => d_i, leakyrelu))
ae = Chain(encoder, n_ode, (x) -> reshape(Array(x), (d_l, n_batch)), decoder)
y0 = rand(d_i, n_batch) # some random data
yt = rand(d_i, n_batch) # some random data
# This line works fine
loss = loss_func(ae(y0), yt)
println("Loss: $(loss)")
# This line gives an error
grad = gradient(params(ae)) do
loss_func(ae(y0), yt)
end
error message:
Loss: 0.49221336310842245
ERROR: DimensionMismatch("new dimensions (4, 64) must be consistent with array size 512")
Stacktrace:
[1] (::Base.var"#throw_dmrsa#272")(dims::Tuple{Int64, Int64}, len::Int64)
@ Base ./reshapedarray.jl:41
[2] reshape(a::Array{Float64, 3}, dims::Tuple{Int64, Int64})
@ Base ./reshapedarray.jl:45
[3] adjoint
@ ~/.julia/packages/Zygote/IoW2g/src/lib/array.jl:106 [inlined]
[4] _pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[5] _pullback
@ ~/Documents/code/ae_test/error_example.jl:29 [inlined]
[6] _pullback(ctx::Zygote.Context, f::var"#9#10", args::SciMLBase.ODESolution{Float64, 3, Vector{Matrix{Float64}}, Nothing, Nothing, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Nothing, SciMLBase.ODEProblem{Matrix{Float64}, Tuple{Float64, Float64}, false, Vector{Float32}, SciMLBase.ODEFunction{false, DiffEqFlux.var"#dudt_#133"{NeuralODE{Chain{Tuple{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Optimisers.Restructure{Chain{Tuple{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, Tuple{Float64, Float64}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Real, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :save_start, :save_everystep), Tuple{Float64, Bool, Bool}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, SciMLBase.SensitivityInterpolation{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Matrix{Float64}}}, DiffEqBase.DEStats})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[7] macro expansion
@ ~/.julia/packages/Flux/KkC79/src/layers/basic.jl:53 [inlined]
[8] _pullback
@ ~/.julia/packages/Flux/KkC79/src/layers/basic.jl:53 [inlined]
[9] _pullback(::Zygote.Context, ::typeof(Flux._applychain), ::Tuple{Chain{Tuple{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}}}, NeuralODE{Chain{Tuple{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Optimisers.Restructure{Chain{Tuple{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, Tuple{Float64, Float64}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Real, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :save_start, :save_everystep), Tuple{Float64, Bool, Bool}}}}, var"#9#10", Chain{Tuple{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}}}}, ::Matrix{Float64})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[10] _pullback
@ ~/.julia/packages/Flux/KkC79/src/layers/basic.jl:51 [inlined]
[11] _pullback(ctx::Zygote.Context, f::Chain{Tuple{Chain{Tuple{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}}}, NeuralODE{Chain{Tuple{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Optimisers.Restructure{Chain{Tuple{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, Tuple{Float64, Float64}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Real, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :save_start, :save_everystep), Tuple{Float64, Bool, Bool}}}}, var"#9#10", Chain{Tuple{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}}}}}, args::Matrix{Float64})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[12] _pullback
@ ~/Documents/code/ae_test/error_example.jl:40 [inlined]
[13] _pullback(::Zygote.Context, ::var"#11#12")
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[14] pullback(f::Function, ps::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface.jl:352
[15] gradient(f::Function, args::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface.jl:75
[16] top-level scope
@ ~/Documents/code/ae_test/error_example.jl:39