Hello there,
I am rather new to Julia, and I am trying to implement a Neural ODE model that relies on event callbacks to terminate integration. I am using the GaussAdjoint, given that it supports callbacks and that it was recommended on the docs.
This is my current forward pass, which i am running with use_gpu=false:
abstract type AITNODELayer <: AbstractLuxContainerLayer{(:vector_field, :halting_unit)} end
@concrete struct AITNODE <: AITNODELayer
vector_field <: AbstractLuxLayer
halting_unit <: AbstractLuxLayer
dim::Int
eps::Float32
use_gpu::Bool
tspan::Any
args::Any
kwargs::Any
end
function (n::AITNODE)(x::AbstractMatrix{<:Number}, θ, st)
st_vf = st.vector_field
st_hu = st.halting_unit
D = n.dim
B = size(x, 2)
use_gpu = n.use_gpu
function dudt(u, θ, t)
x_state = u[1:D]
dx, _ = Lux.apply(n.vector_field, x_state, θ.vector_field, st_vf)
h, _ = Lux.apply(n.halting_unit, x_state, θ.halting_unit, st_hu)
return vcat(dx, h, h .* x_state)
end
pad_matrix = @ignore_derivatives fill!(similar(x, D + 1, B), 0)
u0_batch = vcat(x, pad_matrix)
ff = ODEFunction{false}(dudt)
base_prob = ODEProblem{false}(ff, u0_batch[:, 1], n.tspan, θ)
function prob_func(prob, ctx)
u0_i = u0_batch[:, ctx.sim_id]
remake(prob; u0 = u0_i)
end
function condition(u, t, integrator)
return (one(eltype(u)) - n.eps) - sum(u[(D + 1):(D + 1)])
end
cb = ContinuousCallback(condition, terminate!; save_positions = (false, true))
ensemble_prob = EnsembleProblem(
base_prob;
prob_func = prob_func,
safetycopy = false
)
ensemblealg = use_gpu ? EnsembleGPUArray(CUDA.CUDABackend()) : EnsembleThreads()
ensemble_sol = solve(ensemble_prob, n.args..., ensemblealg;
sensealg = GaussAdjoint(autojacvec = EnzymeVJP()),
callback = cb,
trajectories = B,
n.kwargs...
)
# Extract x_hat and T_star from each trajectory's final state
x_hats = hcat(map(ensemble_sol.u) do sol
z = sol.u[end]
x_state = z[1:D]
A_val = z[(D + 1):(D + 1)]
x_bar = z[(D + 2):(2D + 1)]
x_bar .+ (one(eltype(z)) .- A_val) .* x_state
end...)
T_stars = @ignore_derivatives [sol.t[end] for sol in ensemble_sol.u]
return (x_hats, T_stars), (vector_field = st_vf, halting_unit = st_hu)
end
And im getting the following error
ERROR: LoadError: BoundsError: attempt to access Tuple{} at index [0]
Stacktrace:
[1] getindex(t::Tuple, i::Int64)
@ Base ./tuple.jl:31
[2] (::SciMLSensitivity.var"#df_iip#338"{Float32, Colon})(_out::Vector{Float32}, u::Vector{Float32}, p::ComponentVector{Float32, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Tuple{Axis{(vector_field = ViewAxis(1:8320, Axis(layer_1 = ViewAxis(1:4160, Axis(weight = ViewAxis(1:4096, ShapedAxis((64, 64))), bias = ViewAxis(4097:4160, Shaped1DAxis((64,))))), layer_2 = ViewAxis(4161:8320, Axis(weight = ViewAxis(1:4096, ShapedAxis((64, 64))), bias = ViewAxis(4097:4160, Shaped1DAxis((64,))))))), halting_unit = ViewAxis(8321:9377, Axis(layer_1 = ViewAxis(1:1040, Axis(weight = ViewAxis(1:1024, ShapedAxis((16, 64))), bias = ViewAxis(1025:1040, Shaped1DAxis((16,))))), layer_2 = ViewAxis(1041:1057, Axis(weight = ViewAxis(1:16, ShapedAxis((1, 16))), bias = ViewAxis(17:17, Shaped1DAxis((1,))))))))}}}, t::Float32, i::Int64)
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/9RPKK/src/concrete_solve.jl:833
[3] ReverseLossCallback
@ ~/.julia/packages/SciMLSensitivity/9RPKK/src/adjoint_common.jl:744 [inlined]
...
While running training with loss_fn(logits, y_batch) = CrossEntropyLoss(; logits = true)(logits, y_batch). The forward pass works perfectly, is the backward one that errors.
export TrainConfig, train!
"""
TrainConfig(; epochs, lr, λ_ponder, log_every)
"""
Base.@kwdef struct TrainConfig
epochs::Int = 100
lr::Float32 = 1.0f-3
λ_ponder::Float32 = 0.01f0
log_every::Int = 25
use_gpu::Bool = false
end
"""
train!(model, ps, st, tcfg::TrainConfig, dataloader, loss_fn)
Training loop for the Batched AIT-NODE.
"""
function train!(model, ps, st, tcfg::TrainConfig, dataloader, loss_fn)
dev = tcfg.use_gpu ? gpu_device() : cpu_device()
function loss_batch(θ, batch)
x_batch, y_batch = batch |> dev
(y_hats, t_halts), st_new = model(x_batch, θ, st)
# Vectorized Task Loss (e.g., CrossEntropy)
task_loss = loss_fn(y_hats, y_batch)
# Vectorized Ponder Penalty
ponder_loss = tcfg.λ_ponder * mean(t_halts)
total_loss = task_loss + ponder_loss
# Optimization.jl requires the primary loss as the first return.
# We also return the state and sub-losses for the callback to log.
return total_loss, st_new, task_loss, ponder_loss
end
function callback(state, total_loss, st_new, task_loss, ponder_loss)
if state.iter % tcfg.log_every == 0
@printf("Iter %5d | Total: %.4e | Task: %.4e | Ponder: %.4e\n",
state.iter, total_loss, task_loss, ponder_loss)
end
st = st_new
return false
end
opt_func = OptimizationFunction(loss_batch, Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps, dataloader)
println("Starting Training...")
res = solve(opt_prob, OptimizationOptimisers.Adam(tcfg.lr);
callback = callback, epochs = tcfg.epochs)
return res.u, st
end
I am using
[compat]
Aqua = "0.8.16"
CUDA = "6.1.0"
ChainRulesCore = "1.26.1"
ComponentArrays = "0.15.39"
ConcreteStructs = "0.2.4"
DiffEqGPU = "3.15.0"
DifferentialEquations = "8.0.0"
Lux = "1.31.4"
Optimisers = "0.4.7"
Optimization = "5.6.1"
OptimizationOptimisers = "0.3.17"
Printf = "1.11.0"
Random = "1.11.0"
SciMLSensitivity = "7.111.0"
Statistics = "1.11.1"
Test = "1"
Zygote = "0.7.10"
julia = "~1.11"
Is my sensitivity algorithm and VJP choice sound? I am missing something in my code? I would like to use something that supports Callbacks, GPU and some form of parallelization with Ensembles.