Thanks for your reply. Here are the changes I made: I changed the training data from a 1 X 10,000 array to a 3 X 10,000 array where each column corresponds to parameter mu, sigma, and x, which is a sample from the lognormal distribution given the parameters. I figured that I needed to change the input layer from one to three nodes:
Full Error
ERROR: DimensionMismatch: A has dimensions (10,3) but B has dimensions (1,1000)
Stacktrace:
[1] gemm_wrapper!(C::Matrix{Float32}, tA::Char, tB::Char, A::Matrix{Float32}, B::Matrix{Float32}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
@ LinearAlgebra ~/julia-1.8.3/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:646
[2] mul!
@ ~/julia-1.8.3/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:161 [inlined]
[3] mul!
@ ~/julia-1.8.3/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:276 [inlined]
[4] *
@ ~/julia-1.8.3/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:148 [inlined]
[5] rrule
@ ~/.julia/packages/ChainRules/ajkp7/src/rulesets/Base/arraymath.jl:64 [inlined]
[6] rrule
@ ~/.julia/packages/ChainRulesCore/C73ay/src/rules.jl:134 [inlined]
[7] chain_rrule
@ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:218 [inlined]
[8] macro expansion
@ ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0 [inlined]
[9] _pullback
@ ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:9 [inlined]
[10] _pullback
@ ~/.julia/packages/Flux/kq9Et/src/layers/basic.jl:172 [inlined]
[11] macro expansion
@ ~/.julia/packages/Flux/kq9Et/src/layers/basic.jl:53 [inlined]
[12] _pullback
@ ~/.julia/packages/Flux/kq9Et/src/layers/basic.jl:53 [inlined]
[13] _pullback(::Zygote.Context{false}, ::typeof(Flux._applychain), ::Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}, ::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[14] _pullback
@ ~/.julia/packages/Flux/kq9Et/src/layers/basic.jl:51 [inlined]
[15] pullback(f::Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, cx::Zygote.Context{false}, args::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:44
[16] pullback
@ ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:42 [inlined]
[17] ffjord(u::Matrix{Float32}, p::Vector{Float32}, t::Float32, re::Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, e::Matrix{Float32}, st::Nothing; regularize::Bool, monte_carlo::Bool)
@ DiffEqFlux ~/.julia/packages/DiffEqFlux/2IJEZ/src/ffjord.jl:237
[18] ffjord_
@ ~/.julia/packages/DiffEqFlux/2IJEZ/src/ffjord.jl:282 [inlined]
[19] ODEFunction
@ ~/.julia/packages/SciMLBase/QqtZA/src/scimlfunctions.jl:2096 [inlined]
[20] initialize!(integrator::OrdinaryDiffEq.ODEIntegrator{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, false, Matrix{Float32}, Nothing, Float32, Vector{Float32}, Float32, Float32, Float32, Float32, Vector{Matrix{Float32}}, ODESolution{Float32, 3, Vector{Matrix{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Matrix{Float32}}}, ODEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, Vector{Float32}, ODEFunction{false, SciMLBase.AutoSpecialize, DiffEqFlux.var"#ffjord_#72"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Nothing, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, SciMLBase.AutoSpecialize, DiffEqFlux.var"#ffjord_#72"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Nothing, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Vector{Matrix{Float32}}, Vector{Float32}, Vector{Vector{Matrix{Float32}}}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}}, DiffEqBase.DEStats, Nothing}, ODEFunction{false, SciMLBase.AutoSpecialize, DiffEqFlux.var"#ffjord_#72"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Nothing, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, OrdinaryDiffEq.DEOptions{Float32, Float32, Float32, Float32, PIController{Rational{Int64}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Bool, CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, Nothing, Nothing, Int64, Tuple{}, Tuple{}, Tuple{}}, Matrix{Float32}, Float32, Nothing, OrdinaryDiffEq.DefaultInit}, cache::OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32})
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/CFzwI/src/perform_step/low_order_rk_perform_step.jl:672
[21] __init(prob::ODEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, Vector{Float32}, ODEFunction{false, SciMLBase.AutoSpecialize, DiffEqFlux.var"#ffjord_#72"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Nothing, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, alg::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{Val{true}}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Bool, callback::Nothing, dense::Bool, calck::Bool, dt::Float32, dtmin::Nothing, dtmax::Float32, force_dtmin::Bool, adaptive::Bool, gamma::Rational{Int64}, abstol::Nothing, reltol::Nothing, qmin::Rational{Int64}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{Int64}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::Base.Pairs{Symbol, Bool, Tuple{Symbol}, NamedTuple{(:save_noise,), Tuple{Bool}}})
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/CFzwI/src/solve.jl:493
[22] #__solve#566
@ ~/.julia/packages/OrdinaryDiffEq/CFzwI/src/solve.jl:5 [inlined]
[23] #solve_call#21
@ ~/.julia/packages/DiffEqBase/wXM5P/src/solve.jl:494 [inlined]
[24] solve_up(prob::ODEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, Vector{Float32}, ODEFunction{false, SciMLBase.AutoSpecialize, DiffEqFlux.var"#ffjord_#72"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Nothing, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, sensealg::Nothing, u0::Matrix{Float32}, p::Vector{Float32}, args::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}; kwargs::Base.Pairs{Symbol, Bool, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:save_noise, :save_start, :save_end), Tuple{Bool, Bool, Bool}}})
@ DiffEqBase ~/.julia/packages/DiffEqBase/wXM5P/src/solve.jl:856
[25] #solve#26
@ ~/.julia/packages/DiffEqBase/wXM5P/src/solve.jl:823 [inlined]
[26] _concrete_solve_adjoint(::ODEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, Vector{Float32}, ODEFunction{false, SciMLBase.AutoSpecialize, DiffEqFlux.var"#ffjord_#72"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Nothing, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ::InterpolatingAdjoint{0, true, Val{:central}, Nothing}, ::Matrix{Float32}, ::Vector{Float32}, ::SciMLBase.ChainRulesOriginator; save_start::Bool, save_end::Bool, saveat::Vector{Float32}, save_idxs::Nothing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/DInxI/src/concrete_solve.jl:279
[27] _concrete_solve_adjoint
@ ~/.julia/packages/SciMLSensitivity/DInxI/src/concrete_solve.jl:231 [inlined]
[28] #_solve_adjoint#50
@ ~/.julia/packages/DiffEqBase/wXM5P/src/solve.jl:1263 [inlined]
[29] _solve_adjoint
@ ~/.julia/packages/DiffEqBase/wXM5P/src/solve.jl:1232 [inlined]
[30] #rrule#48
@ ~/.julia/packages/DiffEqBase/wXM5P/src/solve.jl:1216 [inlined]
[31] rrule
@ ~/.julia/packages/DiffEqBase/wXM5P/src/solve.jl:1212 [inlined]
[32] rrule
@ ~/.julia/packages/ChainRulesCore/C73ay/src/rules.jl:134 [inlined]
[33] chain_rrule
@ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:218 [inlined]
[34] macro expansion
@ ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0 [inlined]
[35] _pullback
@ ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:9 [inlined]
[36] _apply
@ ./boot.jl:816 [inlined]
[37] adjoint
@ ~/.julia/packages/Zygote/SmJK6/src/lib/lib.jl:203 [inlined]
[38] _pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[39] _pullback
@ ~/.julia/packages/DiffEqBase/wXM5P/src/solve.jl:823 [inlined]
[40] _pullback(::Zygote.Context{false}, ::DiffEqBase.var"##solve#26", ::InterpolatingAdjoint{0, true, Val{:central}, Nothing}, ::Nothing, ::Nothing, ::Val{true}, ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(solve), ::ODEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, Vector{Float32}, ODEFunction{false, SciMLBase.AutoSpecialize, DiffEqFlux.var"#ffjord_#72"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Nothing, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[41] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:816
[42] adjoint
@ ~/.julia/packages/Zygote/SmJK6/src/lib/lib.jl:203 [inlined]
[43] _pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[44] _pullback
@ ~/.julia/packages/DiffEqBase/wXM5P/src/solve.jl:813 [inlined]
[45] _pullback(::Zygote.Context{false}, ::CommonSolve.var"#solve##kw", ::NamedTuple{(:sensealg,), Tuple{InterpolatingAdjoint{0, true, Val{:central}, Nothing}}}, ::typeof(solve), ::ODEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, Vector{Float32}, ODEFunction{false, SciMLBase.AutoSpecialize, DiffEqFlux.var"#ffjord_#72"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Nothing, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[46] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:816
[47] adjoint
@ ~/.julia/packages/Zygote/SmJK6/src/lib/lib.jl:203 [inlined]
[48] _pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[49] _pullback
@ ~/.julia/packages/DiffEqFlux/2IJEZ/src/ffjord.jl:297 [inlined]
[50] _pullback(::Zygote.Context{false}, ::DiffEqFlux.var"##forward_ffjord#67", ::Bool, ::Bool, ::typeof(DiffEqFlux.forward_ffjord), ::FFJORD{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Nothing, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::Matrix{Float32}, ::Vector{Float32}, ::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[51] _pullback (repeats 2 times)
@ ~/.julia/packages/DiffEqFlux/2IJEZ/src/ffjord.jl:278 [inlined]
[52] _pullback(::Zygote.Context{false}, ::typeof(DiffEqFlux.forward_ffjord), ::FFJORD{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Nothing, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::Matrix{Float32}, ::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[53] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:816
[54] adjoint
@ ~/.julia/packages/Zygote/SmJK6/src/lib/lib.jl:203 [inlined]
[55] _pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[56] _pullback
@ ~/.julia/packages/DiffEqFlux/2IJEZ/src/ffjord.jl:276 [inlined]
--- the last 5 lines are repeated 1 more time ---
[62] _pullback(::Zygote.Context{false}, ::FFJORD{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Nothing, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::Matrix{Float32}, ::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[63] _pullback
@ ~/.julia/dev/normalizing_flow/sandbox.jl:35 [inlined]
[64] _pullback(ctx::Zygote.Context{false}, f::typeof(loss), args::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[65] _pullback
@ ~/.julia/dev/normalizing_flow/sandbox.jl:46 [inlined]
[66] _pullback(::Zygote.Context{false}, ::var"#73#74", ::Vector{Float32}, ::SciMLBase.NullParameters)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[67] _apply
@ ./boot.jl:816 [inlined]
[68] adjoint
@ ~/.julia/packages/Zygote/SmJK6/src/lib/lib.jl:203 [inlined]
[69] _pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[70] _pullback
@ ~/.julia/packages/SciMLBase/QqtZA/src/scimlfunctions.jl:3580 [inlined]
[71] _pullback(::Zygote.Context{false}, ::OptimizationFunction{true, Optimization.AutoZygote, var"#73#74", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ::Vector{Float32}, ::SciMLBase.NullParameters)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[72] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:816
[73] adjoint
@ ~/.julia/packages/Zygote/SmJK6/src/lib/lib.jl:203 [inlined]
[74] _pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[75] _pullback
@ ~/.julia/packages/Optimization/o00ZS/src/function/zygote.jl:30 [inlined]
[76] _pullback(ctx::Zygote.Context{false}, f::Optimization.var"#156#165"{OptimizationFunction{true, Optimization.AutoZygote, var"#73#74", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, SciMLBase.NullParameters}, args::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[77] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:816
[78] adjoint
@ ~/.julia/packages/Zygote/SmJK6/src/lib/lib.jl:203 [inlined]
[79] _pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[80] _pullback
@ ~/.julia/packages/Optimization/o00ZS/src/function/zygote.jl:34 [inlined]
[81] _pullback(ctx::Zygote.Context{false}, f::Optimization.var"#158#167"{Tuple{}, Optimization.var"#156#165"{OptimizationFunction{true, Optimization.AutoZygote, var"#73#74", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, SciMLBase.NullParameters}}, args::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[82] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:44
[83] pullback
@ ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:42 [inlined]
[84] gradient(f::Function, args::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:96
[85] (::Optimization.var"#157#166"{Optimization.var"#156#165"{OptimizationFunction{true, Optimization.AutoZygote, var"#73#74", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, SciMLBase.NullParameters}})(::Vector{Float32}, ::Vector{Float32})
@ Optimization ~/.julia/packages/Optimization/o00ZS/src/function/zygote.jl:32
[86] macro expansion
@ ~/.julia/packages/OptimizationFlux/zHcx5/src/OptimizationFlux.jl:32 [inlined]
[87] macro expansion
@ ~/.julia/packages/Optimization/o00ZS/src/utils.jl:37 [inlined]
[88] __solve(prob::OptimizationProblem{true, OptimizationFunction{true, Optimization.AutoZygote, var"#73#74", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Vector{Float32}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, opt::Adam, data::Base.Iterators.Cycle{Tuple{Optimization.NullData}}; maxiters::Int64, callback::Function, progress::Bool, save_best::Bool, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ OptimizationFlux ~/.julia/packages/OptimizationFlux/zHcx5/src/OptimizationFlux.jl:31
[89] #solve#540
@ ~/.julia/packages/SciMLBase/QqtZA/src/solve.jl:84 [inlined]
````Preformatted text``Preformatted text`
My guess is that either the training data or the architecture are not correctly specified. Or both. Another guess is that the optimizer might need to be changed to accommodate the change in the architecture, but honestly I don’t have much of an idea.