Hello there,
I am rather new to Julia, and I am trying to implement a Neural ODE model that relies on event callbacks to terminate integration. I am using the GaussAdjoint, given that it supports callbacks and that it was recommended on the docs.
This is my current forward pass, which i am running with use_gpu=false:
abstract type AITNODELayer <: AbstractLuxContainerLayer{(:vector_field, :halting_unit)} end
@concrete struct AITNODE <: AITNODELayer
vector_field <: AbstractLuxLayer
halting_unit <: AbstractLuxLayer
dim::Int
eps::Float32
use_gpu::Bool
tspan::Any
args::Any
kwargs::Any
end
function (n::AITNODE)(x::AbstractMatrix{<:Number}, θ, st)
st_vf = st.vector_field
st_hu = st.halting_unit
D = n.dim
B = size(x, 2)
use_gpu = n.use_gpu
function dudt(u, θ, t)
x_state = u[1:D]
dx, _ = Lux.apply(n.vector_field, x_state, θ.vector_field, st_vf)
h, _ = Lux.apply(n.halting_unit, x_state, θ.halting_unit, st_hu)
return vcat(dx, h, h .* x_state)
end
pad_matrix = @ignore_derivatives fill!(similar(x, D + 1, B), 0)
u0_batch = vcat(x, pad_matrix)
ff = ODEFunction{false}(dudt)
base_prob = ODEProblem{false}(ff, u0_batch[:, 1], n.tspan, θ)
function prob_func(prob, ctx)
u0_i = u0_batch[:, ctx.sim_id]
remake(prob; u0 = u0_i)
end
function condition(u, t, integrator)
return (one(eltype(u)) - n.eps) - sum(u[(D + 1):(D + 1)])
end
cb = ContinuousCallback(condition, terminate!; save_positions = (false, true))
ensemble_prob = EnsembleProblem(
base_prob;
prob_func = prob_func,
safetycopy = false
)
ensemblealg = use_gpu ? EnsembleGPUArray(CUDA.CUDABackend()) : EnsembleThreads()
ensemble_sol = solve(ensemble_prob, n.args..., ensemblealg;
sensealg = GaussAdjoint(autojacvec = EnzymeVJP()),
callback = cb,
trajectories = B,
n.kwargs...
)
# Extract x_hat and T_star from each trajectory's final state
x_hats = hcat(map(ensemble_sol.u) do sol
z = sol.u[end]
x_state = z[1:D]
A_val = z[(D + 1):(D + 1)]
x_bar = z[(D + 2):(2D + 1)]
x_bar .+ (one(eltype(z)) .- A_val) .* x_state
end...)
T_stars = @ignore_derivatives [sol.t[end] for sol in ensemble_sol.u]
return (x_hats, T_stars), (vector_field = st_vf, halting_unit = st_hu)
end
And im getting the following error
ERROR: LoadError: BoundsError: attempt to access Tuple{} at index [0]
Stacktrace:
[1] getindex(t::Tuple, i::Int64)
@ Base ./tuple.jl:31
[2] (::SciMLSensitivity.var"#df_iip#338"{Float32, Colon})(_out::Vector{Float32}, u::Vector{Float32}, p::ComponentVector{Float32, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Tuple{Axis{(vector_field = ViewAxis(1:8320, Axis(layer_1 = ViewAxis(1:4160, Axis(weight = ViewAxis(1:4096, ShapedAxis((64, 64))), bias = ViewAxis(4097:4160, Shaped1DAxis((64,))))), layer_2 = ViewAxis(4161:8320, Axis(weight = ViewAxis(1:4096, ShapedAxis((64, 64))), bias = ViewAxis(4097:4160, Shaped1DAxis((64,))))))), halting_unit = ViewAxis(8321:9377, Axis(layer_1 = ViewAxis(1:1040, Axis(weight = ViewAxis(1:1024, ShapedAxis((16, 64))), bias = ViewAxis(1025:1040, Shaped1DAxis((16,))))), layer_2 = ViewAxis(1041:1057, Axis(weight = ViewAxis(1:16, ShapedAxis((1, 16))), bias = ViewAxis(17:17, Shaped1DAxis((1,))))))))}}}, t::Float32, i::Int64)
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/9RPKK/src/concrete_solve.jl:833
[3] ReverseLossCallback
@ ~/.julia/packages/SciMLSensitivity/9RPKK/src/adjoint_common.jl:744 [inlined]
...
I am using
[compat]
Aqua = "0.8.16"
CUDA = "6.1.0"
ChainRulesCore = "1.26.1"
ComponentArrays = "0.15.39"
ConcreteStructs = "0.2.4"
DiffEqGPU = "3.15.0"
DifferentialEquations = "8.0.0"
Lux = "1.31.4"
Optimisers = "0.4.7"
Optimization = "5.6.1"
OptimizationOptimisers = "0.3.17"
Printf = "1.11.0"
Random = "1.11.0"
SciMLSensitivity = "7.111.0"
Statistics = "1.11.1"
Test = "1"
Zygote = "0.7.10"
julia = "~1.11"
Is my sensitivity algorithm and VJP choice sound? I am missing something in my code? I would like to use something that supports Callbacks, GPU and some form of parallelization with Ensembles.