Lux Neural ODE Example does not work with Metal

I directly copied the Lux example for creating a neural ODE on MNIST: MNIST Classification using NeuralODE - Lux.jl

and simply replaced CUDA with Metal to use my Mac M1 GPU. However, just from this single change the code now breaks:

cd(@__DIR__)

using Pkg
Pkg.activate(".")
Pkg.instantiate()


using Lux
using ComponentArrays,
    SciMLSensitivity,
    Metal,
    Optimisers,
    OrdinaryDiffEq,
    Random,
    Statistics,
    Zygote,
    OneHotArrays
import MLDatasets: MNIST
import MLUtils: DataLoader, splitobs
Metal.functional()

function loadmnist(batchsize, train_split)
    # Load MNIST: Only 1500 for demonstration purposes
    N = 1500
    dataset = MNIST(; split=:train)
    imgs = dataset.features[:, :, 1:N]
    labels_raw = dataset.targets[1:N]

    # Process images into (H,W,C,BS) batches
    x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
    y_data = onehotbatch(labels_raw, 0:9)
    (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)

    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize=batchsize, shuffle=true),
        # Don't shuffle the test data
        DataLoader(collect.((x_test, y_test)); batchsize=batchsize, shuffle=false))
end

struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, Se, T, K} <:
    Lux.AbstractExplicitContainerLayer{(:model,)}
 model::M
 solver::So
 sensealg::Se
 tspan::T
 kwargs::K
end

function NeuralODE(model::Lux.AbstractExplicitLayer;
 solver=Tsit5(),
 sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()),
 tspan=(0.0f0, 1.0f0),
 kwargs...)
 return NeuralODE(model, solver, sensealg, tspan, kwargs)
end

function (n::NeuralODE)(x, ps, st)
 function dudt(u, p, t)
     u_, st = n.model(u, p, st)
     return u_
 end
 prob = ODEProblem{false}(ODEFunction{false}(dudt), x, n.tspan, ps)
 return solve(prob, n.solver; sensealg=n.sensealg, n.kwargs...), st
end

function diffeqsol_to_array(x::ODESolution{T, N, <:AbstractVector{<:MtlArray}}) where {T, N}
    dev = gpu_device()
    return dropdims(dev(x); dims=3)
   end
diffeqsol_to_array(x::ODESolution) = dropdims(Array(x); dims=3)

function create_model()
    # Construct the Neural ODE Model
    model = Chain(FlattenLayer(),
        Dense(784, 20, tanh),
        NeuralODE(Chain(Dense(20, 10, tanh), Dense(10, 10, tanh), Dense(10, 20, tanh));
            save_everystep=false,
            reltol=1.0f-3,
            abstol=1.0f-3,
            save_start=false),
        diffeqsol_to_array,
        Dense(20, 10))

    rng = Random.default_rng()
    Random.seed!(rng, 0)

    ps, st = Lux.setup(rng, model)
    dev = gpu_device()
    ps = ComponentArray(ps) |> dev
    st = st |> dev

    return model, ps, st
end

logitcrossentropy(y_pred, y) = mean(-sum(y .* logsoftmax(y_pred); dims=1))

function loss(x, y, model, ps, st)
    y_pred, st = model(x, ps, st)
    return logitcrossentropy(y_pred, y), st
end

function accuracy(model, ps, st, dataloader)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    iterator = CUDA.functional() ? CuIterator(dataloader) : dataloader
    cpu_dev = cpu_device()
    for (x, y) in iterator
        target_class = onecold(cpu_dev(y))
        predicted_class = onecold(cpu_dev(first(model(x, ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end

function train()
    model, ps, st = create_model()

    # Training
    train_dataloader, test_dataloader = loadmnist(128, 0.9)

    opt = Optimisers.ADAM(0.001f0)
    st_opt = Optimisers.setup(opt, ps)

    dev = gpu_device()

    ### Warmup the Model
    img, lab = dev(train_dataloader.data[1][:, :, :, 1:1]),
    dev(train_dataloader.data[2][:, 1:1])
    loss(img, lab, model, ps, st)
    (l, _), back = pullback(p -> loss(img, lab, model, p, st), ps)
    back((one(l), nothing))

    ### Lets train the model
    nepochs = 9
    for epoch in 1:nepochs
        stime = time()
        for (x, y) in train_dataloader
            x = dev(x)
            y = dev(y)
            (l, st), back = pullback(p -> loss(x, y, model, p, st), ps)
            ### We need to add `nothing`s equal to the number of returned values - 1
            gs = back((one(l), nothing))[1]
            st_opt, ps = Optimisers.update(st_opt, ps, gs)
        end
        ttime = time() - stime

        println("[$epoch/$nepochs] \t Time $(round(ttime; digits=2))s \t Training Accuracy: " *
                "$(round(accuracy(model, ps, st, train_dataloader) * 100; digits=2))% \t " *
                "Test Accuracy: $(round(accuracy(model, ps, st, test_dataloader) * 100; digits=2))%")
    end
end

train()

and returns the following error

ERROR: ArgumentError: cannot take the CPU address of a MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}
Stacktrace:
  [1] unsafe_convert(#unused#::Type{Ptr{Float32}}, x::MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate})
    @ Metal ~/.julia/packages/Metal/qeZqc/src/array.jl:139
  [2] unsafe_convert(#unused#::Type{Ptr{Float32}}, V::SubArray{Float32, 1, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{UnitRange{Int64}}, true})
    @ Base ./subarray.jl:437
  [3] unsafe_convert(#unused#::Type{Ptr{Float32}}, a::Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{UnitRange{Int64}}, true}, Tuple{}})
    @ Base ./reshapedarray.jl:283
  [4] gemm!(transA::Char, transB::Char, alpha::Float32, A::Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, B::MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, beta::Float32, C::MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate})
    @ LinearAlgebra.BLAS /Applications/Julia-1.9.app/Contents/Resources/julia/share/julia/stdlib/v1.9/LinearAlgebra/src/blas.jl:1524
  [5] gemm_wrapper!(C::MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, tA::Char, tB::Char, A::Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, B::MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
    @ LinearAlgebra /Applications/Julia-1.9.app/Contents/Resources/julia/share/julia/stdlib/v1.9/LinearAlgebra/src/matmul.jl:674
  [6] mul!
    @ /Applications/Julia-1.9.app/Contents/Resources/julia/share/julia/stdlib/v1.9/LinearAlgebra/src/matmul.jl:161 [inlined]
  [7] mul!
    @ /Applications/Julia-1.9.app/Contents/Resources/julia/share/julia/stdlib/v1.9/LinearAlgebra/src/matmul.jl:276 [inlined]
  [8] *(A::Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, B::MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate})
    @ LinearAlgebra /Applications/Julia-1.9.app/Contents/Resources/julia/share/julia/stdlib/v1.9/LinearAlgebra/src/matmul.jl:148
  [9] Dense
    @ ~/.julia/packages/Lux/5YzHA/src/layers/basic.jl:223 [inlined]
 [10] apply(model::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, x::MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, ps::ComponentVector{Float32, SubArray{Float32, 1, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{UnitRange{Int64}}, true}, Tuple{Axis{(weight = ViewAxis(1:15680, ShapedAxis((20, 784), NamedTuple())), bias = ViewAxis(15681:15700, ShapedAxis((20, 1), NamedTuple())))}}}, st::NamedTuple{(), Tuple{}})
    @ LuxCore ~/.julia/packages/LuxCore/yC3wg/src/LuxCore.jl:100
 [11] macro expansion
    @ ./abstractarray.jl:0 [inlined]
 [12] applychain(layers::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{FlattenLayer, Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, NeuralODE{Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Real, NTuple{4, Symbol}, NamedTuple{(:save_everystep, :reltol, :abstol, :save_start), Tuple{Bool, Float32, Float32, Bool}}}}, WrappedFunction{typeof(diffeqsol_to_array)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, x::MtlArray{Float32, 4, Metal.MTL.MTLResourceStorageModePrivate}, ps::ComponentVector{Float32, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784), NamedTuple())), bias = ViewAxis(15681:15700, ShapedAxis((20, 1), NamedTuple())))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20), NamedTuple())), bias = ViewAxis(201:210, ShapedAxis((10, 1), NamedTuple())))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10), NamedTuple())), bias = ViewAxis(201:220, ShapedAxis((20, 1), NamedTuple())))))), layer_4 = 16241:16240, layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20), NamedTuple())), bias = ViewAxis(201:210, ShapedAxis((10, 1), NamedTuple())))))}}}, st::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}})
    @ Lux ~/.julia/packages/Lux/5YzHA/src/layers/containers.jl:493
 [13] Chain
    @ ~/.julia/packages/Lux/5YzHA/src/layers/containers.jl:491 [inlined]
 [14] loss(x::MtlArray{Float32, 4, Metal.MTL.MTLResourceStorageModePrivate}, y::MtlMatrix{Bool, Metal.MTL.MTLResourceStorageModePrivate}, model::Chain{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{FlattenLayer, Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, NeuralODE{Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Real, NTuple{4, Symbol}, NamedTuple{(:save_everystep, :reltol, :abstol, :save_start), Tuple{Bool, Float32, Float32, Bool}}}}, WrappedFunction{typeof(diffeqsol_to_array)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}, ps::ComponentVector{Float32, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784), NamedTuple())), bias = ViewAxis(15681:15700, ShapedAxis((20, 1), NamedTuple())))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20), NamedTuple())), bias = ViewAxis(201:210, ShapedAxis((10, 1), NamedTuple())))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10), NamedTuple())), bias = ViewAxis(201:220, ShapedAxis((20, 1), NamedTuple())))))), layer_4 = 16241:16240, layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20), NamedTuple())), bias = ViewAxis(201:210, ShapedAxis((10, 1), NamedTuple())))))}}}, st::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}})
    @ Main ~/Research/constitutive_history/examples/mnist.jl:99
 [15] train()
    @ Main ~/Research/constitutive_history/examples/mnist.jl:131
 [16] top-level scope
    @ ~/Research/constitutive_history/examples/mnist.jl:155