Errors on 2nd Iteration of EnsembleGPUArrays when using Lux

I’ve been trying to get some Lux UDEs trained using EnsembleGPUArray problems. I’m mostly following the SciML Automatic Model Discovery workflow to approximate an unmodeled disturbance in ODEs using some measured data. Setting up some UDEs using Lux has worked fine, and I’ve had success in speeding things up a little bit with EnsembleThreads. I have setup training as solving many short problems with few states/inputs, so trying EnsembleGPUArray seemed worthwhile, but I kept getting errors. After peeling away everything related to training the NN, it seems I am still getting errors from the EnsembleGPUArray problem, but it doesn’t happen until it tries the second iteration.

I’ve run through a bunch of examples which use Lux and others which use GPU execution, and they work as expected.

I’ve got at least two MWEs derived on examples. The first MWE is based on a DiffEqGPU example which also uses an EnsembleProblem, so I assume the issue is Flux vs Lux? The second MWE is based on an example from DiffEqFlux, where the problem seems be running on the GPU correctly before trying it as an EnsembleProblem.

I’m new to using CUDA, so I wouldn’t be surprised if I was doing something wrong there. But I don’t know what direction to take debugging since it seems like singular problems on the GPU are working correctly, it’s only when the second iteration of the EnsembleProblem occurs that an error gets thrown.

MWE1

# based on https://docs.sciml.ai/DiffEqGPU/stable/examples/ad/
using OrdinaryDiffEq, SciMLSensitivity, DiffEqGPU, CUDA, Test
using Lux, LuxCUDA, StableRNGs
CUDA.allowscalar(false)

u0 = Float32[0.0]
pa = rand(Float32, 2)

U = Lux.Chain(Lux.Dense(1, 5), Lux.Dense(5, 1))
rng = StableRNG(1234)
p_nn, st = Lux.setup(rng, U)

function modelf(du, u, p, t)
    du[1] = 1.01 * u[1] * p[1] * p[2] + U(u, p_nn, st)[1][1]
end
# Check - not running on GPU.
p = ODEProblem(modelf, u0, (0.0, 1.0), pa)
s = solve(p)

function model(nun_iter)
    prob = ODEProblem(modelf, u0, (0.0, 1.0), pa)

    function prob_func(prob, i, repeat)
        println("Prob_Func was called!")
        remake(prob, u0 = 0.5 .+ i / 100 .* prob.u0)
    end

    ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
    solve(ensemble_prob, Tsit5(), EnsembleGPUArray(CUDA.CUDABackend()), saveat = 0.1,
        trajectories = nun_iter)
end

# loss function
model(1) # Works fine
model(2) # Throws error on second iteration

Produces error

ERROR: LoadError: GPUCompiler.InvalidIRError(GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}(MethodInstance for DiffEqGPU.gpu_gpu_kernel(::KernelAbstractions.CompilerMetadata{…}, ::typeof(modelf), ::CuDeviceMatrix{…}, ::CuDeviceMatrix{…}, ::CuDeviceMatrix{…}, ::Float64), GPUCompiler.CompilerConfig{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}(GPUCompiler.PTXCompilerTarget(v"7.5.0", v"7.5.0", true, nothing, nothing, nothing, nothing, false, nothing, nothing), CUDA.CUDACompilerParams(v"7.5.0", v"8.3.0"), true, nothing, :specfunc, false, 2), 0x0000000000007c5b), Tuple{String, Vector{Base.StackTraces.StackFrame}, Any}[("dynamic function invocation", [modelf at mwe_fin1.jl:14, macro expansion at kernels.jl:43, gpu_gpu_kernel at macros.jl:95, gpu_gpu_kernel at none:0], nothing), ("dynamic function invocation", [modelf at mwe_fin1.jl:14, macro expansion at kernels.jl:43, gpu_gpu_kernel at macros.jl:95, gpu_gpu_kernel at none:0], getindex), ("dynamic function invocation", [modelf at mwe_fin1.jl:14, macro expansion at kernels.jl:43, gpu_gpu_kernel at macros.jl:95, gpu_gpu_kernel at none:0], +), ("dynamic function invocation", [setindex! at array.jl:166, setindex! at subarray.jl:355, modelf at mwe_fin1.jl:14, macro expansion at kernels.jl:43, gpu_gpu_kernel at macros.jl:95, gpu_gpu_kernel at none:0], convert)])
Stacktrace:
  [1] check_ir(job::GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}, args::LLVM.Module)
    @ GPUCompiler C:\Users\coopert\.julia\packages\GPUCompiler\nWT2N\src\validation.jl:147
  [2] macro expansion
    @ C:\Users\coopert\.julia\packages\GPUCompiler\nWT2N\src\driver.jl:460 [inlined]
  [3] macro expansion
    @ C:\Users\coopert\.julia\packages\TimerOutputs\Lw5SP\src\TimerOutput.jl:253 [inlined]
  [4] macro expansion
    @ C:\Users\coopert\.julia\packages\GPUCompiler\nWT2N\src\driver.jl:459 [inlined]
  [5]
    @ GPUCompiler C:\Users\coopert\.julia\packages\GPUCompiler\nWT2N\src\utils.jl:103
  [6] emit_llvm
    @ C:\Users\coopert\.julia\packages\GPUCompiler\nWT2N\src\utils.jl:97 [inlined]
  [7]
    @ GPUCompiler C:\Users\coopert\.julia\packages\GPUCompiler\nWT2N\src\driver.jl:136
  [8]
    @ GPUCompiler C:\Users\coopert\.julia\packages\GPUCompiler\nWT2N\src\driver.jl:111
  [9] compile
    @ C:\Users\coopert\.julia\packages\GPUCompiler\nWT2N\src\driver.jl:103 [inlined]
 [10] #1145
    @ C:\Users\coopert\.julia\packages\CUDA\75aiI\src\compiler\compilation.jl:254 [inlined]
 [11] JuliaContext(f::CUDA.var"#1145#1148"{GPUCompiler.CompilerJob{…}}; kwargs::@Kwargs{})
    @ GPUCompiler C:\Users\coopert\.julia\packages\GPUCompiler\nWT2N\src\driver.jl:52
 [12] JuliaContext(f::Function)
    @ GPUCompiler C:\Users\coopert\.julia\packages\GPUCompiler\nWT2N\src\driver.jl:42
 [13] compile(job::GPUCompiler.CompilerJob)
    @ CUDA C:\Users\coopert\.julia\packages\CUDA\75aiI\src\compiler\compilation.jl:253
 [14] actual_compilation(cache::Dict{…}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{…}, compiler::typeof(CUDA.compile), linker::typeof(CUDA.link))
    @ GPUCompiler C:\Users\coopert\.julia\packages\GPUCompiler\nWT2N\src\execution.jl:128
 [15] cached_compilation(cache::Dict{…}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{…}, compiler::Function, linker::Function)
    @ GPUCompiler C:\Users\coopert\.julia\packages\GPUCompiler\nWT2N\src\execution.jl:103
 [16] macro expansion
    @ C:\Users\coopert\.julia\packages\CUDA\75aiI\src\compiler\execution.jl:369 [inlined]
 [17] macro expansion
    @ .\lock.jl:267 [inlined]
 [18] cufunction(f::typeof(DiffEqGPU.gpu_gpu_kernel), tt::Type{…}; kwargs::@Kwargs{…})
    @ CUDA C:\Users\coopert\.julia\packages\CUDA\75aiI\src\compiler\execution.jl:364
 [19] macro expansion
    @ C:\Users\coopert\.julia\packages\CUDA\75aiI\src\compiler\execution.jl:112 [inlined]
 [20] (::KernelAbstractions.Kernel{…})(::Function, ::Vararg{…}; ndrange::Int64, workgroupsize::Int64)
    @ CUDA.CUDAKernels C:\Users\coopert\.julia\packages\CUDA\75aiI\src\CUDAKernels.jl:103
 [21] Kernel
    @ C:\Users\coopert\.julia\packages\CUDA\75aiI\src\CUDAKernels.jl:89 [inlined]
 [22] #12
    @ C:\Users\coopert\.julia\packages\DiffEqGPU\I999k\src\ensemblegpuarray\problem_generation.jl:10 [inlined]
 [23] ODEFunction
    @ C:\Users\coopert\.julia\packages\SciMLBase\sakPO\src\scimlfunctions.jl:2296 [inlined]
 [24] initialize!(integrator::OrdinaryDiffEq.ODEIntegrator{…}, cache::OrdinaryDiffEq.Tsit5Cache{…})
    @ OrdinaryDiffEq C:\Users\coopert\.julia\packages\OrdinaryDiffEq\Knuk0\src\perform_step\low_order_rk_perform_step.jl:799
 [25] __init(prob::ODEProblem{…}, alg::Tsit5{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Float64, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::Float64, dtmin::Float64, dtmax::Float64, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Nothing, reltol::Nothing, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqGPU.diffeqgpunorm), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::DiffEqGPU.var"#114#120", verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::@Kwargs{})
    @ OrdinaryDiffEq C:\Users\coopert\.julia\packages\OrdinaryDiffEq\Knuk0\src\solve.jl:524
 [26] __init (repeats 5 times)
    @ C:\Users\coopert\.julia\packages\OrdinaryDiffEq\Knuk0\src\solve.jl:11 [inlined]
 [27] #__solve#670
    @ C:\Users\coopert\.julia\packages\OrdinaryDiffEq\Knuk0\src\solve.jl:6 [inlined]
 [28] __solve
    @ C:\Users\coopert\.julia\packages\OrdinaryDiffEq\Knuk0\src\solve.jl:1 [inlined]
 [29] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
    @ DiffEqBase C:\Users\coopert\.julia\packages\DiffEqBase\DS1sd\src\solve.jl:612
 [30] solve_call
    @ C:\Users\coopert\.julia\packages\DiffEqBase\DS1sd\src\solve.jl:569 [inlined]
 [31] #solve_up#53
    @ C:\Users\coopert\.julia\packages\DiffEqBase\DS1sd\src\solve.jl:1080 [inlined]
 [32] solve_up
    @ C:\Users\coopert\.julia\packages\DiffEqBase\DS1sd\src\solve.jl:1066 [inlined]
 [33] #solve#51
    @ C:\Users\coopert\.julia\packages\DiffEqBase\DS1sd\src\solve.jl:1003 [inlined]
 [34] batch_solve_up(ensembleprob::EnsembleProblem{…}, probs::Vector{…}, alg::Tsit5{…}, ensemblealg::EnsembleGPUArray{…}, I::UnitRange{…}, u0::Matrix{…}, p::Matrix{…}; kwargs::@Kwargs{…})
    @ DiffEqGPU C:\Users\coopert\.julia\packages\DiffEqGPU\I999k\src\solve.jl:315
 [35] batch_solve(ensembleprob::EnsembleProblem{…}, alg::Tsit5{…}, ensemblealg::EnsembleGPUArray{…}, I::UnitRange{…}, adaptive::Bool; kwargs::@Kwargs{…})
    @ DiffEqGPU C:\Users\coopert\.julia\packages\DiffEqGPU\I999k\src\solve.jl:242
 [36] macro expansion
    @ .\timing.jl:395 [inlined]
 [37] __solve(ensembleprob::EnsembleProblem{…}, alg::Tsit5{…}, ensemblealg::EnsembleGPUArray{…}; trajectories::Int64, batch_size::Int64, unstable_check::Function, adaptive::Bool, kwargs::@Kwargs{…})
    @ DiffEqGPU C:\Users\coopert\.julia\packages\DiffEqGPU\I999k\src\solve.jl:55
 [38] __solve
    @ C:\Users\coopert\.julia\packages\DiffEqGPU\I999k\src\solve.jl:1 [inlined]
 [39] #solve#55
    @ C:\Users\coopert\.julia\packages\DiffEqBase\DS1sd\src\solve.jl:1096 [inlined]
 [40] model(nun_iter::Int64)
    @ Main D:\temp\mwes\mwe_fin1.jl:29
 [41] top-level scope
    @ D:\temp\mwes\mwe_fin1.jl:35
in expression starting at D:\temp\mwes\mwe_fin1.jl:35
Some type information was truncated. Use `show(err)` to see complete types.

MWE2

# based on https://docs.sciml.ai/DiffEqFlux/dev/examples/GPUs/
using OrdinaryDiffEq, Lux, LuxCUDA, SciMLSensitivity, ComponentArrays, Random
using DiffEqGPU
rng = Xoshiro(0)

const cdev = cpu_device()
const gdev = gpu_device()

model = Chain(Dense(2, 50, tanh), Dense(50, 2))
ps, st = Lux.setup(rng, model)
ps = ps |> ComponentArray |> gdev
st = st |> gdev
dudt(u, p, t) = model(u, p, st)[1]

tspan = (0.0f0, 10.0f0)
tsteps = 0.0f0:1.0f-1:10.0f0

u0 = Float32[2.0; 0.0] |> gdev
prob_gpu = ODEProblem(dudt, u0, tspan, ps)
# Check - single solution works. Should run on GPU. 
sol_gpu = solve(prob_gpu, Tsit5(); saveat = tsteps)

function prob_func(prob, i, repeat)
    remake(prob)
end

ensemble_prob = EnsembleProblem(prob_gpu, 
        prob_func = prob_func, 
        safetycopy = false)
sim = solve(ensemble_prob, Tsit5(), 
        EnsembleGPUArray(CUDA.CUDABackend()), 
        trajectories = 1) # Works fine
sim = solve(ensemble_prob, Tsit5(), 
        EnsembleGPUArray(CUDA.CUDABackend()), 
        trajectories = 2) # Throws error on 2nd iteration

Error

ERROR: LoadError: GPUCompiler.InvalidIRError(GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}(MethodInstance for DiffEqGPU.gpu_gpu_kernel_oop(::KernelAbstractions.CompilerMetadata{…}, ::typeof(dudt), ::CuDeviceMatrix{…}, ::CuDeviceMatrix{…}, ::CuDeviceMatrix{…}, ::Float32), GPUCompiler.CompilerConfig{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}(GPUCompiler.PTXCompilerTarget(v"7.5.0", v"7.5.0", true, nothing, nothing, nothing, nothing, false, nothing, nothing), CUDA.CUDACompilerParams(v"7.5.0", v"8.3.0"), true, nothing, :specfunc, false, 2), 0x0000000000007c65), Tuple{String, Vector{Base.StackTraces.StackFrame}, Any}[("dynamic function invocation", [macro expansion at kernels.jl:57, gpu_gpu_kernel_oop at macros.jl:95, gpu_gpu_kernel_oop at none:0], getindex), ("dynamic function invocation", [setindex! at array.jl:166, setindex! at array.jl:178, macro expansion at kernels.jl:57, gpu_gpu_kernel_oop at macros.jl:95, gpu_gpu_kernel_oop at none:0], convert), ("dynamic function invocation", [dudt at mwe_fin2.jl:13, macro expansion at kernels.jl:52, gpu_gpu_kernel_oop at macros.jl:95, gpu_gpu_kernel_oop at none:0], nothing), ("dynamic function invocation", [dudt at mwe_fin2.jl:13, macro expansion at kernels.jl:52, gpu_gpu_kernel_oop at macros.jl:95, gpu_gpu_kernel_oop at none:0], getindex)])
Stacktrace:
  [1] check_ir(job::GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}, args::LLVM.Module)
    @ GPUCompiler C:\Users\coopert\.julia\packages\GPUCompiler\nWT2N\src\validation.jl:147
  [2] macro expansion
    @ C:\Users\coopert\.julia\packages\GPUCompiler\nWT2N\src\driver.jl:460 [inlined]
  [3] macro expansion
    @ C:\Users\coopert\.julia\packages\TimerOutputs\Lw5SP\src\TimerOutput.jl:253 [inlined]
  [4] macro expansion
    @ C:\Users\coopert\.julia\packages\GPUCompiler\nWT2N\src\driver.jl:459 [inlined]
  [5]
    @ GPUCompiler C:\Users\coopert\.julia\packages\GPUCompiler\nWT2N\src\utils.jl:103
  [6] emit_llvm
    @ C:\Users\coopert\.julia\packages\GPUCompiler\nWT2N\src\utils.jl:97 [inlined]
  [7]
    @ GPUCompiler C:\Users\coopert\.julia\packages\GPUCompiler\nWT2N\src\driver.jl:136
  [8]
    @ GPUCompiler C:\Users\coopert\.julia\packages\GPUCompiler\nWT2N\src\driver.jl:111
  [9] compile
    @ C:\Users\coopert\.julia\packages\GPUCompiler\nWT2N\src\driver.jl:103 [inlined]
 [10] #1145
    @ C:\Users\coopert\.julia\packages\CUDA\75aiI\src\compiler\compilation.jl:254 [inlined]
 [11] JuliaContext(f::CUDA.var"#1145#1148"{GPUCompiler.CompilerJob{…}}; kwargs::@Kwargs{})
    @ GPUCompiler C:\Users\coopert\.julia\packages\GPUCompiler\nWT2N\src\driver.jl:52
 [12] JuliaContext(f::Function)
    @ GPUCompiler C:\Users\coopert\.julia\packages\GPUCompiler\nWT2N\src\driver.jl:42
 [13] compile(job::GPUCompiler.CompilerJob)
    @ CUDA C:\Users\coopert\.julia\packages\CUDA\75aiI\src\compiler\compilation.jl:253
 [14] actual_compilation(cache::Dict{…}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{…}, compiler::typeof(CUDA.compile), linker::typeof(CUDA.link))
    @ GPUCompiler C:\Users\coopert\.julia\packages\GPUCompiler\nWT2N\src\execution.jl:128
 [15] cached_compilation(cache::Dict{…}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{…}, compiler::Function, linker::Function)
    @ GPUCompiler C:\Users\coopert\.julia\packages\GPUCompiler\nWT2N\src\execution.jl:103
 [16] macro expansion
    @ C:\Users\coopert\.julia\packages\CUDA\75aiI\src\compiler\execution.jl:369 [inlined]
 [17] macro expansion
    @ .\lock.jl:267 [inlined]
 [18] cufunction(f::typeof(DiffEqGPU.gpu_gpu_kernel_oop), tt::Type{…}; kwargs::@Kwargs{…})
    @ CUDA C:\Users\coopert\.julia\packages\CUDA\75aiI\src\compiler\execution.jl:364
 [19] macro expansion
    @ C:\Users\coopert\.julia\packages\CUDA\75aiI\src\compiler\execution.jl:112 [inlined]
 [20] (::KernelAbstractions.Kernel{…})(::Function, ::Vararg{…}; ndrange::Int64, workgroupsize::Int64)
    @ CUDA.CUDAKernels C:\Users\coopert\.julia\packages\CUDA\75aiI\src\CUDAKernels.jl:103
 [21] Kernel
    @ C:\Users\coopert\.julia\packages\CUDA\75aiI\src\CUDAKernels.jl:89 [inlined]
 [22] #12
    @ C:\Users\coopert\.julia\packages\DiffEqGPU\I999k\src\ensemblegpuarray\problem_generation.jl:10 [inlined]
 [23] ODEFunction
    @ C:\Users\coopert\.julia\packages\SciMLBase\sakPO\src\scimlfunctions.jl:2296 [inlined]
 [24] initialize!(integrator::OrdinaryDiffEq.ODEIntegrator{…}, cache::OrdinaryDiffEq.Tsit5Cache{…})
    @ OrdinaryDiffEq C:\Users\coopert\.julia\packages\OrdinaryDiffEq\Knuk0\src\perform_step\low_order_rk_perform_step.jl:799
 [25] __init(prob::ODEProblem{…}, alg::Tsit5{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::Float32, dtmin::Float32, dtmax::Float32, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Nothing, reltol::Nothing, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqGPU.diffeqgpunorm), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::DiffEqGPU.var"#114#120", verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::@Kwargs{})
    @ OrdinaryDiffEq C:\Users\coopert\.julia\packages\OrdinaryDiffEq\Knuk0\src\solve.jl:524
 [26] __init (repeats 5 times)
    @ C:\Users\coopert\.julia\packages\OrdinaryDiffEq\Knuk0\src\solve.jl:11 [inlined]
 [27] #__solve#670
    @ C:\Users\coopert\.julia\packages\OrdinaryDiffEq\Knuk0\src\solve.jl:6 [inlined]
 [28] __solve
    @ C:\Users\coopert\.julia\packages\OrdinaryDiffEq\Knuk0\src\solve.jl:1 [inlined]
 [29] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
    @ DiffEqBase C:\Users\coopert\.julia\packages\DiffEqBase\DS1sd\src\solve.jl:612
 [30] solve_call
    @ C:\Users\coopert\.julia\packages\DiffEqBase\DS1sd\src\solve.jl:569 [inlined]
 [31] #solve_up#53
    @ C:\Users\coopert\.julia\packages\DiffEqBase\DS1sd\src\solve.jl:1080 [inlined]
 [32] solve_up
    @ C:\Users\coopert\.julia\packages\DiffEqBase\DS1sd\src\solve.jl:1066 [inlined]
 [33] #solve#51
    @ C:\Users\coopert\.julia\packages\DiffEqBase\DS1sd\src\solve.jl:1003 [inlined]
 [34] batch_solve_up(ensembleprob::EnsembleProblem{…}, probs::Vector{…}, alg::Tsit5{…}, ensemblealg::EnsembleGPUArray{…}, I::UnitRange{…}, u0::Matrix{…}, p::Matrix{…}; kwargs::@Kwargs{…})
    @ DiffEqGPU C:\Users\coopert\.julia\packages\DiffEqGPU\I999k\src\solve.jl:315
 [35] batch_solve(ensembleprob::EnsembleProblem{…}, alg::Tsit5{…}, ensemblealg::EnsembleGPUArray{…}, I::UnitRange{…}, adaptive::Bool; kwargs::@Kwargs{…})
    @ DiffEqGPU C:\Users\coopert\.julia\packages\DiffEqGPU\I999k\src\solve.jl:242
 [36] macro expansion
    @ .\timing.jl:395 [inlined]
 [37] __solve(ensembleprob::EnsembleProblem{…}, alg::Tsit5{…}, ensemblealg::EnsembleGPUArray{…}; trajectories::Int64, batch_size::Int64, unstable_check::Function, adaptive::Bool, kwargs::@Kwargs{})
    @ DiffEqGPU C:\Users\coopert\.julia\packages\DiffEqGPU\I999k\src\solve.jl:55
 [38] __solve
    @ C:\Users\coopert\.julia\packages\DiffEqGPU\I999k\src\solve.jl:1 [inlined]
 [39] #solve#55
    @ C:\Users\coopert\.julia\packages\DiffEqBase\DS1sd\src\solve.jl:1096 [inlined]
 [40] top-level scope
    @ D:\temp\mwes\mwe_fin2.jl:33
in expression starting at D:\temp\mwes\mwe_fin2.jl:33
Some type information was truncated. Use `show(err)` to see complete types.

The way DiffEqGPU works you won’t be able to compose it with Lux atm (@ChrisRackauckas can correct me here if I am mistaken).

That said Lux networks can be put inside GPU kernels Neural Networks Inside GPU Kernels | Lux.jl Docs (use the #main unreleased version for this), but we might have to add dispatches on DiffEqGPU for handling these.

Indeed EnsembleGPUArray is the wrong thing to use here. In theory you can get this to work but it won’t be as fast as batching the normal way, so I’d just GPU accelerate the neural network calls the normal way here shown in the docs.