Trouble with NeuralODE and Metal

I’m trying to do the canned basic neural ODE demo using Metal.

First off, I was getting errors even with gdev=cpu_device(), but I guessed at wrapping the solve() and prob_neuralode() outputs with Array(), and then it seems to run correctly on the cpu.

When I try to use the gpu_device(), I get the following from within Zygote:

ERROR: InvalidIRError: compiling MethodInstance for (::GPUArrays.var"#broadcast_kernel#38")(::Metal.mtlKernelContext, ::MtlDeviceVector{…}, ::Base.Broadcast.Broadcasted{…}, ::Int64) resulted in invalid LLVM IR

The error was triggered by multiple cases of “unsupported use of double value.”

I can call the optimization objective function (called optf in the demo) just fine. I have no idea where to go from here.

With LuxMetal.jl? Can you share what you ran?

For Metal Backend, just using Metal should be enough.

However, more information on the stack trace would be helpful.

Here is my example:

using Lux, Optimization, OptimizationOptimisers, Zygote, OrdinaryDiffEq,
    Metal, SciMLSensitivity, Random, ComponentArrays
import DiffEqFlux: NeuralODE

gdev = gpu_device()
#rng for Lux.setup
rng = Random.default_rng()
# Generate Data
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2]; length = datasize)
true_A = Matrix([-0.1f0 2.0f0; -2.0f0 -0.1f0]')
function trueODEfunc(du, u, A, t)
    du .= A * u .^ 3
end
prob_trueode = ODEProblem(trueODEfunc, u0, tspan, true_A)
# Make the data into a GPU-based array if the user has a GPU
ode_data = gdev(Array(solve(prob_trueode, Tsit5(); saveat = tsteps)))

##

dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2))
u0 = Float32[2.0; 0.0] |> gdev
p, st = Lux.setup(rng, dudt2)
p = p |> ComponentArray |> gdev
st = st |> gdev

prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps)

function predict_neuralode(p)
    gdev(Array(first(prob_neuralode(u0, p, st))))
end
function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2, ode_data - pred)
    return loss, pred
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p)
result_neuralode = Optimization.solve(optprob, Adam(0.05); maxiters = 100)

Results:

julia> optf(p, [])
(309.3814f0, Float32[2.0 2.0186965 … 1.9060048 1.8881334; 0.0 0.09428286 … 2.486669 2.5848868])

julia> result_neuralode = Optimization.solve(optprob, Adam(0.05); maxiters = 100)
ERROR: InvalidIRError: compiling MethodInstance for (::GPUArrays.var"#broadcast_kernel#38")(::Metal.mtlKernelContext, ::MtlDeviceVector{…}, ::Base.Broadcast.Broadcasted{…}, ::Int64) resulted in invalid LLVM IR
Reason: unsupported use of double value
Stacktrace:
 [1] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] power_by_squaring
   @ ./intfuncs.jl:0
 [2] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] *
   @ ./float.jl:411
 [2] power_by_squaring
   @ ./intfuncs.jl:282
 [3] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] ==
   @ ./float.jl:534
 [2] isone
   @ ./number.jl:62
 [3] power_by_squaring
   @ ./intfuncs.jl:284
 [4] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] ==
   @ ./float.jl:534
 [2] isone
   @ ./number.jl:62
 [3] power_by_squaring
   @ ./intfuncs.jl:285
 [4] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] *
   @ ./float.jl:411
 [2] power_by_squaring
   @ ./intfuncs.jl:291
 [3] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] *
   @ ./float.jl:411
 [2] power_by_squaring
   @ ./intfuncs.jl:298
 [3] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] *
   @ ./float.jl:411
 [2] power_by_squaring
   @ ./intfuncs.jl:300
 [3] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] power_by_squaring
   @ ./intfuncs.jl:285
 [2] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
  [1] ^
    @ ./math.jl:0
  [2] _broadcast_getindex_evalf
    @ ./broadcast.jl:709
  [3] _broadcast_getindex
    @ ./broadcast.jl:682
  [4] _getindex
    @ ./broadcast.jl:706
  [5] _broadcast_getindex
    @ ./broadcast.jl:681
  [6] _getindex
    @ ./broadcast.jl:706
  [7] _getindex
    @ ./broadcast.jl:705
  [8] _broadcast_getindex
    @ ./broadcast.jl:681
  [9] getindex
    @ ./broadcast.jl:636
 [10] broadcast_kernel
    @ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:64
Reason: unsupported use of double value
Stacktrace:
  [1] /
    @ ./float.jl:412
  [2] inv
    @ ./number.jl:255
  [3] ^
    @ ./math.jl:1280
  [4] _broadcast_getindex_evalf
    @ ./broadcast.jl:709
  [5] _broadcast_getindex
    @ ./broadcast.jl:682
  [6] _getindex
    @ ./broadcast.jl:706
  [7] _broadcast_getindex
    @ ./broadcast.jl:681
  [8] _getindex
    @ ./broadcast.jl:706
  [9] _getindex
    @ ./broadcast.jl:705
 [10] _broadcast_getindex
    @ ./broadcast.jl:681
 [11] getindex
    @ ./broadcast.jl:636
 [12] broadcast_kernel
    @ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:64
Reason: unsupported use of double value
Stacktrace:
  [1] ^
    @ ./math.jl:1280
  [2] _broadcast_getindex_evalf
    @ ./broadcast.jl:709
  [3] _broadcast_getindex
    @ ./broadcast.jl:682
  [4] _getindex
    @ ./broadcast.jl:706
  [5] _broadcast_getindex
    @ ./broadcast.jl:681
  [6] _getindex
    @ ./broadcast.jl:706
  [7] _getindex
    @ ./broadcast.jl:705
  [8] _broadcast_getindex
    @ ./broadcast.jl:681
  [9] getindex
    @ ./broadcast.jl:636
 [10] broadcast_kernel
    @ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:64
Reason: unsupported use of double value
Stacktrace:
  [1] Float32
    @ ./float.jl:258
  [2] ^
    @ ./math.jl:1280
  [3] _broadcast_getindex_evalf
    @ ./broadcast.jl:709
  [4] _broadcast_getindex
    @ ./broadcast.jl:682
  [5] _getindex
    @ ./broadcast.jl:706
  [6] _broadcast_getindex
    @ ./broadcast.jl:681
  [7] _getindex
    @ ./broadcast.jl:706
  [8] _getindex
    @ ./broadcast.jl:705
  [9] _broadcast_getindex
    @ ./broadcast.jl:681
 [10] getindex
    @ ./broadcast.jl:636
 [11] broadcast_kernel
    @ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:64
Reason: unsupported use of double value
Stacktrace:
  [1] ^
    @ ./math.jl:1281
  [2] _broadcast_getindex_evalf
    @ ./broadcast.jl:709
  [3] _broadcast_getindex
    @ ./broadcast.jl:682
  [4] _getindex
    @ ./broadcast.jl:706
  [5] _broadcast_getindex
    @ ./broadcast.jl:681
  [6] _getindex
    @ ./broadcast.jl:706
  [7] _getindex
    @ ./broadcast.jl:705
  [8] _broadcast_getindex
    @ ./broadcast.jl:681
  [9] getindex
    @ ./broadcast.jl:636
 [10] broadcast_kernel
    @ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:64
Reason: unsupported use of double value
Stacktrace:
  [1] Float32
    @ ./float.jl:258
  [2] ^
    @ ./math.jl:1281
  [3] _broadcast_getindex_evalf
    @ ./broadcast.jl:709
  [4] _broadcast_getindex
    @ ./broadcast.jl:682
  [5] _getindex
    @ ./broadcast.jl:706
  [6] _broadcast_getindex
    @ ./broadcast.jl:681
  [7] _getindex
    @ ./broadcast.jl:706
  [8] _getindex
    @ ./broadcast.jl:705
  [9] _broadcast_getindex
    @ ./broadcast.jl:681
 [10] getindex
    @ ./broadcast.jl:636
 [11] broadcast_kernel
    @ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:64
Hint: catch this exception as `err` and call `code_typed(err; interactive = true)` to introspect the erronous code with Cthulhu.jl
Stacktrace:
   [1] check_ir(job::GPUCompiler.CompilerJob{GPUCompiler.MetalCompilerTarget, Metal.MetalCompilerParams}, args::LLVM.Module)
     @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/validation.jl:147
   [2] macro expansion
     @ ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:440 [inlined]
   [3] macro expansion
     @ ~/.julia/packages/TimerOutputs/RsWnF/src/TimerOutput.jl:253 [inlined]
   [4] macro expansion
     @ ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:439 [inlined]
   [5] 
     @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/utils.jl:92
   [6] emit_llvm
     @ ~/.julia/packages/GPUCompiler/2mJjc/src/utils.jl:86 [inlined]
   [7] 
     @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:129
   [8] 
     @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:106
   [9] compile
     @ ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:98 [inlined]
  [10] #45
     @ ~/.julia/packages/Metal/lnkVP/src/compiler/compilation.jl:57 [inlined]
  [11] JuliaContext(f::Metal.var"#45#46"{GPUCompiler.CompilerJob{GPUCompiler.MetalCompilerTarget, Metal.MetalCompilerParams}})
     @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:47
  [12] compile(job::GPUCompiler.CompilerJob)
     @ Metal ~/.julia/packages/Metal/lnkVP/src/compiler/compilation.jl:56
  [13] actual_compilation(cache::Dict{…}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{…}, compiler::typeof(Metal.compile), linker::typeof(Metal.link))
     @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/execution.jl:125
  [14] cached_compilation(cache::Dict{…}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{…}, compiler::Function, linker::Function)
     @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/execution.jl:103
  [15] macro expansion
     @ ~/.julia/packages/Metal/lnkVP/src/compiler/execution.jl:162 [inlined]
  [16] macro expansion
     @ ./lock.jl:267 [inlined]
  [17] mtlfunction(f::GPUArrays.var"#broadcast_kernel#38", tt::Type{Tuple{…}}; name::Nothing, kwargs::@Kwargs{})
     @ Metal ~/.julia/packages/Metal/lnkVP/src/compiler/execution.jl:157
  [18] mtlfunction(f::GPUArrays.var"#broadcast_kernel#38", tt::Type{Tuple{…}})
     @ Metal ~/.julia/packages/Metal/lnkVP/src/compiler/execution.jl:155
  [19] macro expansion
     @ ~/.julia/packages/Metal/lnkVP/src/compiler/execution.jl:77 [inlined]
  [20] #launch_heuristic#90
     @ ~/.julia/packages/Metal/lnkVP/src/gpuarrays.jl:14 [inlined]
  [21] launch_heuristic
     @ ~/.julia/packages/Metal/lnkVP/src/gpuarrays.jl:12 [inlined]
  [22] _copyto!
     @ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:70 [inlined]
  [23] copyto!
     @ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:51 [inlined]
  [24] copy
     @ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:42 [inlined]
  [25] materialize
     @ ./broadcast.jl:903 [inlined]
  [26] (::Zygote.var"#1235#1238"{3, MtlVector{…}})(ȳ::MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/lib/broadcast.jl:108
  [27] #3878#back
     @ Zygote ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71 [inlined]
  [28] #5
     @ Main ~/Dropbox/research/tearfilm/repos/uode/scripts/node-metal.jl:24 [inlined]
  [29] WrappedFunction
     @ Lux ~/.julia/packages/Lux/VFyfk/src/layers/basic.jl:118 [inlined]
  [30] apply
     @ ~/.julia/packages/LuxCore/aumFq/src/LuxCore.jl:115 [inlined]
  [31] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{MtlVector{…}, Nothing})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [32] macro expansion
     @ ./tuple.jl:0 [inlined]
  [33] applychain
     @ ~/.julia/packages/Lux/VFyfk/src/layers/containers.jl:480 [inlined]
  [34] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{MtlVector{…}, Nothing})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [35] Chain
     @ ~/.julia/packages/Lux/VFyfk/src/layers/containers.jl:478 [inlined]
  [36] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{MtlVector{…}, Nothing})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [37] apply
     @ ~/.julia/packages/LuxCore/aumFq/src/LuxCore.jl:115 [inlined]
  [38] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{MtlVector{…}, Nothing})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [39] StatefulLuxLayer
     @ ~/.julia/packages/Lux/VFyfk/src/contrib/stateful.jl:47 [inlined]
  [40] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [41] dudt
     @ ~/.julia/packages/DiffEqFlux/7OfDv/src/neural_de.jl:49 [inlined]
  [42] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [43] #291
     @ ~/.julia/packages/Zygote/YYT6v/src/lib/lib.jl:206 [inlined]
  [44] #2169#back
     @ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71 [inlined]
  [45] ODEFunction
     @ ~/.julia/packages/SciMLBase/MMAmp/src/scimlfunctions.jl:2355 [inlined]
  [46] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [47] #38
     @ ~/.julia/packages/SciMLSensitivity/Eqao8/src/derivative_wrappers.jl:573 [inlined]
  [48] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [49] (::Zygote.var"#75#76"{Zygote.Pullback{…}})(Δ::MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:45
  [50] _vecjacobian!(dλ::MtlVector{…}, y::MtlVector{…}, λ::MtlVector{…}, p::ComponentVector{…}, t::Float32, S::SciMLSensitivity.ODEInterpolatingAdjointSensitivityFunction{…}, isautojacvec::ZygoteVJP, dgrad::MtlVector{…}, dy::Nothing, W::Nothing)
     @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Eqao8/src/derivative_wrappers.jl:584
  [51] #vecjacobian!#18
     @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Eqao8/src/derivative_wrappers.jl:224 [inlined]
  [52] (::SciMLSensitivity.ODEInterpolatingAdjointSensitivityFunction{…})(du::MtlVector{…}, u::MtlVector{…}, p::ComponentVector{…}, t::Float32)
     @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Eqao8/src/interpolating_adjoint.jl:0
  [53] ODEFunction
     @ ~/.julia/packages/SciMLBase/MMAmp/src/scimlfunctions.jl:2355 [inlined]
  [54] initialize!(integrator::OrdinaryDiffEq.ODEIntegrator{…}, cache::OrdinaryDiffEq.Tsit5Cache{…})
     @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/ym8TQ/src/perform_step/low_order_rk_perform_step.jl:792
  [55] __init(prob::ODEProblem{…}, alg::Tsit5{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Vector{…}, tstops::Vector{…}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::CallbackSet{…}, dense::Bool, calck::Bool, dt::Float32, dtmin::Nothing, dtmax::Float32, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Float64, reltol::Float64, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::@Kwargs{})
     @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/ym8TQ/src/solve.jl:513
  [56] __init (repeats 5 times)
     @ ~/.julia/packages/OrdinaryDiffEq/ym8TQ/src/solve.jl:10 [inlined]
  [57] __solve(::ODEProblem{…}, ::Tsit5{…}; kwargs::@Kwargs{…})
     @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/ym8TQ/src/solve.jl:5
  [58] __solve
     @ ~/.julia/packages/OrdinaryDiffEq/ym8TQ/src/solve.jl:1 [inlined]
  [59] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
     @ DiffEqBase ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:605
  [60] solve_call
     @ ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:566 [inlined]
  [61] solve_up(prob::SciMLBase.AbstractDEProblem, sensealg::Any, u0::Any, p::Any, args::Vararg{Any}; kwargs...)
     @ DiffEqBase ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:1054 [inlined]
  [62] solve_up
     @ ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:1040 [inlined]
  [63] solve(prob::ODEProblem{…}, args::Tsit5{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{…})
     @ DiffEqBase ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:977
  [64] _adjoint_sensitivities(sol::ODESolution{…}, sensealg::InterpolatingAdjoint{…}, alg::Tsit5{…}; t::Vector{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, checkpoints::Vector{…}, corfunc_analytical::Nothing, callback::Nothing, kwargs::@Kwargs{…})
     @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Eqao8/src/sensitivity_interface.jl:432
  [65] _adjoint_sensitivities
     @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Eqao8/src/sensitivity_interface.jl:390 [inlined]
  [66] #adjoint_sensitivities#63
     @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Eqao8/src/sensitivity_interface.jl:386 [inlined]
  [67] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#307"{…})(Δ::ODESolution{…})
     @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Eqao8/src/concrete_solve.jl:535
  [68] ZBack
     @ ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:211 [inlined]
  [69] (::Zygote.var"#kw_zpullback#53"{…})(dy::ODESolution{…})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:237
  [70] #291
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/lib/lib.jl:206 [inlined]
  [71] (::Zygote.var"#2169#back#293"{…})(Δ::ODESolution{…})
     @ Zygote ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71
  [72] #solve#40
     @ ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:977 [inlined]
  [73] (::Zygote.Pullback{…})(Δ::ODESolution{…})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [74] #291
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/lib/lib.jl:206 [inlined]
  [75] (::Zygote.var"#2169#back#293"{…})(Δ::ODESolution{…})
     @ Zygote ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71
  [76] solve
     @ ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:967 [inlined]
  [77] (::Zygote.Pullback{…})(Δ::ODESolution{…})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [78] #291
     @ ~/.julia/packages/Zygote/YYT6v/src/lib/lib.jl:206 [inlined]
  [79] #2169#back
     @ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71 [inlined]
  [80] NeuralODE
     @ ~/.julia/packages/DiffEqFlux/7OfDv/src/neural_de.jl:53 [inlined]
  [81] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ODESolution{…}, Nothing})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [82] predict_neuralode
     @ ~/Dropbox/research/tearfilm/repos/uode/scripts/node-metal.jl:33 [inlined]
  [83] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [84] loss_neuralode
     @ ~/Dropbox/research/tearfilm/repos/uode/scripts/node-metal.jl:36 [inlined]
  [85] #7
     @ ~/Dropbox/research/tearfilm/repos/uode/scripts/node-metal.jl:42 [inlined]
  [86] #291
     @ ~/.julia/packages/Zygote/YYT6v/src/lib/lib.jl:206 [inlined]
  [87] #2169#back
     @ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71 [inlined]
  [88] OptimizationFunction
     @ ~/.julia/packages/SciMLBase/MMAmp/src/scimlfunctions.jl:3811 [inlined]
  [89] #291
     @ ~/.julia/packages/Zygote/YYT6v/src/lib/lib.jl:206 [inlined]
  [90] #2169#back
     @ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71 [inlined]
  [91] #37
     @ ~/.julia/packages/Optimization/sIARu/ext/OptimizationZygoteExt.jl:88 [inlined]
  [92] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [93] #291
     @ ~/.julia/packages/Zygote/YYT6v/src/lib/lib.jl:206 [inlined]
  [94] #2169#back
     @ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71 [inlined]
  [95] #39
     @ ~/.julia/packages/Optimization/sIARu/ext/OptimizationZygoteExt.jl:91 [inlined]
  [96] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [97] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:45
  [98] gradient(f::Function, args::ComponentVector{Float32, MtlVector{…}, Tuple{…}})
     @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:97
  [99] (::OptimizationZygoteExt.var"#38#56"{…})(::ComponentVector{…}, ::ComponentVector{…})
     @ OptimizationZygoteExt ~/.julia/packages/Optimization/sIARu/ext/OptimizationZygoteExt.jl:91
 [100] macro expansion
     @ ~/.julia/packages/OptimizationOptimisers/hZdKg/src/OptimizationOptimisers.jl:65 [inlined]
 [101] macro expansion
     @ ~/.julia/packages/Optimization/sIARu/src/utils.jl:41 [inlined]
 [102] __solve(cache::OptimizationCache{…})
     @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/hZdKg/src/OptimizationOptimisers.jl:63
 [103] solve!(cache::OptimizationCache{…})
     @ SciMLBase ~/.julia/packages/SciMLBase/MMAmp/src/solve.jl:177
 [104] solve(::OptimizationProblem{…}, ::Adam; kwargs::@Kwargs{…})
     @ SciMLBase ~/.julia/packages/SciMLBase/MMAmp/src/solve.jl:94
Some type information was truncated. Use `show(err)` to see complete types.