Trouble with NeuralODE and Metal

Here is my example:

using Lux, Optimization, OptimizationOptimisers, Zygote, OrdinaryDiffEq,
    Metal, SciMLSensitivity, Random, ComponentArrays
import DiffEqFlux: NeuralODE

gdev = gpu_device()
#rng for Lux.setup
rng = Random.default_rng()
# Generate Data
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2]; length = datasize)
true_A = Matrix([-0.1f0 2.0f0; -2.0f0 -0.1f0]')
function trueODEfunc(du, u, A, t)
    du .= A * u .^ 3
end
prob_trueode = ODEProblem(trueODEfunc, u0, tspan, true_A)
# Make the data into a GPU-based array if the user has a GPU
ode_data = gdev(Array(solve(prob_trueode, Tsit5(); saveat = tsteps)))

##

dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2))
u0 = Float32[2.0; 0.0] |> gdev
p, st = Lux.setup(rng, dudt2)
p = p |> ComponentArray |> gdev
st = st |> gdev

prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps)

function predict_neuralode(p)
    gdev(Array(first(prob_neuralode(u0, p, st))))
end
function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2, ode_data - pred)
    return loss, pred
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p)
result_neuralode = Optimization.solve(optprob, Adam(0.05); maxiters = 100)

Results:

julia> optf(p, [])
(309.3814f0, Float32[2.0 2.0186965 … 1.9060048 1.8881334; 0.0 0.09428286 … 2.486669 2.5848868])

julia> result_neuralode = Optimization.solve(optprob, Adam(0.05); maxiters = 100)
ERROR: InvalidIRError: compiling MethodInstance for (::GPUArrays.var"#broadcast_kernel#38")(::Metal.mtlKernelContext, ::MtlDeviceVector{…}, ::Base.Broadcast.Broadcasted{…}, ::Int64) resulted in invalid LLVM IR
Reason: unsupported use of double value
Stacktrace:
 [1] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] power_by_squaring
   @ ./intfuncs.jl:0
 [2] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] *
   @ ./float.jl:411
 [2] power_by_squaring
   @ ./intfuncs.jl:282
 [3] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] ==
   @ ./float.jl:534
 [2] isone
   @ ./number.jl:62
 [3] power_by_squaring
   @ ./intfuncs.jl:284
 [4] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] ==
   @ ./float.jl:534
 [2] isone
   @ ./number.jl:62
 [3] power_by_squaring
   @ ./intfuncs.jl:285
 [4] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] *
   @ ./float.jl:411
 [2] power_by_squaring
   @ ./intfuncs.jl:291
 [3] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] *
   @ ./float.jl:411
 [2] power_by_squaring
   @ ./intfuncs.jl:298
 [3] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] *
   @ ./float.jl:411
 [2] power_by_squaring
   @ ./intfuncs.jl:300
 [3] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] power_by_squaring
   @ ./intfuncs.jl:285
 [2] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
  [1] ^
    @ ./math.jl:0
  [2] _broadcast_getindex_evalf
    @ ./broadcast.jl:709
  [3] _broadcast_getindex
    @ ./broadcast.jl:682
  [4] _getindex
    @ ./broadcast.jl:706
  [5] _broadcast_getindex
    @ ./broadcast.jl:681
  [6] _getindex
    @ ./broadcast.jl:706
  [7] _getindex
    @ ./broadcast.jl:705
  [8] _broadcast_getindex
    @ ./broadcast.jl:681
  [9] getindex
    @ ./broadcast.jl:636
 [10] broadcast_kernel
    @ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:64
Reason: unsupported use of double value
Stacktrace:
  [1] /
    @ ./float.jl:412
  [2] inv
    @ ./number.jl:255
  [3] ^
    @ ./math.jl:1280
  [4] _broadcast_getindex_evalf
    @ ./broadcast.jl:709
  [5] _broadcast_getindex
    @ ./broadcast.jl:682
  [6] _getindex
    @ ./broadcast.jl:706
  [7] _broadcast_getindex
    @ ./broadcast.jl:681
  [8] _getindex
    @ ./broadcast.jl:706
  [9] _getindex
    @ ./broadcast.jl:705
 [10] _broadcast_getindex
    @ ./broadcast.jl:681
 [11] getindex
    @ ./broadcast.jl:636
 [12] broadcast_kernel
    @ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:64
Reason: unsupported use of double value
Stacktrace:
  [1] ^
    @ ./math.jl:1280
  [2] _broadcast_getindex_evalf
    @ ./broadcast.jl:709
  [3] _broadcast_getindex
    @ ./broadcast.jl:682
  [4] _getindex
    @ ./broadcast.jl:706
  [5] _broadcast_getindex
    @ ./broadcast.jl:681
  [6] _getindex
    @ ./broadcast.jl:706
  [7] _getindex
    @ ./broadcast.jl:705
  [8] _broadcast_getindex
    @ ./broadcast.jl:681
  [9] getindex
    @ ./broadcast.jl:636
 [10] broadcast_kernel
    @ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:64
Reason: unsupported use of double value
Stacktrace:
  [1] Float32
    @ ./float.jl:258
  [2] ^
    @ ./math.jl:1280
  [3] _broadcast_getindex_evalf
    @ ./broadcast.jl:709
  [4] _broadcast_getindex
    @ ./broadcast.jl:682
  [5] _getindex
    @ ./broadcast.jl:706
  [6] _broadcast_getindex
    @ ./broadcast.jl:681
  [7] _getindex
    @ ./broadcast.jl:706
  [8] _getindex
    @ ./broadcast.jl:705
  [9] _broadcast_getindex
    @ ./broadcast.jl:681
 [10] getindex
    @ ./broadcast.jl:636
 [11] broadcast_kernel
    @ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:64
Reason: unsupported use of double value
Stacktrace:
  [1] ^
    @ ./math.jl:1281
  [2] _broadcast_getindex_evalf
    @ ./broadcast.jl:709
  [3] _broadcast_getindex
    @ ./broadcast.jl:682
  [4] _getindex
    @ ./broadcast.jl:706
  [5] _broadcast_getindex
    @ ./broadcast.jl:681
  [6] _getindex
    @ ./broadcast.jl:706
  [7] _getindex
    @ ./broadcast.jl:705
  [8] _broadcast_getindex
    @ ./broadcast.jl:681
  [9] getindex
    @ ./broadcast.jl:636
 [10] broadcast_kernel
    @ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:64
Reason: unsupported use of double value
Stacktrace:
  [1] Float32
    @ ./float.jl:258
  [2] ^
    @ ./math.jl:1281
  [3] _broadcast_getindex_evalf
    @ ./broadcast.jl:709
  [4] _broadcast_getindex
    @ ./broadcast.jl:682
  [5] _getindex
    @ ./broadcast.jl:706
  [6] _broadcast_getindex
    @ ./broadcast.jl:681
  [7] _getindex
    @ ./broadcast.jl:706
  [8] _getindex
    @ ./broadcast.jl:705
  [9] _broadcast_getindex
    @ ./broadcast.jl:681
 [10] getindex
    @ ./broadcast.jl:636
 [11] broadcast_kernel
    @ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:64
Hint: catch this exception as `err` and call `code_typed(err; interactive = true)` to introspect the erronous code with Cthulhu.jl
Stacktrace:
   [1] check_ir(job::GPUCompiler.CompilerJob{GPUCompiler.MetalCompilerTarget, Metal.MetalCompilerParams}, args::LLVM.Module)
     @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/validation.jl:147
   [2] macro expansion
     @ ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:440 [inlined]
   [3] macro expansion
     @ ~/.julia/packages/TimerOutputs/RsWnF/src/TimerOutput.jl:253 [inlined]
   [4] macro expansion
     @ ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:439 [inlined]
   [5] 
     @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/utils.jl:92
   [6] emit_llvm
     @ ~/.julia/packages/GPUCompiler/2mJjc/src/utils.jl:86 [inlined]
   [7] 
     @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:129
   [8] 
     @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:106
   [9] compile
     @ ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:98 [inlined]
  [10] #45
     @ ~/.julia/packages/Metal/lnkVP/src/compiler/compilation.jl:57 [inlined]
  [11] JuliaContext(f::Metal.var"#45#46"{GPUCompiler.CompilerJob{GPUCompiler.MetalCompilerTarget, Metal.MetalCompilerParams}})
     @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:47
  [12] compile(job::GPUCompiler.CompilerJob)
     @ Metal ~/.julia/packages/Metal/lnkVP/src/compiler/compilation.jl:56
  [13] actual_compilation(cache::Dict{…}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{…}, compiler::typeof(Metal.compile), linker::typeof(Metal.link))
     @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/execution.jl:125
  [14] cached_compilation(cache::Dict{…}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{…}, compiler::Function, linker::Function)
     @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/execution.jl:103
  [15] macro expansion
     @ ~/.julia/packages/Metal/lnkVP/src/compiler/execution.jl:162 [inlined]
  [16] macro expansion
     @ ./lock.jl:267 [inlined]
  [17] mtlfunction(f::GPUArrays.var"#broadcast_kernel#38", tt::Type{Tuple{…}}; name::Nothing, kwargs::@Kwargs{})
     @ Metal ~/.julia/packages/Metal/lnkVP/src/compiler/execution.jl:157
  [18] mtlfunction(f::GPUArrays.var"#broadcast_kernel#38", tt::Type{Tuple{…}})
     @ Metal ~/.julia/packages/Metal/lnkVP/src/compiler/execution.jl:155
  [19] macro expansion
     @ ~/.julia/packages/Metal/lnkVP/src/compiler/execution.jl:77 [inlined]
  [20] #launch_heuristic#90
     @ ~/.julia/packages/Metal/lnkVP/src/gpuarrays.jl:14 [inlined]
  [21] launch_heuristic
     @ ~/.julia/packages/Metal/lnkVP/src/gpuarrays.jl:12 [inlined]
  [22] _copyto!
     @ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:70 [inlined]
  [23] copyto!
     @ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:51 [inlined]
  [24] copy
     @ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:42 [inlined]
  [25] materialize
     @ ./broadcast.jl:903 [inlined]
  [26] (::Zygote.var"#1235#1238"{3, MtlVector{…}})(ȳ::MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/lib/broadcast.jl:108
  [27] #3878#back
     @ Zygote ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71 [inlined]
  [28] #5
     @ Main ~/Dropbox/research/tearfilm/repos/uode/scripts/node-metal.jl:24 [inlined]
  [29] WrappedFunction
     @ Lux ~/.julia/packages/Lux/VFyfk/src/layers/basic.jl:118 [inlined]
  [30] apply
     @ ~/.julia/packages/LuxCore/aumFq/src/LuxCore.jl:115 [inlined]
  [31] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{MtlVector{…}, Nothing})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [32] macro expansion
     @ ./tuple.jl:0 [inlined]
  [33] applychain
     @ ~/.julia/packages/Lux/VFyfk/src/layers/containers.jl:480 [inlined]
  [34] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{MtlVector{…}, Nothing})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [35] Chain
     @ ~/.julia/packages/Lux/VFyfk/src/layers/containers.jl:478 [inlined]
  [36] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{MtlVector{…}, Nothing})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [37] apply
     @ ~/.julia/packages/LuxCore/aumFq/src/LuxCore.jl:115 [inlined]
  [38] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{MtlVector{…}, Nothing})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [39] StatefulLuxLayer
     @ ~/.julia/packages/Lux/VFyfk/src/contrib/stateful.jl:47 [inlined]
  [40] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [41] dudt
     @ ~/.julia/packages/DiffEqFlux/7OfDv/src/neural_de.jl:49 [inlined]
  [42] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [43] #291
     @ ~/.julia/packages/Zygote/YYT6v/src/lib/lib.jl:206 [inlined]
  [44] #2169#back
     @ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71 [inlined]
  [45] ODEFunction
     @ ~/.julia/packages/SciMLBase/MMAmp/src/scimlfunctions.jl:2355 [inlined]
  [46] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [47] #38
     @ ~/.julia/packages/SciMLSensitivity/Eqao8/src/derivative_wrappers.jl:573 [inlined]
  [48] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [49] (::Zygote.var"#75#76"{Zygote.Pullback{…}})(Δ::MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:45
  [50] _vecjacobian!(dλ::MtlVector{…}, y::MtlVector{…}, λ::MtlVector{…}, p::ComponentVector{…}, t::Float32, S::SciMLSensitivity.ODEInterpolatingAdjointSensitivityFunction{…}, isautojacvec::ZygoteVJP, dgrad::MtlVector{…}, dy::Nothing, W::Nothing)
     @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Eqao8/src/derivative_wrappers.jl:584
  [51] #vecjacobian!#18
     @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Eqao8/src/derivative_wrappers.jl:224 [inlined]
  [52] (::SciMLSensitivity.ODEInterpolatingAdjointSensitivityFunction{…})(du::MtlVector{…}, u::MtlVector{…}, p::ComponentVector{…}, t::Float32)
     @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Eqao8/src/interpolating_adjoint.jl:0
  [53] ODEFunction
     @ ~/.julia/packages/SciMLBase/MMAmp/src/scimlfunctions.jl:2355 [inlined]
  [54] initialize!(integrator::OrdinaryDiffEq.ODEIntegrator{…}, cache::OrdinaryDiffEq.Tsit5Cache{…})
     @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/ym8TQ/src/perform_step/low_order_rk_perform_step.jl:792
  [55] __init(prob::ODEProblem{…}, alg::Tsit5{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Vector{…}, tstops::Vector{…}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::CallbackSet{…}, dense::Bool, calck::Bool, dt::Float32, dtmin::Nothing, dtmax::Float32, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Float64, reltol::Float64, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::@Kwargs{})
     @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/ym8TQ/src/solve.jl:513
  [56] __init (repeats 5 times)
     @ ~/.julia/packages/OrdinaryDiffEq/ym8TQ/src/solve.jl:10 [inlined]
  [57] __solve(::ODEProblem{…}, ::Tsit5{…}; kwargs::@Kwargs{…})
     @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/ym8TQ/src/solve.jl:5
  [58] __solve
     @ ~/.julia/packages/OrdinaryDiffEq/ym8TQ/src/solve.jl:1 [inlined]
  [59] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
     @ DiffEqBase ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:605
  [60] solve_call
     @ ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:566 [inlined]
  [61] solve_up(prob::SciMLBase.AbstractDEProblem, sensealg::Any, u0::Any, p::Any, args::Vararg{Any}; kwargs...)
     @ DiffEqBase ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:1054 [inlined]
  [62] solve_up
     @ ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:1040 [inlined]
  [63] solve(prob::ODEProblem{…}, args::Tsit5{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{…})
     @ DiffEqBase ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:977
  [64] _adjoint_sensitivities(sol::ODESolution{…}, sensealg::InterpolatingAdjoint{…}, alg::Tsit5{…}; t::Vector{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, checkpoints::Vector{…}, corfunc_analytical::Nothing, callback::Nothing, kwargs::@Kwargs{…})
     @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Eqao8/src/sensitivity_interface.jl:432
  [65] _adjoint_sensitivities
     @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Eqao8/src/sensitivity_interface.jl:390 [inlined]
  [66] #adjoint_sensitivities#63
     @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Eqao8/src/sensitivity_interface.jl:386 [inlined]
  [67] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#307"{…})(Δ::ODESolution{…})
     @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Eqao8/src/concrete_solve.jl:535
  [68] ZBack
     @ ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:211 [inlined]
  [69] (::Zygote.var"#kw_zpullback#53"{…})(dy::ODESolution{…})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:237
  [70] #291
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/lib/lib.jl:206 [inlined]
  [71] (::Zygote.var"#2169#back#293"{…})(Δ::ODESolution{…})
     @ Zygote ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71
  [72] #solve#40
     @ ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:977 [inlined]
  [73] (::Zygote.Pullback{…})(Δ::ODESolution{…})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [74] #291
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/lib/lib.jl:206 [inlined]
  [75] (::Zygote.var"#2169#back#293"{…})(Δ::ODESolution{…})
     @ Zygote ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71
  [76] solve
     @ ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:967 [inlined]
  [77] (::Zygote.Pullback{…})(Δ::ODESolution{…})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [78] #291
     @ ~/.julia/packages/Zygote/YYT6v/src/lib/lib.jl:206 [inlined]
  [79] #2169#back
     @ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71 [inlined]
  [80] NeuralODE
     @ ~/.julia/packages/DiffEqFlux/7OfDv/src/neural_de.jl:53 [inlined]
  [81] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ODESolution{…}, Nothing})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [82] predict_neuralode
     @ ~/Dropbox/research/tearfilm/repos/uode/scripts/node-metal.jl:33 [inlined]
  [83] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [84] loss_neuralode
     @ ~/Dropbox/research/tearfilm/repos/uode/scripts/node-metal.jl:36 [inlined]
  [85] #7
     @ ~/Dropbox/research/tearfilm/repos/uode/scripts/node-metal.jl:42 [inlined]
  [86] #291
     @ ~/.julia/packages/Zygote/YYT6v/src/lib/lib.jl:206 [inlined]
  [87] #2169#back
     @ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71 [inlined]
  [88] OptimizationFunction
     @ ~/.julia/packages/SciMLBase/MMAmp/src/scimlfunctions.jl:3811 [inlined]
  [89] #291
     @ ~/.julia/packages/Zygote/YYT6v/src/lib/lib.jl:206 [inlined]
  [90] #2169#back
     @ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71 [inlined]
  [91] #37
     @ ~/.julia/packages/Optimization/sIARu/ext/OptimizationZygoteExt.jl:88 [inlined]
  [92] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [93] #291
     @ ~/.julia/packages/Zygote/YYT6v/src/lib/lib.jl:206 [inlined]
  [94] #2169#back
     @ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71 [inlined]
  [95] #39
     @ ~/.julia/packages/Optimization/sIARu/ext/OptimizationZygoteExt.jl:91 [inlined]
  [96] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [97] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:45
  [98] gradient(f::Function, args::ComponentVector{Float32, MtlVector{…}, Tuple{…}})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:97
  [99] (::OptimizationZygoteExt.var"#38#56"{…})(::ComponentVector{…}, ::ComponentVector{…})
     @ OptimizationZygoteExt ~/.julia/packages/Optimization/sIARu/ext/OptimizationZygoteExt.jl:91
 [100] macro expansion
     @ ~/.julia/packages/OptimizationOptimisers/hZdKg/src/OptimizationOptimisers.jl:65 [inlined]
 [101] macro expansion
     @ ~/.julia/packages/Optimization/sIARu/src/utils.jl:41 [inlined]
 [102] __solve(cache::OptimizationCache{…})
     @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/hZdKg/src/OptimizationOptimisers.jl:63
 [103] solve!(cache::OptimizationCache{…})
     @ SciMLBase ~/.julia/packages/SciMLBase/MMAmp/src/solve.jl:177
 [104] solve(::OptimizationProblem{…}, ::Adam; kwargs::@Kwargs{…})
     @ SciMLBase ~/.julia/packages/SciMLBase/MMAmp/src/solve.jl:94
Some type information was truncated. Use `show(err)` to see complete types.