Using ODE solution in NEURAL ODE

using ComponentArrays, Lux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationOptimJL, Random, Plots,OptimizationOptimisers
rng = Random.default_rng()
## Setup ODE to optimize
u0 = Float32[1.0 ; 1.0]
tspan = (0.0f0, 4.0f0)
datasize = 70

tsteps = range(tspan[1], tspan[2], length = datasize)
function lotka_volterra(du,u,p,t)
  x, y = u
  p = Float32[1.5;1.0;3.0;1.0]
  α, β, δ, γ = p
  du[1] = dx = α*x - β*x*y
  du[2] = dy = -δ*y + γ*x*y
end

prob = ODEProblem(lotka_volterra,u0,tspan)

# Verify ODE solution
ode_data =Array(solve(prob, Tsit5(), saveat = tsteps))

#add random 

# Build a neural network that sets the cost as the difference from the
# generated data and 1



dudt2 = Lux.Chain(x -> sin.(x),
                  Lux.Dense(2, 50, tanh),
                  Lux.Dense(50, 2))
p, st = Lux.setup(rng, dudt2)
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)```

instead of sin.(x) i would like to use the solution of the ode

You are using ‘’’ to delimit your code blocks, but you need ```, which are different ticks.

sol = solve(prob, Tsit5())

dudt2 = Lux.Chain(x -> sol(x),
                  Lux.Dense(2, 50, tanh),
                  Lux.Dense(50, 2))
using ComponentArrays, Lux, DiffEqFlux, Optimization, OptimizationPolyalgorithms, DifferentialEquations, Plots
using DiffEqFlux: group_ranges

using Random
rng = Random.default_rng()

# Define initial conditions and time steps

tspan = (0.0f0, 4.0f0)
datasize = 70
tsteps = range(tspan[1], tspan[2], length = datasize)
t=tsteps
u0 = Float32[1.0 ; 1.0]
function lotka_volterra(du,u,p,t)
  x, y= u
  p = Float32[1.5;1.0;3.0;1.0]
  α, β, δ, γ = p
  du[1] = dx = α*x - β*x*y
  du[2] = dy = -δ*y + γ*x*y
end

prob = ODEProblem(lotka_volterra,u0,tspan)

# Verify ODE solution
ode_data =Array(solve(prob, Tsit5(), saveat = tsteps))
sol=solve(prob, Tsit5())
#from ode data array to define a function

anim = Plots.Animation()

# Define the Neural Network
nn = Lux.Chain(x -> sol(x),
                Lux.Dense(2, 84, tanh),
                #Lux.Dense(84, 44, swish),
                #Lux.Dense(44, 22, swish),
                #Lux.Dense(22, 12, swish),
                Lux.Dense(84,2))
p_init, st = Lux.setup(rng, nn)

neuralode = NeuralODE(nn, tspan, Tsit5(), saveat = tsteps)
prob_node = ODEProblem((u,p,t)->nn(u,p,st)[1], u0, tspan, ComponentArray(p_init))


function plot_multiple_shoot(plt, preds, group_size)
	step = group_size-1
	ranges = group_ranges(datasize, group_size)

	for (i, rg) in enumerate(ranges)
		plot!(plt, tsteps[rg], preds[i][1,:], markershape=:circle)
	end
end

# Animate training, cannot make animation on CI server
# anim = Plots.Animation()
iter = 0
callback = function (p, l, preds; doplot = true)
  display(l)
  global iter
  iter += 1
  if doplot && iter%1 == 0
    # plot the original data
    plt = scatter(tsteps, ode_data[1,:], label = "Data")

    # plot the different predictions for individual shoot
    plot_multiple_shoot(plt, preds, group_size)

    frame(anim)
    display(plot(plt))
  end
  return false
end

# Define parameters for Multiple Shooting
group_size = 3
continuity_term = 100

function loss_function(data, pred)
	return sum(abs2, data - pred)
end

function loss_multiple_shooting(p)
    return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function, Tsit5(),
                          group_size; continuity_term)
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss_multiple_shooting(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentArray(p_init))
res_ms = Optimization.solve(optprob, PolyOpt(), callback = callback)


gif(anim, "multiple_shooting.gif", fps=15)

optimized_params= res_ms.u
print("the optimized params are: ",optimized_params)
u1 = Float32[ode_data[1,end] ;  ode_data[2,end]]
tspan1 = (tspan[2], tspan[2]+10.0f0)
tsteps1=range(tspan1[1], tspan1[2], length = datasize)

true_odePRED = ODEProblem(lotka_volterra,u1,tspan1)
ode_data_pred = Array(solve(true_odePRED, Tsit5(), saveat = tsteps1))
prob_neuralode2 = NeuralODE(nn, tspan1, Tsit5(), saveat = tsteps1)

function predict_neuralode2(p)
  Array(prob_neuralode2(u1, p, st)[1])
end
prediction=predict_neuralode2(optimized_params)

plot(tsteps, ode_data[1,:], label = "data")
plot!(tsteps1, ode_data_pred[1,:], label = "data to check prediction")
scatter!(tsteps1, prediction[1,:], label = "prediction to check prediction")

ERROR: MethodError: no method matching Vector{Float32}(::Matrix{Float32})

What’s the stacktrace you get and what’s it pointing to?

  [1] convert(#unused#::Type{Vector{Float32}}, a::Matrix{Float32})
    @ Base .\array.jl:613
  [2] setproperty!(x::OrdinaryDiffEq.ODEIntegrator{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, false, Vector{Float32}, Nothing, Float32, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:252, Axis(weight = ViewAxis(1:168, ShapedAxis((84, 2), NamedTuple())), bias = ViewAxis(169:252, ShapedAxis((84, 1), NamedTuple())))), layer_3 = ViewAxis(253:422, Axis(weight = ViewAxis(1:168, ShapedAxis((2, 84), NamedTuple())), bias = ViewAxis(169:170, ShapedAxis((2, 1), NamedTuple())))))}}}, Float32, Float32, Float32, Float32, Vector{Vector{Float32}}, ODESolution{Float32, 2, Vector{Vector{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Vector{Float32}}}, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, false, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:252, Axis(weight = ViewAxis(1:168, ShapedAxis((84, 2), NamedTuple())), bias = ViewAxis(169:252, ShapedAxis((84, 1), NamedTuple())))), layer_3 = ViewAxis(253:422, Axis(weight = ViewAxis(1:168, ShapedAxis((2, 84), NamedTuple())), bias = ViewAxis(169:170, ShapedAxis((2, 1), NamedTuple())))))}}}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#81#82", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, SciMLBase.AutoSpecialize, var"#81#82", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Vector{Vector{Float32}}, Vector{Float32}, Vector{Vector{Vector{Float32}}}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}}, DiffEqBase.DEStats, Nothing}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#81#82", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, OrdinaryDiffEq.DEOptions{Float32, Float32, Float32, Float32, PIController{Rational{Int64}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Nothing, CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, Nothing, Nothing, Int64, Tuple{}, StepRangeLen{Float32, Float64, Float64, Int64}, Tuple{}}, Vector{Float32}, Float32, Nothing, OrdinaryDiffEq.DefaultInit}, f::Symbol, v::Matrix{Float32})
    @ Base .\Base.jl:38
  [3] initialize!(integrator::OrdinaryDiffEq.ODEIntegrator{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, false, Vector{Float32}, Nothing, Float32, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:252, Axis(weight = ViewAxis(1:168, ShapedAxis((84, 2), NamedTuple())), bias = ViewAxis(169:252, ShapedAxis((84, 1), NamedTuple())))), layer_3 = ViewAxis(253:422, Axis(weight = ViewAxis(1:168, ShapedAxis((2, 84), NamedTuple())), bias = ViewAxis(169:170, ShapedAxis((2, 1), NamedTuple())))))}}}, Float32, Float32, Float32, Float32, Vector{Vector{Float32}}, ODESolution{Float32, 2, Vector{Vector{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Vector{Float32}}}, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, false, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:252, Axis(weight = ViewAxis(1:168, ShapedAxis((84, 2), NamedTuple())), bias = ViewAxis(169:252, ShapedAxis((84, 1), NamedTuple())))), layer_3 = ViewAxis(253:422, Axis(weight = ViewAxis(1:168, ShapedAxis((2, 84), NamedTuple())), bias = ViewAxis(169:170, ShapedAxis((2, 1), NamedTuple())))))}}}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#81#82", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, SciMLBase.AutoSpecialize, var"#81#82", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Vector{Vector{Float32}}, Vector{Float32}, Vector{Vector{Vector{Float32}}}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}}, DiffEqBase.DEStats, Nothing}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#81#82", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, OrdinaryDiffEq.DEOptions{Float32, Float32, Float32, Float32, PIController{Rational{Int64}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Nothing, CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, Nothing, Nothing, Int64, Tuple{}, StepRangeLen{Float32, Float64, Float64, Int64}, Tuple{}}, Vector{Float32}, Float32, Nothing, OrdinaryDiffEq.DefaultInit}, cache::OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32})
    @ OrdinaryDiffEq C:\Users\marco\.julia\packages\OrdinaryDiffEq\P7HJO\src\perform_step\low_order_rk_perform_step.jl:672
  [4] __init(prob::ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, false, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:252, Axis(weight = ViewAxis(1:168, ShapedAxis((84, 2), NamedTuple())), bias = ViewAxis(169:252, ShapedAxis((84, 1), NamedTuple())))), layer_3 = ViewAxis(253:422, Axis(weight = ViewAxis(1:168, ShapedAxis((2, 84), NamedTuple())), bias = ViewAxis(169:170, ShapedAxis((2, 1), NamedTuple())))))}}}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#81#82", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, alg::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{Val{true}}; saveat::StepRangeLen{Float32, Float64, Float64, Int64}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::Float32, dtmin::Nothing, dtmax::Float32, force_dtmin::Bool, adaptive::Bool, gamma::Rational{Int64}, abstol::Nothing, reltol::Nothing, qmin::Rational{Int64}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{Int64}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ OrdinaryDiffEq C:\Users\marco\.julia\packages\OrdinaryDiffEq\P7HJO\src\solve.jl:493
  [5] __init (repeats 5 times)
    @ C:\Users\marco\.julia\packages\OrdinaryDiffEq\P7HJO\src\solve.jl:10 [inlined]
  [6] #__solve#561
    @ C:\Users\marco\.julia\packages\OrdinaryDiffEq\P7HJO\src\solve.jl:5 [inlined]
  [7] __solve
    @ C:\Users\marco\.julia\packages\OrdinaryDiffEq\P7HJO\src\solve.jl:1 [inlined]
  [8] #solve_call#26
    @ C:\Users\marco\.julia\packages\DiffEqBase\WXn2i\src\solve.jl:473 [inlined]
  [9] solve_call
    @ C:\Users\marco\.julia\packages\DiffEqBase\WXn2i\src\solve.jl:443 [inlined]
 [10] #solve_up#32
    @ C:\Users\marco\.julia\packages\DiffEqBase\WXn2i\src\solve.jl:835 [inlined]
 [11] solve_up
    @ C:\Users\marco\.julia\packages\DiffEqBase\WXn2i\src\solve.jl:808 [inlined]
 [12] #solve#31
    @ C:\Users\marco\.julia\packages\DiffEqBase\WXn2i\src\solve.jl:802 [inlined]
 [13] solve
    @ C:\Users\marco\.julia\packages\DiffEqBase\WXn2i\src\solve.jl:792 [inlined]
 [14] (::DiffEqFlux.var"#198#200"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:252, Axis(weight = ViewAxis(1:168, ShapedAxis((84, 2), NamedTuple())), bias = ViewAxis(169:252, ShapedAxis((84, 1), NamedTuple())))), layer_3 = ViewAxis(253:422, Axis(weight = ViewAxis(1:168, ShapedAxis((2, 84), NamedTuple())), bias = ViewAxis(169:170, ShapedAxis((2, 1), NamedTuple())))))}}}, Matrix{Float32}, StepRangeLen{Float32, Float64, Float64, Int64}, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, false, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:252, Axis(weight = ViewAxis(1:168, ShapedAxis((84, 2), NamedTuple())), bias = ViewAxis(169:252, ShapedAxis((84, 1), NamedTuple())))), layer_3 = ViewAxis(253:422, Axis(weight = ViewAxis(1:168, ShapedAxis((2, 84), NamedTuple())), bias = ViewAxis(169:170, ShapedAxis((2, 1), NamedTuple())))))}}}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#81#82", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}})(rg::UnitRange{Int64})
    @ DiffEqFlux .\none:0
 [15] iterate
    @ .\generator.jl:47 [inlined]
 [16] collect(itr::Base.Generator{Vector{UnitRange{Int64}}, DiffEqFlux.var"#198#200"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:252, Axis(weight = ViewAxis(1:168, ShapedAxis((84, 2), NamedTuple())), bias = ViewAxis(169:252, ShapedAxis((84, 1), NamedTuple())))), layer_3 = ViewAxis(253:422, Axis(weight = ViewAxis(1:168, ShapedAxis((2, 84), NamedTuple())), bias = ViewAxis(169:170, ShapedAxis((2, 1), NamedTuple())))))}}}, Matrix{Float32}, StepRangeLen{Float32, Float64, Float64, Int64}, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, false, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:252, Axis(weight = ViewAxis(1:168, ShapedAxis((84, 2), NamedTuple())), bias = ViewAxis(169:252, ShapedAxis((84, 1), NamedTuple())))), layer_3 = ViewAxis(253:422, Axis(weight = ViewAxis(1:168, ShapedAxis((2, 84), NamedTuple())), bias = ViewAxis(169:170, ShapedAxis((2, 1), NamedTuple())))))}}}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#81#82", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}})
    @ Base .\array.jl:782
 [17] multiple_shoot(p::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:252, Axis(weight = ViewAxis(1:168, ShapedAxis((84, 2), NamedTuple())), bias = ViewAxis(169:252, ShapedAxis((84, 1), NamedTuple())))), layer_3 = ViewAxis(253:422, Axis(weight = ViewAxis(1:168, ShapedAxis((2, 84), NamedTuple())), bias = ViewAxis(169:170, ShapedAxis((2, 1), NamedTuple())))))}}}, ode_data::Matrix{Float32}, tsteps::StepRangeLen{Float32, Float64, Float64, Int64}, prob::ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, false, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:252, Axis(weight = ViewAxis(1:168, ShapedAxis((84, 2), NamedTuple())), bias = ViewAxis(169:252, ShapedAxis((84, 1), NamedTuple())))), layer_3 = ViewAxis(253:422, Axis(weight = ViewAxis(1:168, ShapedAxis((2, 84), NamedTuple())), bias = ViewAxis(169:170, ShapedAxis((2, 1), NamedTuple())))))}}}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#81#82", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, loss_function::typeof(loss_function), continuity_loss::typeof(DiffEqFlux._default_continuity_loss), solver::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, group_size::Int64; continuity_term::Int64, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ DiffEqFlux C:\Users\marco\.julia\packages\DiffEqFlux\Ckhh7\src\multiple_shooting.jl:60
 [18] #multiple_shoot#202
    @ C:\Users\marco\.julia\packages\DiffEqFlux\Ckhh7\src\multiple_shooting.jl:110 [inlined]
 [19] loss_multiple_shooting(p::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:252, Axis(weight = ViewAxis(1:168, ShapedAxis((84, 2), NamedTuple())), bias = ViewAxis(169:252, ShapedAxis((84, 1), NamedTuple())))), layer_3 = ViewAxis(253:422, Axis(weight = ViewAxis(1:168, ShapedAxis((2, 84), NamedTuple())), bias = ViewAxis(169:170, ShapedAxis((2, 1), NamedTuple())))))}}})
    @ Main c:\Users\marco\OneDrive\Desktop\Tesi\Julia_files\multipleshooting.jl:82
 [20] #86
    @ c:\Users\marco\OneDrive\Desktop\Tesi\Julia_files\multipleshooting.jl:87 [inlined]
 [21] OptimizationFunction
    @ C:\Users\marco\.julia\packages\SciMLBase\QqtZA\src\scimlfunctions.jl:3580 [inlined]
 [22] #2
    @ C:\Users\marco\.julia\packages\OptimizationPolyalgorithms\hCJel\src\OptimizationPolyalgorithms.jl:14 [inlined]
 [23] __solve(::OptimizationProblem{true, OptimizationFunction{true, Optimization.AutoZygote, var"#86#87", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:252, Axis(weight = ViewAxis(1:168, ShapedAxis((84, 2), NamedTuple())), bias = ViewAxis(169:252, ShapedAxis((84, 1), NamedTuple())))), layer_3 = ViewAxis(253:422, Axis(weight = ViewAxis(1:168, ShapedAxis((2, 84), NamedTuple())), bias = ViewAxis(169:170, ShapedAxis((2, 1), NamedTuple())))))}}}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::PolyOpt; maxiters::Nothing, kwargs::Base.Pairs{Symbol, var"#83#85", Tuple{Symbol}, NamedTuple{(:callback,), Tuple{var"#83#85"}}})
    @ OptimizationPolyalgorithms C:\Users\marco\.julia\packages\OptimizationPolyalgorithms\hCJel\src\OptimizationPolyalgorithms.jl:15
 [24] __solve
    @ C:\Users\marco\.julia\packages\OptimizationPolyalgorithms\hCJel\src\OptimizationPolyalgorithms.jl:9 [inlined]
 [25] #solve#540
    @ C:\Users\marco\.julia\packages\SciMLBase\QqtZA\src\solve.jl:84 [inlined]
 [26] top-level scope
    @ c:\Users\marco\OneDrive\Desktop\Tesi\Julia_files\multipleshooting.jl:89```

If x is a vector then sol(x) returns a matrix so the nn returns a matrix. Is that what you wanted?