NeuralODE training failed on GPU with Enzyme

Hi, when I try to train a NeuralODE with Discretecallback using 'sensealg=InterpolatingAdjoint(autojacvec=EnzymeVJP(), checkpointing=true) ’ ,
I get:

Enzyme execution failed.
Enzyme: unhandled augmented forward for jl_f_finalizer
Stacktrace:
  [1] finalizer
    @ ./gcutils.jl:87
  [2] _
    @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:83
  [3] CuArray
    @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:79
  [4] derive
    @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:799
  [5] unsafe_contiguous_view
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:319
  [6] unsafe_view
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:314
  [7] view
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:310
  [8] maybeview
    @ ./views.jl:148
  [9] macro expansion
    @ ~/.julia/packages/ComponentArrays/xO4hy/src/array_interface.jl:0
 [10] _getindex
    @ ~/.julia/packages/ComponentArrays/xO4hy/src/array_interface.jl:119
 [11] getproperty
    @ ~/.julia/packages/ComponentArrays/xO4hy/src/namedtuple_interface.jl:14
 [12] macro expansion
    @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:0
 [13] applychain
    @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:520


Stacktrace:
  [1] finalizer
    @ ./gcutils.jl:87 [inlined]
  [2] _
    @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:83 [inlined]
  [3] CuArray
    @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:79 [inlined]
  [4] derive
    @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:799 [inlined]
  [5] unsafe_contiguous_view
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:319 [inlined]
  [6] unsafe_view
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:314 [inlined]
  [7] view
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:310 [inlined]
  [8] maybeview
    @ ./views.jl:148 [inlined]
  [9] macro expansion
    @ ~/.julia/packages/ComponentArrays/xO4hy/src/array_interface.jl:0 [inlined]
 [10] _getindex
    @ ~/.julia/packages/ComponentArrays/xO4hy/src/array_interface.jl:119 [inlined]
 [11] getproperty
    @ ~/.julia/packages/ComponentArrays/xO4hy/src/namedtuple_interface.jl:14 [inlined]
 [12] macro expansion
    @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:0 [inlined]
 [13] applychain
    @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:520
 [14] Chain
    @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:518 [inlined]
 [15] apply
    @ ~/.julia/packages/LuxCore/yzx6E/src/LuxCore.jl:171 [inlined]
 [16] dudt
    @ ./In[92]:24 [inlined]
 [17] dudt
    @ ./In[92]:20 [inlined]
 [18] ODEFunction
    @ ~/.julia/packages/SciMLBase/Q1klk/src/scimlfunctions.jl:2335 [inlined]
 [19] #138
    @ ~/.julia/packages/SciMLSensitivity/se3y4/src/adjoint_common.jl:490 [inlined]
 [20] diffejulia__138_128700_inner_1wrap
    @ ~/.julia/packages/SciMLSensitivity/se3y4/src/adjoint_common.jl:0
 [21] macro expansion
    @ ~/.julia/packages/Enzyme/XGb4o/src/compiler.jl:7049 [inlined]
 [22] enzyme_call
    @ ~/.julia/packages/Enzyme/XGb4o/src/compiler.jl:6658 [inlined]
 [23] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/XGb4o/src/compiler.jl:6535 [inlined]
 [24] autodiff
    @ ~/.julia/packages/Enzyme/XGb4o/src/Enzyme.jl:320 [inlined]
 [25] _vecjacobian!(dλ::CuArray{Float32, 1, CUDA.DeviceMemory}, y::CuArray{Float32, 1, CUDA.DeviceMemory}, λ::CuArray{Float32, 1, CUDA.DeviceMemory}, p::ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:80, Axis(weight = ViewAxis(1:64, ShapedAxis((16, 4))), bias = ViewAxis(65:80, ShapedAxis((16, 1))))), layer_2 = ViewAxis(81:352, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_3 = ViewAxis(353:420, Axis(weight = ViewAxis(1:64, ShapedAxis((4, 16))), bias = ViewAxis(65:68, ShapedAxis((4, 1))))))}}}, t::Float32, S::SciMLSensitivity.ODEInterpolatingAdjointSensitivityFunction{SciMLSensitivity.AdjointDiffCache{Nothing, SciMLSensitivity.var"#138#142"{ODEFunction{false, SciMLBase.FullSpecialize, var"#dudt#52"{@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, var"#dudt#51#53"{NeuralODE{Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_3::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}, Nothing}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, LuxCUDADevice{Nothing}, @Kwargs{sensealg::InterpolatingAdjoint{0, true, Val{:central}, EnzymeVJP}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Tuple{CuArray{Float32, 1, CUDA.DeviceMemory}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:80, Axis(weight = ViewAxis(1:64, ShapedAxis((16, 4))), bias = ViewAxis(65:80, ShapedAxis((16, 1))))), layer_2 = ViewAxis(81:352, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_3 = ViewAxis(353:420, Axis(weight = ViewAxis(1:64, ShapedAxis((4, 16))), bias = ViewAxis(65:68, ShapedAxis((4, 1))))))}}}, CuArray{Float32, 1, CUDA.DeviceMemory}, CuArray{Float32, 1, CUDA.DeviceMemory}, CuArray{Float32, 1, CUDA.DeviceMemory}, SciMLSensitivity.var"#138#142"{ODEFunction{false, SciMLBase.FullSpecialize, var"#dudt#52"{@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, var"#dudt#51#53"{NeuralODE{Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_3::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}, Nothing}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, LuxCUDADevice{Nothing}, @Kwargs{sensealg::InterpolatingAdjoint{0, true, Val{:central}, EnzymeVJP}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}}}, Nothing, Nothing, Nothing, Nothing, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}, InterpolatingAdjoint{0, true, Val{:central}, EnzymeVJP}, CuArray{Float32, 1, CUDA.DeviceMemory}, ODESolution{Float32, 2, Vector{CuArray{Float32, 1, CUDA.DeviceMemory}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{CuArray{Float32, 1, CUDA.DeviceMemory}}}, Nothing, ODEProblem{CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Float32, Float32}, false, ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:80, Axis(weight = ViewAxis(1:64, ShapedAxis((16, 4))), bias = ViewAxis(65:80, ShapedAxis((16, 1))))), layer_2 = ViewAxis(81:352, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_3 = ViewAxis(353:420, Axis(weight = ViewAxis(1:64, ShapedAxis((4, 16))), bias = ViewAxis(65:68, ShapedAxis((4, 1))))))}}}, ODEFunction{false, SciMLBase.FullSpecialize, var"#dudt#52"{@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, var"#dudt#51#53"{NeuralODE{Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_3::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}, Nothing}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, LuxCUDADevice{Nothing}, @Kwargs{sensealg::InterpolatingAdjoint{0, true, Val{:central}, EnzymeVJP}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, @Kwargs{callback::CallbackSet{Tuple{}, Tuple{DiscreteCallback{DiffEqCallbacks.var"#109#113", SciMLSensitivity.TrackedAffect{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:80, Axis(weight = ViewAxis(1:64, ShapedAxis((16, 4))), bias = ViewAxis(65:80, ShapedAxis((16, 1))))), layer_2 = ViewAxis(81:352, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_3 = ViewAxis(353:420, Axis(weight = ViewAxis(1:64, ShapedAxis((4, 16))), bias = ViewAxis(65:68, ShapedAxis((4, 1))))))}}}, DiffEqCallbacks.var"#110#114"{typeof(affect!)}, Nothing, Int64}, DiffEqCallbacks.var"#111#115"{typeof(SciMLBase.INITIALIZE_DEFAULT), Bool, typeof(affect!)}, typeof(SciMLBase.FINALIZE_DEFAULT)}}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, SciMLBase.FullSpecialize, var"#dudt#52"{@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, var"#dudt#51#53"{NeuralODE{Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_3::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}, Nothing}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, LuxCUDADevice{Nothing}, @Kwargs{sensealg::InterpolatingAdjoint{0, true, Val{:central}, EnzymeVJP}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Vector{CuArray{Float32, 1, CUDA.DeviceMemory}}, Vector{Float32}, Vector{Vector{CuArray{Float32, 1, CUDA.DeviceMemory}}}, Nothing, OrdinaryDiffEq.Tsit5ConstantCache, Nothing}, SciMLBase.DEStats, Nothing, Nothing, Nothing}, SciMLSensitivity.CheckpointSolution{ODESolution{Float32, 2, Vector{CuArray{Float32, 1, CUDA.DeviceMemory}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{CuArray{Float32, 1, CUDA.DeviceMemory}}}, Nothing, ODEProblem{CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Float32, Float32}, false, ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:80, Axis(weight = ViewAxis(1:64, ShapedAxis((16, 4))), bias = ViewAxis(65:80, ShapedAxis((16, 1))))), layer_2 = ViewAxis(81:352, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_3 = ViewAxis(353:420, Axis(weight = ViewAxis(1:64, ShapedAxis((4, 16))), bias = ViewAxis(65:68, ShapedAxis((4, 1))))))}}}, ODEFunction{false, SciMLBase.FullSpecialize, var"#dudt#52"{@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, var"#dudt#51#53"{NeuralODE{Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_3::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}, Nothing}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, LuxCUDADevice{Nothing}, @Kwargs{sensealg::InterpolatingAdjoint{0, true, Val{:central}, EnzymeVJP}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, @Kwargs{callback::CallbackSet{Tuple{}, Tuple{DiscreteCallback{DiffEqCallbacks.var"#109#113", SciMLSensitivity.TrackedAffect{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:80, Axis(weight = ViewAxis(1:64, ShapedAxis((16, 4))), bias = ViewAxis(65:80, ShapedAxis((16, 1))))), layer_2 = ViewAxis(81:352, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_3 = ViewAxis(353:420, Axis(weight = ViewAxis(1:64, ShapedAxis((4, 16))), bias = ViewAxis(65:68, ShapedAxis((4, 1))))))}}}, DiffEqCallbacks.var"#110#114"{typeof(affect!)}, Nothing, Int64}, DiffEqCallbacks.var"#111#115"{typeof(SciMLBase.INITIALIZE_DEFAULT), Bool, typeof(affect!)}, typeof(SciMLBase.FINALIZE_DEFAULT)}}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, SciMLBase.FullSpecialize, var"#dudt#52"{@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, var"#dudt#51#53"{NeuralODE{Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_3::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}, Nothing}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, LuxCUDADevice{Nothing}, @Kwargs{sensealg::InterpolatingAdjoint{0, true, Val{:central}, EnzymeVJP}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Vector{CuArray{Float32, 1, CUDA.DeviceMemory}}, Vector{Float32}, Vector{Vector{CuArray{Float32, 1, CUDA.DeviceMemory}}}, Nothing, OrdinaryDiffEq.Tsit5ConstantCache, Nothing}, SciMLBase.DEStats, Nothing, Nothing, Nothing}, Vector{Tuple{Float32, Float32}}, @NamedTuple{reltol::Float64, abstol::Float64}, Nothing}, ODEProblem{CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Float32, Float32}, false, ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:80, Axis(weight = ViewAxis(1:64, ShapedAxis((16, 4))), bias = ViewAxis(65:80, ShapedAxis((16, 1))))), layer_2 = ViewAxis(81:352, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_3 = ViewAxis(353:420, Axis(weight = ViewAxis(1:64, ShapedAxis((4, 16))), bias = ViewAxis(65:68, ShapedAxis((4, 1))))))}}}, ODEFunction{false, SciMLBase.FullSpecialize, var"#dudt#52"{@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, var"#dudt#51#53"{NeuralODE{Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_3::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}, Nothing}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, LuxCUDADevice{Nothing}, @Kwargs{sensealg::InterpolatingAdjoint{0, true, Val{:central}, EnzymeVJP}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, @Kwargs{callback::CallbackSet{Tuple{}, Tuple{DiscreteCallback{DiffEqCallbacks.var"#109#113", SciMLSensitivity.TrackedAffect{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:80, Axis(weight = ViewAxis(1:64, ShapedAxis((16, 4))), bias = ViewAxis(65:80, ShapedAxis((16, 1))))), layer_2 = ViewAxis(81:352, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_3 = ViewAxis(353:420, Axis(weight = ViewAxis(1:64, ShapedAxis((4, 16))), bias = ViewAxis(65:68, ShapedAxis((4, 1))))))}}}, DiffEqCallbacks.var"#110#114"{typeof(affect!)}, Nothing, Int64}, DiffEqCallbacks.var"#111#115"{typeof(SciMLBase.INITIALIZE_DEFAULT), Bool, typeof(affect!)}, typeof(SciMLBase.FINALIZE_DEFAULT)}}}}, SciMLBase.StandardODEProblem}, ODEFunction{false, SciMLBase.FullSpecialize, var"#dudt#52"{@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, var"#dudt#51#53"{NeuralODE{Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_3::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}, Nothing}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, LuxCUDADevice{Nothing}, @Kwargs{sensealg::InterpolatingAdjoint{0, true, Val{:central}, EnzymeVJP}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}}, isautojacvec::EnzymeVJP, dgrad::CuArray{Float32, 1, CUDA.DeviceMemory}, dy::Nothing, W::Nothing)
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/se3y4/src/derivative_wrappers.jl:728
...

Do you guys meet before? :frowning:

Isolate it to just the Enzyme call on the ODE function:

1 Like

Thanks Chris!

I’ve tried to isolate it, but still got the same error :frowning:

Here is my main code:

struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, T, D, K} <:
       Lux.AbstractExplicitContainerLayer{(:model,)}
    model::M
    solver::So
    tspan::T
    device::D
    kwargs::K
end

function NeuralODE(model::Lux.AbstractExplicitLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), gpu=nothing, kwargs...)
    device = DetermineDevice(gpu=gpu)
    NeuralODE{typeof(model), typeof(solver), typeof(tspan), typeof(device), typeof(kwargs)}(model, solver, tspan, device, kwargs)
end


function (n::NeuralODE)(u0, ps, st, cb)

    function dudt(u, p, t; st=st)
        u_, st = Lux.apply(n.model, u, p, st)
        return u_
    end
    
    prob = ODEProblem{false}(ODEFunction{false}(dudt), u0, n.tspan, ps)
    
    sensealg = get(n.kwargs, :sensealg, InterpolatingAdjoint(autojacvec=ZygoteVJP(), checkpointing=true))
    
    tsteps = n.tspan[1]:n.tspan[2]
    
    sol = solve(prob, n.solver, saveat=tsteps, callback = cb, sensealg = sensealg)
    
    return DeviceArray(n.device, Array(sol)), st
end

function loss_neuralode(model, u0, p, st, cb)

    pred, st = model(u0, p, st, cb)
    
    loss = mean((pred .- observed_data).^2)
    
    return loss
end


function train_neuralode!(model, u0, p, st, cb, loss_func, opt_state, η_schedule; N_epochs=1, verbose=true, compute_initial_error::Bool=true, scheduler_offset::Int=0)
    
    best_p = copy(p)
    results = (i_epoch = Int[], train_loss=Float32[], learning_rate=Float32[], duration=Float32[], valid_loss=Float32[], test_loss=Float32[], loss_min=[Inf32], i_epoch_min=[1])
    
    progress = Progress(N_epochs, 1)
    
    # initial error 
    lowest_train_err = compute_initial_error ? loss_func(model, u0, p, st, cb) : Inf
    
    if verbose 
        println("______________________________")
        println("starting training epoch")
    end

    for i_epoch in 1:N_epochs

        Optimisers.adjust!(opt_state, η_schedule(i_epoch + scheduler_offset)) 

        epoch_start_time = time()

        losses = zeros(Float32, 1)

        loss_p(p) = loss_func(model, u0, p, st, cb)
        println("training loss:", loss_p(p))

        l, gs = Zygote.withgradient(loss_p, p)

        losses = l
        opt_state, p = Optimisers.update(opt_state, p, gs[1])


        train_err = l
        epoch_time = time() - epoch_start_time

        push!(results[:i_epoch], i_epoch)
        push!(results[:train_loss], train_err)
        push!(results[:learning_rate], η_schedule(i_epoch))
        push!(results[:duration], epoch_time)

        if train_err < lowest_train_err
            lowest_train_err = train_err
            best_p = deepcopy(p)
            results[:loss_min] .= lowest_train_err
            results[:i_epoch_min] .= i_epoch
        end

    end


    return model, best_p, st, results
    
end


#callback
times = [274.0f0]
affect!(integrator) = integrator.u[1:4] .= 0.0f0
cb = PresetTimeCallback(times, affect!; save_positions = (false, false)) #save_positions = (true, true)

const device = DetermineDevice()
CUDA.allowscalar(false) # Makes sure no slow operations are occurring


nn = Chain(
          Dense(4, 16, tanh),
          Dense(16, 16, tanh), 
          Dense(16, 4)
)

rng = Xoshiro(0)
p, st = Lux.setup(rng, nn)
p = ComponentArray(p) |> gdev
st = st |> gdev

u0 = Float32[0.0, 8.0, 0.0, 12.0] |> gdev
tspan = (0.0f0, 365.0f0) 

neural_ode = NeuralODE(nn; solver=Tsit5(), tspan = tspan, sensealg=InterpolatingAdjoint(autojacvec=EnzymeVJP(), checkpointing=true))

loss = loss_neuralode

opt = Optimisers.AdamW(1f-3, (9f-1, 9.99f-1), 1f-6)
opt_state = Optimisers.setup(opt, p)
η_schedule = SinExp(λ0=1f-3,λ1=1f-5,period=20,decay=0.975f0)

println("starting training...")
neural_de, ps, st, results_ad = train_neuralode!(neural_ode, u0, p, st, cb, loss, opt_state, η_schedule; N_epochs=5, verbose=true)

The code works on CPU, but not on GPU.

That’s not isolated. There should be no ODE.

That looks like a bug on the CUDA.jl enzyme extension not handling some cuarray constructor properly. File an issue on CUDA.jl and cc me (wsmoses)?

2 Likes

thanks, sure!