SciML: Issues using `GaussAdjoint` on an `EnsembleProblem` with callbacks

Hello there,

I am rather new to Julia, and I am trying to implement a Neural ODE model that relies on event callbacks to terminate integration. I am using the GaussAdjoint, given that it supports callbacks and that it was recommended on the docs.

This is my current forward pass, which i am running with use_gpu=false:


abstract type AITNODELayer <: AbstractLuxContainerLayer{(:vector_field, :halting_unit)} end

@concrete struct AITNODE <: AITNODELayer
    vector_field <: AbstractLuxLayer
    halting_unit <: AbstractLuxLayer
    dim::Int
    eps::Float32
    use_gpu::Bool
    tspan::Any
    args::Any
    kwargs::Any
end

function (n::AITNODE)(x::AbstractMatrix{<:Number}, θ, st)
    st_vf = st.vector_field
    st_hu = st.halting_unit
    D = n.dim
    B = size(x, 2)
    use_gpu = n.use_gpu

    function dudt(u, θ, t)
        x_state = u[1:D]
        dx, _ = Lux.apply(n.vector_field, x_state, θ.vector_field, st_vf)
        h, _ = Lux.apply(n.halting_unit, x_state, θ.halting_unit, st_hu)
        return vcat(dx, h, h .* x_state)
    end

    pad_matrix = @ignore_derivatives fill!(similar(x, D + 1, B), 0)
    u0_batch = vcat(x, pad_matrix)

    ff = ODEFunction{false}(dudt)
    base_prob = ODEProblem{false}(ff, u0_batch[:, 1], n.tspan, θ)

    function prob_func(prob, ctx)
        u0_i = u0_batch[:, ctx.sim_id]
        remake(prob; u0 = u0_i)
    end

    function condition(u, t, integrator)
        return (one(eltype(u)) - n.eps) - sum(u[(D + 1):(D + 1)])
    end

    cb = ContinuousCallback(condition, terminate!; save_positions = (false, true))

    ensemble_prob = EnsembleProblem(
        base_prob;
        prob_func = prob_func,
        safetycopy = false
    )

    ensemblealg = use_gpu ? EnsembleGPUArray(CUDA.CUDABackend()) : EnsembleThreads()

    ensemble_sol = solve(ensemble_prob, n.args..., ensemblealg;
        sensealg = GaussAdjoint(autojacvec = EnzymeVJP()),
        callback = cb,
        trajectories = B,
        n.kwargs...
    )

    # Extract x_hat and T_star from each trajectory's final state
    x_hats = hcat(map(ensemble_sol.u) do sol
        z = sol.u[end]
        x_state = z[1:D]
        A_val = z[(D + 1):(D + 1)]
        x_bar = z[(D + 2):(2D + 1)]
        x_bar .+ (one(eltype(z)) .- A_val) .* x_state
    end...)
    T_stars = @ignore_derivatives [sol.t[end] for sol in ensemble_sol.u]

    return (x_hats, T_stars), (vector_field = st_vf, halting_unit = st_hu)
end

And im getting the following error

ERROR: LoadError: BoundsError: attempt to access Tuple{} at index [0]
Stacktrace:
  [1] getindex(t::Tuple, i::Int64)
    @ Base ./tuple.jl:31
  [2] (::SciMLSensitivity.var"#df_iip#338"{Float32, Colon})(_out::Vector{Float32}, u::Vector{Float32}, p::ComponentVector{Float32, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Tuple{Axis{(vector_field = ViewAxis(1:8320, Axis(layer_1 = ViewAxis(1:4160, Axis(weight = ViewAxis(1:4096, ShapedAxis((64, 64))), bias = ViewAxis(4097:4160, Shaped1DAxis((64,))))), layer_2 = ViewAxis(4161:8320, Axis(weight = ViewAxis(1:4096, ShapedAxis((64, 64))), bias = ViewAxis(4097:4160, Shaped1DAxis((64,))))))), halting_unit = ViewAxis(8321:9377, Axis(layer_1 = ViewAxis(1:1040, Axis(weight = ViewAxis(1:1024, ShapedAxis((16, 64))), bias = ViewAxis(1025:1040, Shaped1DAxis((16,))))), layer_2 = ViewAxis(1041:1057, Axis(weight = ViewAxis(1:16, ShapedAxis((1, 16))), bias = ViewAxis(17:17, Shaped1DAxis((1,))))))))}}}, t::Float32, i::Int64)
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/9RPKK/src/concrete_solve.jl:833
  [3] ReverseLossCallback
    @ ~/.julia/packages/SciMLSensitivity/9RPKK/src/adjoint_common.jl:744 [inlined]
...

I am using

[compat]
Aqua = "0.8.16"
CUDA = "6.1.0"
ChainRulesCore = "1.26.1"
ComponentArrays = "0.15.39"
ConcreteStructs = "0.2.4"
DiffEqGPU = "3.15.0"
DifferentialEquations = "8.0.0"
Lux = "1.31.4"
Optimisers = "0.4.7"
Optimization = "5.6.1"
OptimizationOptimisers = "0.3.17"
Printf = "1.11.0"
Random = "1.11.0"
SciMLSensitivity = "7.111.0"
Statistics = "1.11.1"
Test = "1"
Zygote = "0.7.10"
julia = "~1.11"

Is my sensitivity algorithm and VJP choice sound? I am missing something in my code? I would like to use something that supports Callbacks, GPU and some form of parallelization with Ensembles.