Here is my example:
using Lux, Optimization, OptimizationOptimisers, Zygote, OrdinaryDiffEq,
Metal, SciMLSensitivity, Random, ComponentArrays
import DiffEqFlux: NeuralODE
gdev = gpu_device()
#rng for Lux.setup
rng = Random.default_rng()
# Generate Data
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2]; length = datasize)
true_A = Matrix([-0.1f0 2.0f0; -2.0f0 -0.1f0]')
function trueODEfunc(du, u, A, t)
du .= A * u .^ 3
end
prob_trueode = ODEProblem(trueODEfunc, u0, tspan, true_A)
# Make the data into a GPU-based array if the user has a GPU
ode_data = gdev(Array(solve(prob_trueode, Tsit5(); saveat = tsteps)))
##
dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2))
u0 = Float32[2.0; 0.0] |> gdev
p, st = Lux.setup(rng, dudt2)
p = p |> ComponentArray |> gdev
st = st |> gdev
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps)
function predict_neuralode(p)
gdev(Array(first(prob_neuralode(u0, p, st))))
end
function loss_neuralode(p)
pred = predict_neuralode(p)
loss = sum(abs2, ode_data - pred)
return loss, pred
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p)
result_neuralode = Optimization.solve(optprob, Adam(0.05); maxiters = 100)
Results:
julia> optf(p, [])
(309.3814f0, Float32[2.0 2.0186965 … 1.9060048 1.8881334; 0.0 0.09428286 … 2.486669 2.5848868])
julia> result_neuralode = Optimization.solve(optprob, Adam(0.05); maxiters = 100)
ERROR: InvalidIRError: compiling MethodInstance for (::GPUArrays.var"#broadcast_kernel#38")(::Metal.mtlKernelContext, ::MtlDeviceVector{…}, ::Base.Broadcast.Broadcasted{…}, ::Int64) resulted in invalid LLVM IR
Reason: unsupported use of double value
Stacktrace:
[1] multiple call sites
@ unknown:0
Reason: unsupported use of double value
Stacktrace:
[1] power_by_squaring
@ ./intfuncs.jl:0
[2] multiple call sites
@ unknown:0
Reason: unsupported use of double value
Stacktrace:
[1] *
@ ./float.jl:411
[2] power_by_squaring
@ ./intfuncs.jl:282
[3] multiple call sites
@ unknown:0
Reason: unsupported use of double value
Stacktrace:
[1] ==
@ ./float.jl:534
[2] isone
@ ./number.jl:62
[3] power_by_squaring
@ ./intfuncs.jl:284
[4] multiple call sites
@ unknown:0
Reason: unsupported use of double value
Stacktrace:
[1] ==
@ ./float.jl:534
[2] isone
@ ./number.jl:62
[3] power_by_squaring
@ ./intfuncs.jl:285
[4] multiple call sites
@ unknown:0
Reason: unsupported use of double value
Stacktrace:
[1] multiple call sites
@ unknown:0
Reason: unsupported use of double value
Stacktrace:
[1] multiple call sites
@ unknown:0
Reason: unsupported use of double value
Stacktrace:
[1] *
@ ./float.jl:411
[2] power_by_squaring
@ ./intfuncs.jl:291
[3] multiple call sites
@ unknown:0
Reason: unsupported use of double value
Stacktrace:
[1] multiple call sites
@ unknown:0
Reason: unsupported use of double value
Stacktrace:
[1] multiple call sites
@ unknown:0
Reason: unsupported use of double value
Stacktrace:
[1] multiple call sites
@ unknown:0
Reason: unsupported use of double value
Stacktrace:
[1] *
@ ./float.jl:411
[2] power_by_squaring
@ ./intfuncs.jl:298
[3] multiple call sites
@ unknown:0
Reason: unsupported use of double value
Stacktrace:
[1] *
@ ./float.jl:411
[2] power_by_squaring
@ ./intfuncs.jl:300
[3] multiple call sites
@ unknown:0
Reason: unsupported use of double value
Stacktrace:
[1] power_by_squaring
@ ./intfuncs.jl:285
[2] multiple call sites
@ unknown:0
Reason: unsupported use of double value
Stacktrace:
[1] ^
@ ./math.jl:0
[2] _broadcast_getindex_evalf
@ ./broadcast.jl:709
[3] _broadcast_getindex
@ ./broadcast.jl:682
[4] _getindex
@ ./broadcast.jl:706
[5] _broadcast_getindex
@ ./broadcast.jl:681
[6] _getindex
@ ./broadcast.jl:706
[7] _getindex
@ ./broadcast.jl:705
[8] _broadcast_getindex
@ ./broadcast.jl:681
[9] getindex
@ ./broadcast.jl:636
[10] broadcast_kernel
@ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:64
Reason: unsupported use of double value
Stacktrace:
[1] /
@ ./float.jl:412
[2] inv
@ ./number.jl:255
[3] ^
@ ./math.jl:1280
[4] _broadcast_getindex_evalf
@ ./broadcast.jl:709
[5] _broadcast_getindex
@ ./broadcast.jl:682
[6] _getindex
@ ./broadcast.jl:706
[7] _broadcast_getindex
@ ./broadcast.jl:681
[8] _getindex
@ ./broadcast.jl:706
[9] _getindex
@ ./broadcast.jl:705
[10] _broadcast_getindex
@ ./broadcast.jl:681
[11] getindex
@ ./broadcast.jl:636
[12] broadcast_kernel
@ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:64
Reason: unsupported use of double value
Stacktrace:
[1] ^
@ ./math.jl:1280
[2] _broadcast_getindex_evalf
@ ./broadcast.jl:709
[3] _broadcast_getindex
@ ./broadcast.jl:682
[4] _getindex
@ ./broadcast.jl:706
[5] _broadcast_getindex
@ ./broadcast.jl:681
[6] _getindex
@ ./broadcast.jl:706
[7] _getindex
@ ./broadcast.jl:705
[8] _broadcast_getindex
@ ./broadcast.jl:681
[9] getindex
@ ./broadcast.jl:636
[10] broadcast_kernel
@ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:64
Reason: unsupported use of double value
Stacktrace:
[1] Float32
@ ./float.jl:258
[2] ^
@ ./math.jl:1280
[3] _broadcast_getindex_evalf
@ ./broadcast.jl:709
[4] _broadcast_getindex
@ ./broadcast.jl:682
[5] _getindex
@ ./broadcast.jl:706
[6] _broadcast_getindex
@ ./broadcast.jl:681
[7] _getindex
@ ./broadcast.jl:706
[8] _getindex
@ ./broadcast.jl:705
[9] _broadcast_getindex
@ ./broadcast.jl:681
[10] getindex
@ ./broadcast.jl:636
[11] broadcast_kernel
@ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:64
Reason: unsupported use of double value
Stacktrace:
[1] ^
@ ./math.jl:1281
[2] _broadcast_getindex_evalf
@ ./broadcast.jl:709
[3] _broadcast_getindex
@ ./broadcast.jl:682
[4] _getindex
@ ./broadcast.jl:706
[5] _broadcast_getindex
@ ./broadcast.jl:681
[6] _getindex
@ ./broadcast.jl:706
[7] _getindex
@ ./broadcast.jl:705
[8] _broadcast_getindex
@ ./broadcast.jl:681
[9] getindex
@ ./broadcast.jl:636
[10] broadcast_kernel
@ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:64
Reason: unsupported use of double value
Stacktrace:
[1] Float32
@ ./float.jl:258
[2] ^
@ ./math.jl:1281
[3] _broadcast_getindex_evalf
@ ./broadcast.jl:709
[4] _broadcast_getindex
@ ./broadcast.jl:682
[5] _getindex
@ ./broadcast.jl:706
[6] _broadcast_getindex
@ ./broadcast.jl:681
[7] _getindex
@ ./broadcast.jl:706
[8] _getindex
@ ./broadcast.jl:705
[9] _broadcast_getindex
@ ./broadcast.jl:681
[10] getindex
@ ./broadcast.jl:636
[11] broadcast_kernel
@ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:64
Hint: catch this exception as `err` and call `code_typed(err; interactive = true)` to introspect the erronous code with Cthulhu.jl
Stacktrace:
[1] check_ir(job::GPUCompiler.CompilerJob{GPUCompiler.MetalCompilerTarget, Metal.MetalCompilerParams}, args::LLVM.Module)
@ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/validation.jl:147
[2] macro expansion
@ ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:440 [inlined]
[3] macro expansion
@ ~/.julia/packages/TimerOutputs/RsWnF/src/TimerOutput.jl:253 [inlined]
[4] macro expansion
@ ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:439 [inlined]
[5]
@ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/utils.jl:92
[6] emit_llvm
@ ~/.julia/packages/GPUCompiler/2mJjc/src/utils.jl:86 [inlined]
[7]
@ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:129
[8]
@ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:106
[9] compile
@ ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:98 [inlined]
[10] #45
@ ~/.julia/packages/Metal/lnkVP/src/compiler/compilation.jl:57 [inlined]
[11] JuliaContext(f::Metal.var"#45#46"{GPUCompiler.CompilerJob{GPUCompiler.MetalCompilerTarget, Metal.MetalCompilerParams}})
@ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:47
[12] compile(job::GPUCompiler.CompilerJob)
@ Metal ~/.julia/packages/Metal/lnkVP/src/compiler/compilation.jl:56
[13] actual_compilation(cache::Dict{…}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{…}, compiler::typeof(Metal.compile), linker::typeof(Metal.link))
@ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/execution.jl:125
[14] cached_compilation(cache::Dict{…}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{…}, compiler::Function, linker::Function)
@ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/execution.jl:103
[15] macro expansion
@ ~/.julia/packages/Metal/lnkVP/src/compiler/execution.jl:162 [inlined]
[16] macro expansion
@ ./lock.jl:267 [inlined]
[17] mtlfunction(f::GPUArrays.var"#broadcast_kernel#38", tt::Type{Tuple{…}}; name::Nothing, kwargs::@Kwargs{})
@ Metal ~/.julia/packages/Metal/lnkVP/src/compiler/execution.jl:157
[18] mtlfunction(f::GPUArrays.var"#broadcast_kernel#38", tt::Type{Tuple{…}})
@ Metal ~/.julia/packages/Metal/lnkVP/src/compiler/execution.jl:155
[19] macro expansion
@ ~/.julia/packages/Metal/lnkVP/src/compiler/execution.jl:77 [inlined]
[20] #launch_heuristic#90
@ ~/.julia/packages/Metal/lnkVP/src/gpuarrays.jl:14 [inlined]
[21] launch_heuristic
@ ~/.julia/packages/Metal/lnkVP/src/gpuarrays.jl:12 [inlined]
[22] _copyto!
@ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:70 [inlined]
[23] copyto!
@ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:51 [inlined]
[24] copy
@ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:42 [inlined]
[25] materialize
@ ./broadcast.jl:903 [inlined]
[26] (::Zygote.var"#1235#1238"{3, MtlVector{…}})(ȳ::MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/lib/broadcast.jl:108
[27] #3878#back
@ Zygote ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71 [inlined]
[28] #5
@ Main ~/Dropbox/research/tearfilm/repos/uode/scripts/node-metal.jl:24 [inlined]
[29] WrappedFunction
@ Lux ~/.julia/packages/Lux/VFyfk/src/layers/basic.jl:118 [inlined]
[30] apply
@ ~/.julia/packages/LuxCore/aumFq/src/LuxCore.jl:115 [inlined]
[31] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{MtlVector{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[32] macro expansion
@ ./tuple.jl:0 [inlined]
[33] applychain
@ ~/.julia/packages/Lux/VFyfk/src/layers/containers.jl:480 [inlined]
[34] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{MtlVector{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[35] Chain
@ ~/.julia/packages/Lux/VFyfk/src/layers/containers.jl:478 [inlined]
[36] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{MtlVector{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[37] apply
@ ~/.julia/packages/LuxCore/aumFq/src/LuxCore.jl:115 [inlined]
[38] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{MtlVector{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[39] StatefulLuxLayer
@ ~/.julia/packages/Lux/VFyfk/src/contrib/stateful.jl:47 [inlined]
[40] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[41] dudt
@ ~/.julia/packages/DiffEqFlux/7OfDv/src/neural_de.jl:49 [inlined]
[42] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[43] #291
@ ~/.julia/packages/Zygote/YYT6v/src/lib/lib.jl:206 [inlined]
[44] #2169#back
@ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71 [inlined]
[45] ODEFunction
@ ~/.julia/packages/SciMLBase/MMAmp/src/scimlfunctions.jl:2355 [inlined]
[46] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[47] #38
@ ~/.julia/packages/SciMLSensitivity/Eqao8/src/derivative_wrappers.jl:573 [inlined]
[48] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[49] (::Zygote.var"#75#76"{Zygote.Pullback{…}})(Δ::MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:45
[50] _vecjacobian!(dλ::MtlVector{…}, y::MtlVector{…}, λ::MtlVector{…}, p::ComponentVector{…}, t::Float32, S::SciMLSensitivity.ODEInterpolatingAdjointSensitivityFunction{…}, isautojacvec::ZygoteVJP, dgrad::MtlVector{…}, dy::Nothing, W::Nothing)
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Eqao8/src/derivative_wrappers.jl:584
[51] #vecjacobian!#18
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Eqao8/src/derivative_wrappers.jl:224 [inlined]
[52] (::SciMLSensitivity.ODEInterpolatingAdjointSensitivityFunction{…})(du::MtlVector{…}, u::MtlVector{…}, p::ComponentVector{…}, t::Float32)
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Eqao8/src/interpolating_adjoint.jl:0
[53] ODEFunction
@ ~/.julia/packages/SciMLBase/MMAmp/src/scimlfunctions.jl:2355 [inlined]
[54] initialize!(integrator::OrdinaryDiffEq.ODEIntegrator{…}, cache::OrdinaryDiffEq.Tsit5Cache{…})
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/ym8TQ/src/perform_step/low_order_rk_perform_step.jl:792
[55] __init(prob::ODEProblem{…}, alg::Tsit5{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Vector{…}, tstops::Vector{…}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::CallbackSet{…}, dense::Bool, calck::Bool, dt::Float32, dtmin::Nothing, dtmax::Float32, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Float64, reltol::Float64, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::@Kwargs{})
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/ym8TQ/src/solve.jl:513
[56] __init (repeats 5 times)
@ ~/.julia/packages/OrdinaryDiffEq/ym8TQ/src/solve.jl:10 [inlined]
[57] __solve(::ODEProblem{…}, ::Tsit5{…}; kwargs::@Kwargs{…})
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/ym8TQ/src/solve.jl:5
[58] __solve
@ ~/.julia/packages/OrdinaryDiffEq/ym8TQ/src/solve.jl:1 [inlined]
[59] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:605
[60] solve_call
@ ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:566 [inlined]
[61] solve_up(prob::SciMLBase.AbstractDEProblem, sensealg::Any, u0::Any, p::Any, args::Vararg{Any}; kwargs...)
@ DiffEqBase ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:1054 [inlined]
[62] solve_up
@ ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:1040 [inlined]
[63] solve(prob::ODEProblem{…}, args::Tsit5{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:977
[64] _adjoint_sensitivities(sol::ODESolution{…}, sensealg::InterpolatingAdjoint{…}, alg::Tsit5{…}; t::Vector{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, checkpoints::Vector{…}, corfunc_analytical::Nothing, callback::Nothing, kwargs::@Kwargs{…})
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Eqao8/src/sensitivity_interface.jl:432
[65] _adjoint_sensitivities
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Eqao8/src/sensitivity_interface.jl:390 [inlined]
[66] #adjoint_sensitivities#63
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Eqao8/src/sensitivity_interface.jl:386 [inlined]
[67] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#307"{…})(Δ::ODESolution{…})
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Eqao8/src/concrete_solve.jl:535
[68] ZBack
@ ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:211 [inlined]
[69] (::Zygote.var"#kw_zpullback#53"{…})(dy::ODESolution{…})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:237
[70] #291
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/lib/lib.jl:206 [inlined]
[71] (::Zygote.var"#2169#back#293"{…})(Δ::ODESolution{…})
@ Zygote ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71
[72] #solve#40
@ ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:977 [inlined]
[73] (::Zygote.Pullback{…})(Δ::ODESolution{…})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[74] #291
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/lib/lib.jl:206 [inlined]
[75] (::Zygote.var"#2169#back#293"{…})(Δ::ODESolution{…})
@ Zygote ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71
[76] solve
@ ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:967 [inlined]
[77] (::Zygote.Pullback{…})(Δ::ODESolution{…})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[78] #291
@ ~/.julia/packages/Zygote/YYT6v/src/lib/lib.jl:206 [inlined]
[79] #2169#back
@ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71 [inlined]
[80] NeuralODE
@ ~/.julia/packages/DiffEqFlux/7OfDv/src/neural_de.jl:53 [inlined]
[81] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ODESolution{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[82] predict_neuralode
@ ~/Dropbox/research/tearfilm/repos/uode/scripts/node-metal.jl:33 [inlined]
[83] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[84] loss_neuralode
@ ~/Dropbox/research/tearfilm/repos/uode/scripts/node-metal.jl:36 [inlined]
[85] #7
@ ~/Dropbox/research/tearfilm/repos/uode/scripts/node-metal.jl:42 [inlined]
[86] #291
@ ~/.julia/packages/Zygote/YYT6v/src/lib/lib.jl:206 [inlined]
[87] #2169#back
@ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71 [inlined]
[88] OptimizationFunction
@ ~/.julia/packages/SciMLBase/MMAmp/src/scimlfunctions.jl:3811 [inlined]
[89] #291
@ ~/.julia/packages/Zygote/YYT6v/src/lib/lib.jl:206 [inlined]
[90] #2169#back
@ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71 [inlined]
[91] #37
@ ~/.julia/packages/Optimization/sIARu/ext/OptimizationZygoteExt.jl:88 [inlined]
[92] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[93] #291
@ ~/.julia/packages/Zygote/YYT6v/src/lib/lib.jl:206 [inlined]
[94] #2169#back
@ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71 [inlined]
[95] #39
@ ~/.julia/packages/Optimization/sIARu/ext/OptimizationZygoteExt.jl:91 [inlined]
[96] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[97] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:45
[98] gradient(f::Function, args::ComponentVector{Float32, MtlVector{…}, Tuple{…}})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:97
[99] (::OptimizationZygoteExt.var"#38#56"{…})(::ComponentVector{…}, ::ComponentVector{…})
@ OptimizationZygoteExt ~/.julia/packages/Optimization/sIARu/ext/OptimizationZygoteExt.jl:91
[100] macro expansion
@ ~/.julia/packages/OptimizationOptimisers/hZdKg/src/OptimizationOptimisers.jl:65 [inlined]
[101] macro expansion
@ ~/.julia/packages/Optimization/sIARu/src/utils.jl:41 [inlined]
[102] __solve(cache::OptimizationCache{…})
@ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/hZdKg/src/OptimizationOptimisers.jl:63
[103] solve!(cache::OptimizationCache{…})
@ SciMLBase ~/.julia/packages/SciMLBase/MMAmp/src/solve.jl:177
[104] solve(::OptimizationProblem{…}, ::Adam; kwargs::@Kwargs{…})
@ SciMLBase ~/.julia/packages/SciMLBase/MMAmp/src/solve.jl:94
Some type information was truncated. Use `show(err)` to see complete types.