SciML: Issues using `GaussAdjoint` on an `EnsembleProblem` with callbacks

Hello there,

I am rather new to Julia, and I am trying to implement a Neural ODE model that relies on event callbacks to terminate integration. I am using the GaussAdjoint, given that it supports callbacks and that it was recommended on the docs.

This is my current forward pass, which i am running with use_gpu=false:


abstract type AITNODELayer <: AbstractLuxContainerLayer{(:vector_field, :halting_unit)} end

@concrete struct AITNODE <: AITNODELayer
    vector_field <: AbstractLuxLayer
    halting_unit <: AbstractLuxLayer
    dim::Int
    eps::Float32
    use_gpu::Bool
    tspan::Any
    args::Any
    kwargs::Any
end

function (n::AITNODE)(x::AbstractMatrix{<:Number}, θ, st)
    st_vf = st.vector_field
    st_hu = st.halting_unit
    D = n.dim
    B = size(x, 2)
    use_gpu = n.use_gpu

    function dudt(u, θ, t)
        x_state = u[1:D]
        dx, _ = Lux.apply(n.vector_field, x_state, θ.vector_field, st_vf)
        h, _ = Lux.apply(n.halting_unit, x_state, θ.halting_unit, st_hu)
        return vcat(dx, h, h .* x_state)
    end

    pad_matrix = @ignore_derivatives fill!(similar(x, D + 1, B), 0)
    u0_batch = vcat(x, pad_matrix)

    ff = ODEFunction{false}(dudt)
    base_prob = ODEProblem{false}(ff, u0_batch[:, 1], n.tspan, θ)

    function prob_func(prob, ctx)
        u0_i = u0_batch[:, ctx.sim_id]
        remake(prob; u0 = u0_i)
    end

    function condition(u, t, integrator)
        return (one(eltype(u)) - n.eps) - sum(u[(D + 1):(D + 1)])
    end

    cb = ContinuousCallback(condition, terminate!; save_positions = (false, true))

    ensemble_prob = EnsembleProblem(
        base_prob;
        prob_func = prob_func,
        safetycopy = false
    )

    ensemblealg = use_gpu ? EnsembleGPUArray(CUDA.CUDABackend()) : EnsembleThreads()

    ensemble_sol = solve(ensemble_prob, n.args..., ensemblealg;
        sensealg = GaussAdjoint(autojacvec = EnzymeVJP()),
        callback = cb,
        trajectories = B,
        n.kwargs...
    )

    # Extract x_hat and T_star from each trajectory's final state
    x_hats = hcat(map(ensemble_sol.u) do sol
        z = sol.u[end]
        x_state = z[1:D]
        A_val = z[(D + 1):(D + 1)]
        x_bar = z[(D + 2):(2D + 1)]
        x_bar .+ (one(eltype(z)) .- A_val) .* x_state
    end...)
    T_stars = @ignore_derivatives [sol.t[end] for sol in ensemble_sol.u]

    return (x_hats, T_stars), (vector_field = st_vf, halting_unit = st_hu)
end

And im getting the following error

ERROR: LoadError: BoundsError: attempt to access Tuple{} at index [0]
Stacktrace:
  [1] getindex(t::Tuple, i::Int64)
    @ Base ./tuple.jl:31
  [2] (::SciMLSensitivity.var"#df_iip#338"{Float32, Colon})(_out::Vector{Float32}, u::Vector{Float32}, p::ComponentVector{Float32, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Tuple{Axis{(vector_field = ViewAxis(1:8320, Axis(layer_1 = ViewAxis(1:4160, Axis(weight = ViewAxis(1:4096, ShapedAxis((64, 64))), bias = ViewAxis(4097:4160, Shaped1DAxis((64,))))), layer_2 = ViewAxis(4161:8320, Axis(weight = ViewAxis(1:4096, ShapedAxis((64, 64))), bias = ViewAxis(4097:4160, Shaped1DAxis((64,))))))), halting_unit = ViewAxis(8321:9377, Axis(layer_1 = ViewAxis(1:1040, Axis(weight = ViewAxis(1:1024, ShapedAxis((16, 64))), bias = ViewAxis(1025:1040, Shaped1DAxis((16,))))), layer_2 = ViewAxis(1041:1057, Axis(weight = ViewAxis(1:16, ShapedAxis((1, 16))), bias = ViewAxis(17:17, Shaped1DAxis((1,))))))))}}}, t::Float32, i::Int64)
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/9RPKK/src/concrete_solve.jl:833
  [3] ReverseLossCallback
    @ ~/.julia/packages/SciMLSensitivity/9RPKK/src/adjoint_common.jl:744 [inlined]
...

While running training with loss_fn(logits, y_batch) = CrossEntropyLoss(; logits = true)(logits, y_batch). The forward pass works perfectly, is the backward one that errors.

export TrainConfig, train!

"""
    TrainConfig(; epochs, lr, λ_ponder, log_every)
"""
Base.@kwdef struct TrainConfig
    epochs::Int = 100
    lr::Float32 = 1.0f-3
    λ_ponder::Float32 = 0.01f0
    log_every::Int = 25
    use_gpu::Bool = false
end

"""
    train!(model, ps, st, tcfg::TrainConfig, dataloader, loss_fn)

Training loop for the Batched AIT-NODE.
"""
function train!(model, ps, st, tcfg::TrainConfig, dataloader, loss_fn)
    dev = tcfg.use_gpu ? gpu_device() : cpu_device()

    function loss_batch(θ, batch)
        x_batch, y_batch = batch |> dev

        (y_hats, t_halts), st_new = model(x_batch, θ, st)

        # Vectorized Task Loss (e.g., CrossEntropy)
        task_loss = loss_fn(y_hats, y_batch)

        # Vectorized Ponder Penalty
        ponder_loss = tcfg.λ_ponder * mean(t_halts)

        total_loss = task_loss + ponder_loss

        # Optimization.jl requires the primary loss as the first return.
        # We also return the state and sub-losses for the callback to log.
        return total_loss, st_new, task_loss, ponder_loss
    end

    function callback(state, total_loss, st_new, task_loss, ponder_loss)
        if state.iter % tcfg.log_every == 0
            @printf("Iter %5d | Total: %.4e | Task: %.4e | Ponder: %.4e\n",
                state.iter, total_loss, task_loss, ponder_loss)
        end
        st = st_new

        return false
    end

    opt_func = OptimizationFunction(loss_batch, Optimization.AutoZygote())
    opt_prob = OptimizationProblem(opt_func, ps, dataloader)

    println("Starting Training...")
    res = solve(opt_prob, OptimizationOptimisers.Adam(tcfg.lr);
        callback = callback, epochs = tcfg.epochs)

    return res.u, st
end

I am using

[compat]
Aqua = "0.8.16"
CUDA = "6.1.0"
ChainRulesCore = "1.26.1"
ComponentArrays = "0.15.39"
ConcreteStructs = "0.2.4"
DiffEqGPU = "3.15.0"
DifferentialEquations = "8.0.0"
Lux = "1.31.4"
Optimisers = "0.4.7"
Optimization = "5.6.1"
OptimizationOptimisers = "0.3.17"
Printf = "1.11.0"
Random = "1.11.0"
SciMLSensitivity = "7.111.0"
Statistics = "1.11.1"
Test = "1"
Zygote = "0.7.10"
julia = "~1.11"

Is my sensitivity algorithm and VJP choice sound? I am missing something in my code? I would like to use something that supports Callbacks, GPU and some form of parallelization with Ensembles.

The issue also occurs with QuadratureAdjoint.

Can you give a runnable code?

  1. You don’t define θ here, but my guess is its a ComponentArray?
  2. What are the solver kwargs?
  3. Do you get an error when it’s not an ensemble?

This is probably better as an issue in SciMLSensitivity.jl. My first guess is that either the non-ensemble version has an issue (1 or 2), or there’s some odd caching thing to handle for callbacks + ensembles which would just need a bugfix (because come to think about it, I don’t know if I tried that exact combination yet, and I can imagine an odd interaction in the caches showing up), but to do the bugfix I’d need code I could run.

Hey @ChrisRackauckas, thanks for lending a hand.

Here is a self contained code that produces the error. Before I was naively looping and it worked (very slowly), the issue arose from using Ensemble.

using LuxCUDA
using Lux
using Random
using Statistics
using Optimisers, Optimization, OptimizationOptimisers
using ComponentArrays
using OrdinaryDiffEqLowOrderRK
using MLDatasets, MLUtils, NNlib
using OneHotArrays
using Printf
using ConcreteStructs: @concrete
using ChainRulesCore: @ignore_derivatives
using DifferentialEquations
using DiffEqGPU
using SciMLSensitivity

abstract type AITNODELayer <: AbstractLuxContainerLayer{(:vector_field, :halting_unit)} end

@concrete struct AITNODE <: AITNODELayer
    vector_field <: AbstractLuxLayer
    halting_unit <: AbstractLuxLayer
    dim::Int
    eps::Float32
    use_gpu::Bool
    tspan::Any
    args::Any
    kwargs::Any
end

function AITNODE(vf, hu, dim::Int, eps::Float32, use_gpu::Bool, tspan, args...; kwargs...)
    !(vf isa AbstractLuxLayer) && (vf = FromFluxAdaptor()(vf))
    !(hu isa AbstractLuxLayer) && (hu = FromFluxAdaptor()(hu))
    AITNODE(vf, hu, dim, eps, use_gpu, tspan, args, kwargs)
end

function (n::AITNODE)(x::AbstractMatrix{<:Number}, θ, st)
    st_vf = st.vector_field
    st_hu = st.halting_unit
    D = n.dim
    B = size(x, 2)
    use_gpu = n.use_gpu

    function dudt(u, θ, t)
        x_state = u[1:D]
        dx, _ = Lux.apply(n.vector_field, x_state, θ.vector_field, st_vf)
        h, _ = Lux.apply(n.halting_unit, x_state, θ.halting_unit, st_hu)
        return vcat(dx, h, h .* x_state)
    end

    pad_matrix = @ignore_derivatives fill!(similar(x, D + 1, B), 0)
    u0_batch = vcat(x, pad_matrix)

    ff = ODEFunction{false}(dudt)
    base_prob = ODEProblem{false}(ff, u0_batch[:, 1], n.tspan, θ)

    function prob_func(prob, ctx)
        u0_i = u0_batch[:, ctx.sim_id]
        remake(prob; u0 = u0_i)
    end

    function condition(u, t, integrator)
        return (one(eltype(u)) - n.eps) - sum(u[(D + 1):(D + 1)])
    end

    cb = ContinuousCallback(condition, terminate!; save_positions = (false, true))

    ensemble_prob = EnsembleProblem(
        base_prob;
        prob_func = prob_func,
        safetycopy = false
    )

    ensemblealg = use_gpu ? EnsembleGPUArray(CUDA.CUDABackend()) : EnsembleThreads()

    ensemble_sol = solve(ensemble_prob, n.args..., ensemblealg;
        sensealg = GaussAdjoint(autojacvec = EnzymeVJP(), checkpointing = true),
        callback = cb,
        trajectories = B,
        n.kwargs...
    )

    # Extract x_hat and T_star from each trajectory's final state
    x_hats = hcat(map(ensemble_sol.u) do sol
        z = sol.u[end]
        x_state = z[1:D]
        A_val = z[(D + 1):(D + 1)]
        x_bar = z[(D + 2):(2D + 1)]
        x_bar .+ (one(eltype(z)) .- A_val) .* x_state
    end...)
    T_stars = @ignore_derivatives [sol.t[end] for sol in ensemble_sol.u]

    return (x_hats, T_stars), (vector_field = st_vf, halting_unit = st_hu)
end

struct ImageAITClassifier{S, A, C} <: AbstractLuxContainerLayer{(:stem, :ait, :classifier)}
    stem::S
    ait::A
    classifier::C
end

function (m::ImageAITClassifier)(x, p, st)
    # Encoder: Spatial image -> latent vector
    feat, st_stem = m.stem(x, p.stem, st.stem)

    # Integration over the latent vector
    (x_hats, T_stars), st_ait = m.ait(feat, p.ait, st.ait)

    # Classifier: latent vector -> class probabilities (logits)
    logits, st_class = m.classifier(x_hats, p.classifier, st.classifier)

    return (logits, T_stars), (stem = st_stem, ait = st_ait, classifier = st_class)
end

"""
    TrainConfig(; epochs, lr, λ_ponder, log_every)
"""
Base.@kwdef struct TrainConfig
    epochs::Int = 100
    lr::Float32 = 1.0f-3
    λ_ponder::Float32 = 0.01f0
    log_every::Int = 25
    use_gpu::Bool = false
end

"""
    train!(model, ps, st, tcfg::TrainConfig, dataloader, loss_fn)

Training loop for the Batched AIT-NODE.
"""
function train!(model, ps, st, tcfg::TrainConfig, dataloader, loss_fn)
    dev = tcfg.use_gpu ? gpu_device() : cpu_device()

    function loss_batch(θ, batch)
        x_batch, y_batch = batch |> dev

        (y_hats, t_halts), st_new = model(x_batch, θ, st)

        # Vectorized Task Loss (e.g., CrossEntropy)
        task_loss = loss_fn(y_hats, y_batch)

        # Vectorized Ponder Penalty
        ponder_loss = tcfg.λ_ponder * mean(t_halts)

        total_loss = task_loss + ponder_loss

        # Optimization.jl requires the primary loss as the first return.
        # We also return the state and sub-losses for the callback to log.
        return total_loss, st_new, task_loss, ponder_loss
    end

    function callback(state, total_loss, st_new, task_loss, ponder_loss)
        if state.iter % tcfg.log_every == 0
            @printf("Iter %5d | Total: %.4e | Task: %.4e | Ponder: %.4e\n",
                state.iter, total_loss, task_loss, ponder_loss)
        end
        st = st_new

        return false
    end

    opt_func = OptimizationFunction(loss_batch, Optimization.AutoZygote())
    opt_prob = OptimizationProblem(opt_func, ps, dataloader)

    println("Starting Training...")
    res = solve(opt_prob, OptimizationOptimisers.Adam(tcfg.lr);
        callback = callback, epochs = tcfg.epochs)

    return res.u, st
end

function get_mnist_loaders(batchsize = 128)
    println("Loading MNIST Dataset...")

    train_features, train_targets = MLDatasets.MNIST.traindata()
    test_features, test_targets = MLDatasets.MNIST.testdata()

    # Reshape to standard Vision format: (W, H, C, Batch)
    X_train = reshape(Float32.(train_features), 28, 28, 1, :)
    Y_train = onehotbatch(train_targets, 0:9)

    X_test = reshape(Float32.(test_features), 28, 28, 1, :)
    Y_test = onehotbatch(test_targets, 0:9)

    train_loader = DataLoader((X_train, Y_train), batchsize = batchsize, shuffle = true)
    test_loader = DataLoader((X_test, Y_test), batchsize = batchsize, shuffle = false)

    return train_loader, test_loader
end

function build_mnist_ait(; seed = 42, λ_ponder, use_gpu = true)
    rng = Random.Xoshiro(seed)
    D = 64 # Latent dimension

    # Conv Stem: 28x28x1 -> 14x14x16 -> 7x7x32 -> Flatten -> 64
    stem = Chain(
        Conv((3, 3), 1 => 16, relu; pad = 1, stride = 2),
        Conv((3, 3), 16 => 32, relu; pad = 1, stride = 2),
        FlattenLayer(),
        Dense(7 * 7 * 32 => D, swish)
    )

    # AIT NODE Dynamics
    vf = Chain(Dense(D => D, swish), Dense(D => D))
    hu = Chain(Dense(D => 16, swish), Dense(16 => 1, softplus))
    ait = AITNODE(vf, hu, D, Float32(1e-4), use_gpu, (0.0f0, 2.5f0), RK4(); dt = 0.01f0)

    # Classifier Head
    classifier = Dense(D => 10)

    model = ImageAITClassifier(stem, ait, classifier)
    ps, st = Lux.setup(rng, model)

    dev = use_gpu ? gpu_device() : cpu_device()

    ps_ca = ComponentArray(ps) |> dev
    st = st |> dev

    return model, ps_ca, st
end

loss_fn(logits, y_batch) = CrossEntropyLoss(; logits = true)(logits, y_batch)

function run_experiment()
    use_gpu = false
    n_epochs = 1
    master_seed = 42

    train_loader, _ = get_mnist_loaders(128)

    model, ps, st = build_mnist_ait(; seed = master_seed, λ_ponder = 0.01f0, use_gpu)

    train_config = TrainConfig(;
        epochs = n_epochs, lr = 1.0f-3, λ_ponder = 0.01f0, log_every = 1, use_gpu)

    train!(model, ps, st, train_config, train_loader, loss_fn)
end

# Execute
results = run_experiment()
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DiffEqGPU = "071ae1c0-96b5-11e9-1965-c90190d839ea"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
OrdinaryDiffEqLowOrderRK = "1344f307-1e59-4825-a18e-ace9aa3fa4c6"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

Okay yeah that’s a bug. Can you open an issue on SciMLSensitivity.jl? And I’ll try to fix over the weekend.

Done `GaussAdjoint` and `QuadratureAdjoint` fail on `EnsembleProblem` with Callbacks · Issue #1478 · SciML/SciMLSensitivity.jl · GitHub.

If you have an idea were the issue can be maybe I can start looking into it.

TBH the callback caching code is a bit crazy so it might not be easy to track it down, but it would be in the callback tracking code.

Doing some debugging with good 'ol claude I seem to have found a solution.

adding this overload

__getindex(x::AbstractEnsembleSolution, i) = __getindex(x.u, i)

to

seems to fix it.

The issue arises when trying to iterate sol.u and read per-trajectory fields. Zygote reverse pass produces a cotangent for the ensemble that is an EnsembleSolution that wraps a 3D VectorOfArray, and when calling getindex it yields a scalar value. using Array(sol) bypasses this and works.

The cotangent Δ that arrives at ∇responsible_map_internal is:

Δ :: EnsembleSolution{Float32, 1, VectorOfArray{Float32, 3, Vector{DiffEqArray{...}}}}

Its structure:

Δ.u :: VectorOfArray{Float32, 3, Vector{DiffEqArray{...}}}
Δ.u.u :: Vector{DiffEqArray{...}}   # length = n_traj, one per trajectory
Δ.u.u[i] :: DiffEqArray{...}        # trajectory i's cotangent (state × time)

I am currently checking that it doesnt break anything, if it doesnt I can start with the PR, also adding some tests for it.
I would appreciate it if you could look into this.

Hey @ChrisRackauckas,
I’ve created a draft PR Fix(ad): accesing per-trajectory fields · Pull Request #1380 · SciML/SciMLBase.jl, but I think it is related to your PR fix(ad): EnsembleSolution cotangent shape + redundant remake reconstructions · Pull Request #1347 · SciML/SciMLBase.jl. I would like to know if we are both solving the same issue before moving it to review.

I think we are looking at the same thing.

__getindex(x::AbstractEnsembleSolution, i) = __getindex(x.u, i)

AbstractEnsembleSolution isa AbstractVectorOfArray though?

yes, but the cotangent is a nested AbstractVectorOfArray and withouth the fix it indexes ln the wrong level.

I think it would be better to fix the Cotangent issue shape than this hack though. But currently im getting aroun with an @eval to __getIndex given that I need to finish something soon.

Solved by Fix Zygote ensemble adjoints (RAT v4 cotangent iteration + mutable-problem over-counting) and remove __getindex - Pull Request #1384 - SciML/SciMLBase.jl - GitHub

Thanks Chris!