# Using ODE solution in NEURAL ODE

using ComponentArrays, Lux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationOptimJL, Random, Plots,OptimizationOptimisers
rng = Random.default_rng()
## Setup ODE to optimize
u0 = Float32[1.0 ; 1.0]
tspan = (0.0f0, 4.0f0)
datasize = 70

tsteps = range(tspan[1], tspan[2], length = datasize)
function lotka_volterra(du,u,p,t)
x, y = u
p = Float32[1.5;1.0;3.0;1.0]
α, β, δ, γ = p
du[1] = dx = α*x - β*x*y
du[2] = dy = -δ*y + γ*x*y
end

prob = ODEProblem(lotka_volterra,u0,tspan)

# Verify ODE solution
ode_data =Array(solve(prob, Tsit5(), saveat = tsteps))

# Build a neural network that sets the cost as the difference from the
# generated data and 1

dudt2 = Lux.Chain(x -> sin.(x),
Lux.Dense(2, 50, tanh),
Lux.Dense(50, 2))
p, st = Lux.setup(rng, dudt2)
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)

instead of sin.(x) i would like to use the solution of the ode

You are using ‘’’ to delimit your code blocks, but you need , which are different ticks.

sol = solve(prob, Tsit5())

dudt2 = Lux.Chain(x -> sol(x),
Lux.Dense(2, 50, tanh),
Lux.Dense(50, 2))

using ComponentArrays, Lux, DiffEqFlux, Optimization, OptimizationPolyalgorithms, DifferentialEquations, Plots
using DiffEqFlux: group_ranges

using Random
rng = Random.default_rng()

# Define initial conditions and time steps

tspan = (0.0f0, 4.0f0)
datasize = 70
tsteps = range(tspan[1], tspan[2], length = datasize)
t=tsteps
u0 = Float32[1.0 ; 1.0]
function lotka_volterra(du,u,p,t)
x, y= u
p = Float32[1.5;1.0;3.0;1.0]
α, β, δ, γ = p
du[1] = dx = α*x - β*x*y
du[2] = dy = -δ*y + γ*x*y
end

prob = ODEProblem(lotka_volterra,u0,tspan)

# Verify ODE solution
ode_data =Array(solve(prob, Tsit5(), saveat = tsteps))
sol=solve(prob, Tsit5())
#from ode data array to define a function

anim = Plots.Animation()

# Define the Neural Network
nn = Lux.Chain(x -> sol(x),
Lux.Dense(2, 84, tanh),
#Lux.Dense(84, 44, swish),
#Lux.Dense(44, 22, swish),
#Lux.Dense(22, 12, swish),
Lux.Dense(84,2))
p_init, st = Lux.setup(rng, nn)

neuralode = NeuralODE(nn, tspan, Tsit5(), saveat = tsteps)
prob_node = ODEProblem((u,p,t)->nn(u,p,st)[1], u0, tspan, ComponentArray(p_init))

function plot_multiple_shoot(plt, preds, group_size)
step = group_size-1
ranges = group_ranges(datasize, group_size)

for (i, rg) in enumerate(ranges)
plot!(plt, tsteps[rg], preds[i][1,:], markershape=:circle)
end
end

# Animate training, cannot make animation on CI server
# anim = Plots.Animation()
iter = 0
callback = function (p, l, preds; doplot = true)
display(l)
global iter
iter += 1
if doplot && iter%1 == 0
# plot the original data
plt = scatter(tsteps, ode_data[1,:], label = "Data")

# plot the different predictions for individual shoot
plot_multiple_shoot(plt, preds, group_size)

frame(anim)
display(plot(plt))
end
return false
end

# Define parameters for Multiple Shooting
group_size = 3
continuity_term = 100

function loss_function(data, pred)
return sum(abs2, data - pred)
end

function loss_multiple_shooting(p)
return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function, Tsit5(),
group_size; continuity_term)
end

optf = Optimization.OptimizationFunction((x,p) -> loss_multiple_shooting(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentArray(p_init))
res_ms = Optimization.solve(optprob, PolyOpt(), callback = callback)

gif(anim, "multiple_shooting.gif", fps=15)

optimized_params= res_ms.u
print("the optimized params are: ",optimized_params)
u1 = Float32[ode_data[1,end] ;  ode_data[2,end]]
tspan1 = (tspan[2], tspan[2]+10.0f0)
tsteps1=range(tspan1[1], tspan1[2], length = datasize)

true_odePRED = ODEProblem(lotka_volterra,u1,tspan1)
ode_data_pred = Array(solve(true_odePRED, Tsit5(), saveat = tsteps1))
prob_neuralode2 = NeuralODE(nn, tspan1, Tsit5(), saveat = tsteps1)

function predict_neuralode2(p)
Array(prob_neuralode2(u1, p, st)[1])
end
prediction=predict_neuralode2(optimized_params)

plot(tsteps, ode_data[1,:], label = "data")
plot!(tsteps1, ode_data_pred[1,:], label = "data to check prediction")
scatter!(tsteps1, prediction[1,:], label = "prediction to check prediction")


ERROR: MethodError: no method matching Vector{Float32}(::Matrix{Float32})

What’s the stacktrace you get and what’s it pointing to?

  [1] convert(#unused#::Type{Vector{Float32}}, a::Matrix{Float32})
@ Base .\array.jl:613
[2] setproperty!(x::OrdinaryDiffEq.ODEIntegrator{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, false, Vector{Float32}, Nothing, Float32, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:252, Axis(weight = ViewAxis(1:168, ShapedAxis((84, 2), NamedTuple())), bias = ViewAxis(169:252, ShapedAxis((84, 1), NamedTuple())))), layer_3 = ViewAxis(253:422, Axis(weight = ViewAxis(1:168, ShapedAxis((2, 84), NamedTuple())), bias = ViewAxis(169:170, ShapedAxis((2, 1), NamedTuple())))))}}}, Float32, Float32, Float32, Float32, Vector{Vector{Float32}}, ODESolution{Float32, 2, Vector{Vector{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Vector{Float32}}}, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, false, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:252, Axis(weight = ViewAxis(1:168, ShapedAxis((84, 2), NamedTuple())), bias = ViewAxis(169:252, ShapedAxis((84, 1), NamedTuple())))), layer_3 = ViewAxis(253:422, Axis(weight = ViewAxis(1:168, ShapedAxis((2, 84), NamedTuple())), bias = ViewAxis(169:170, ShapedAxis((2, 1), NamedTuple())))))}}}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#81#82", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, SciMLBase.AutoSpecialize, var"#81#82", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Vector{Vector{Float32}}, Vector{Float32}, Vector{Vector{Vector{Float32}}}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}}, DiffEqBase.DEStats, Nothing}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#81#82", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, OrdinaryDiffEq.DEOptions{Float32, Float32, Float32, Float32, PIController{Rational{Int64}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Nothing, CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, Nothing, Nothing, Int64, Tuple{}, StepRangeLen{Float32, Float64, Float64, Int64}, Tuple{}}, Vector{Float32}, Float32, Nothing, OrdinaryDiffEq.DefaultInit}, f::Symbol, v::Matrix{Float32})
@ Base .\Base.jl:38
[3] initialize!(integrator::OrdinaryDiffEq.ODEIntegrator{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, false, Vector{Float32}, Nothing, Float32, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:252, Axis(weight = ViewAxis(1:168, ShapedAxis((84, 2), NamedTuple())), bias = ViewAxis(169:252, ShapedAxis((84, 1), NamedTuple())))), layer_3 = ViewAxis(253:422, Axis(weight = ViewAxis(1:168, ShapedAxis((2, 84), NamedTuple())), bias = ViewAxis(169:170, ShapedAxis((2, 1), NamedTuple())))))}}}, Float32, Float32, Float32, Float32, Vector{Vector{Float32}}, ODESolution{Float32, 2, Vector{Vector{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Vector{Float32}}}, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, false, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:252, Axis(weight = ViewAxis(1:168, ShapedAxis((84, 2), NamedTuple())), bias = ViewAxis(169:252, ShapedAxis((84, 1), NamedTuple())))), layer_3 = ViewAxis(253:422, Axis(weight = ViewAxis(1:168, ShapedAxis((2, 84), NamedTuple())), bias = ViewAxis(169:170, ShapedAxis((2, 1), NamedTuple())))))}}}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#81#82", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, SciMLBase.AutoSpecialize, var"#81#82", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Vector{Vector{Float32}}, Vector{Float32}, Vector{Vector{Vector{Float32}}}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}}, DiffEqBase.DEStats, Nothing}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#81#82", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, OrdinaryDiffEq.DEOptions{Float32, Float32, Float32, Float32, PIController{Rational{Int64}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Nothing, CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, Nothing, Nothing, Int64, Tuple{}, StepRangeLen{Float32, Float64, Float64, Int64}, Tuple{}}, Vector{Float32}, Float32, Nothing, OrdinaryDiffEq.DefaultInit}, cache::OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32})
@ OrdinaryDiffEq C:\Users\marco\.julia\packages\OrdinaryDiffEq\P7HJO\src\perform_step\low_order_rk_perform_step.jl:672
[4] __init(prob::ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, false, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:252, Axis(weight = ViewAxis(1:168, ShapedAxis((84, 2), NamedTuple())), bias = ViewAxis(169:252, ShapedAxis((84, 1), NamedTuple())))), layer_3 = ViewAxis(253:422, Axis(weight = ViewAxis(1:168, ShapedAxis((2, 84), NamedTuple())), bias = ViewAxis(169:170, ShapedAxis((2, 1), NamedTuple())))))}}}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#81#82", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, alg::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{Val{true}}; saveat::StepRangeLen{Float32, Float64, Float64, Int64}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::Float32, dtmin::Nothing, dtmax::Float32, force_dtmin::Bool, adaptive::Bool, gamma::Rational{Int64}, abstol::Nothing, reltol::Nothing, qmin::Rational{Int64}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{Int64}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ OrdinaryDiffEq C:\Users\marco\.julia\packages\OrdinaryDiffEq\P7HJO\src\solve.jl:493
[5] __init (repeats 5 times)
@ C:\Users\marco\.julia\packages\OrdinaryDiffEq\P7HJO\src\solve.jl:10 [inlined]
[6] #__solve#561
@ C:\Users\marco\.julia\packages\OrdinaryDiffEq\P7HJO\src\solve.jl:5 [inlined]
[7] __solve
@ C:\Users\marco\.julia\packages\OrdinaryDiffEq\P7HJO\src\solve.jl:1 [inlined]
[8] #solve_call#26
@ C:\Users\marco\.julia\packages\DiffEqBase\WXn2i\src\solve.jl:473 [inlined]
[9] solve_call
@ C:\Users\marco\.julia\packages\DiffEqBase\WXn2i\src\solve.jl:443 [inlined]
[10] #solve_up#32
@ C:\Users\marco\.julia\packages\DiffEqBase\WXn2i\src\solve.jl:835 [inlined]
[11] solve_up
@ C:\Users\marco\.julia\packages\DiffEqBase\WXn2i\src\solve.jl:808 [inlined]
[12] #solve#31
@ C:\Users\marco\.julia\packages\DiffEqBase\WXn2i\src\solve.jl:802 [inlined]
[13] solve
@ C:\Users\marco\.julia\packages\DiffEqBase\WXn2i\src\solve.jl:792 [inlined]
[14] (::DiffEqFlux.var"#198#200"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:252, Axis(weight = ViewAxis(1:168, ShapedAxis((84, 2), NamedTuple())), bias = ViewAxis(169:252, ShapedAxis((84, 1), NamedTuple())))), layer_3 = ViewAxis(253:422, Axis(weight = ViewAxis(1:168, ShapedAxis((2, 84), NamedTuple())), bias = ViewAxis(169:170, ShapedAxis((2, 1), NamedTuple())))))}}}, Matrix{Float32}, StepRangeLen{Float32, Float64, Float64, Int64}, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, false, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:252, Axis(weight = ViewAxis(1:168, ShapedAxis((84, 2), NamedTuple())), bias = ViewAxis(169:252, ShapedAxis((84, 1), NamedTuple())))), layer_3 = ViewAxis(253:422, Axis(weight = ViewAxis(1:168, ShapedAxis((2, 84), NamedTuple())), bias = ViewAxis(169:170, ShapedAxis((2, 1), NamedTuple())))))}}}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#81#82", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}})(rg::UnitRange{Int64})
@ DiffEqFlux .\none:0
[15] iterate
@ .\generator.jl:47 [inlined]
[16] collect(itr::Base.Generator{Vector{UnitRange{Int64}}, DiffEqFlux.var"#198#200"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:252, Axis(weight = ViewAxis(1:168, ShapedAxis((84, 2), NamedTuple())), bias = ViewAxis(169:252, ShapedAxis((84, 1), NamedTuple())))), layer_3 = ViewAxis(253:422, Axis(weight = ViewAxis(1:168, ShapedAxis((2, 84), NamedTuple())), bias = ViewAxis(169:170, ShapedAxis((2, 1), NamedTuple())))))}}}, Matrix{Float32}, StepRangeLen{Float32, Float64, Float64, Int64}, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, false, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:252, Axis(weight = ViewAxis(1:168, ShapedAxis((84, 2), NamedTuple())), bias = ViewAxis(169:252, ShapedAxis((84, 1), NamedTuple())))), layer_3 = ViewAxis(253:422, Axis(weight = ViewAxis(1:168, ShapedAxis((2, 84), NamedTuple())), bias = ViewAxis(169:170, ShapedAxis((2, 1), NamedTuple())))))}}}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#81#82", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}})
@ Base .\array.jl:782
[17] multiple_shoot(p::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:252, Axis(weight = ViewAxis(1:168, ShapedAxis((84, 2), NamedTuple())), bias = ViewAxis(169:252, ShapedAxis((84, 1), NamedTuple())))), layer_3 = ViewAxis(253:422, Axis(weight = ViewAxis(1:168, ShapedAxis((2, 84), NamedTuple())), bias = ViewAxis(169:170, ShapedAxis((2, 1), NamedTuple())))))}}}, ode_data::Matrix{Float32}, tsteps::StepRangeLen{Float32, Float64, Float64, Int64}, prob::ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, false, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:252, Axis(weight = ViewAxis(1:168, ShapedAxis((84, 2), NamedTuple())), bias = ViewAxis(169:252, ShapedAxis((84, 1), NamedTuple())))), layer_3 = ViewAxis(253:422, Axis(weight = ViewAxis(1:168, ShapedAxis((2, 84), NamedTuple())), bias = ViewAxis(169:170, ShapedAxis((2, 1), NamedTuple())))))}}}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#81#82", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, loss_function::typeof(loss_function), continuity_loss::typeof(DiffEqFlux._default_continuity_loss), solver::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, group_size::Int64; continuity_term::Int64, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ DiffEqFlux C:\Users\marco\.julia\packages\DiffEqFlux\Ckhh7\src\multiple_shooting.jl:60
[18] #multiple_shoot#202
@ C:\Users\marco\.julia\packages\DiffEqFlux\Ckhh7\src\multiple_shooting.jl:110 [inlined]
[19] loss_multiple_shooting(p::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:252, Axis(weight = ViewAxis(1:168, ShapedAxis((84, 2), NamedTuple())), bias = ViewAxis(169:252, ShapedAxis((84, 1), NamedTuple())))), layer_3 = ViewAxis(253:422, Axis(weight = ViewAxis(1:168, ShapedAxis((2, 84), NamedTuple())), bias = ViewAxis(169:170, ShapedAxis((2, 1), NamedTuple())))))}}})
@ Main c:\Users\marco\OneDrive\Desktop\Tesi\Julia_files\multipleshooting.jl:82
[20] #86
@ c:\Users\marco\OneDrive\Desktop\Tesi\Julia_files\multipleshooting.jl:87 [inlined]
[21] OptimizationFunction
@ C:\Users\marco\.julia\packages\SciMLBase\QqtZA\src\scimlfunctions.jl:3580 [inlined]
[22] #2
@ C:\Users\marco\.julia\packages\OptimizationPolyalgorithms\hCJel\src\OptimizationPolyalgorithms.jl:14 [inlined]
[23] __solve(::OptimizationProblem{true, OptimizationFunction{true, Optimization.AutoZygote, var"#86#87", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:252, Axis(weight = ViewAxis(1:168, ShapedAxis((84, 2), NamedTuple())), bias = ViewAxis(169:252, ShapedAxis((84, 1), NamedTuple())))), layer_3 = ViewAxis(253:422, Axis(weight = ViewAxis(1:168, ShapedAxis((2, 84), NamedTuple())), bias = ViewAxis(169:170, ShapedAxis((2, 1), NamedTuple())))))}}}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::PolyOpt; maxiters::Nothing, kwargs::Base.Pairs{Symbol, var"#83#85", Tuple{Symbol}, NamedTuple{(:callback,), Tuple{var"#83#85"}}})
@ OptimizationPolyalgorithms C:\Users\marco\.julia\packages\OptimizationPolyalgorithms\hCJel\src\OptimizationPolyalgorithms.jl:15
[24] __solve
@ C:\Users\marco\.julia\packages\OptimizationPolyalgorithms\hCJel\src\OptimizationPolyalgorithms.jl:9 [inlined]
[25] #solve#540
@ C:\Users\marco\.julia\packages\SciMLBase\QqtZA\src\solve.jl:84 [inlined]
[26] top-level scope
@ c:\Users\marco\OneDrive\Desktop\Tesi\Julia_files\multipleshooting.jl:89

If x is a vector then sol(x)` returns a matrix so the nn returns a matrix. Is that what you wanted?