Training Large (e.g, Neural) SDEs on GPUs fails. The only way to obtain adjoints is via `TrackerAdjoint()` and this only currently works on CPU.

None of the continuous adjoints methods, e.g. `InterpolatingAdjoint()` or `BackwardsolveAdjoint()` work either on cpu or gpu.

• I suspect the problem with the continuous methods is the shape of the noise during the backwards solve.
• W.r.t. `TrackerAdjoint()` on gpus, something is transferred to the CPU during the backwards pass. This also happens for ODEs btw.

MWE

``````

using DifferentialEquations, Lux, ComponentArrays, Random, SciMLSensitivity, Zygote, BenchmarkTools, LuxCUDA, CUDA,
OptimizationOptimisers

dev = gpu_device()
sensealg = TrackerAdjoint()  #This works only on cpu

data = rand32(32,100,512) |> dev
x₀ = rand32(32,512) |> dev
ts = range(0.0f0, 1.0f0, length=100)
drift = Dense(32, 32, tanh)
diffusion = Scale(32, sigmoid)

struct NeuralSDE{D, F} <: Lux.AbstractExplicitContainerLayer{(:drift, :diffusion)}
drift::D
diffusion::F
solver
tspan
sensealg
end

function (model::NeuralSDE)(x₀, ts, p, st)
μ(u, p, t) = model.drift(u, p.drift, st.drift)[1]
σ(u, p, t) = model.diffusion(u, p.diffusion, st.diffusion)[1]
prob = SDEProblem{false}(func, x₀, model.tspan, p)
sol = solve(prob, model.solver; saveat=ts, dt=0.01f0, sensealg = model.sensealg)
return permutedims(cat(sol.u..., dims=3), (1,3,2))
end

function loss!(p, data)
pred = model(x₀, ts, p, st)
l = sum(abs2, data .- pred)
return l, st, pred
end

rng = Random.default_rng()
model = NeuralSDE(drift, diffusion, EM(), (0.0f0, 1.0f0), sensealg)
p, st = Lux.setup(rng, model)
p = p |> ComponentArray{Float32} |> dev

optf = OptimizationFunction((p, _ ) -> loss!(p, data), adtype)
optproblem = OptimizationProblem(optf, p)
``````

Error & Stacktrace

``````ERROR: LoadError: GPU compilation of MethodInstance for (::GPUArrays.var"#35#37")(::CUDA.CuKernelContext, ::CuDeviceMatrix{…}, ::Base.Broadcast.Broadcasted{…}, ::Int64) failed
KernelError: passing and using non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2, CUDA.DeviceMemory}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(+), Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, Base.Broadcast.Extruded{Matrix{Float32}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, which is not isbits:
.args is of type Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, Base.Broadcast.Extruded{Matrix{Float32}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}} which is not isbits.
.2 is of type Base.Broadcast.Extruded{Matrix{Float32}, Tuple{Bool, Bool}, Tuple{Int64, Int64}} which is not isbits.
.x is of type Matrix{Float32} which is not isbits.

Stacktrace:
[1] check_invocation(job::GPUCompiler.CompilerJob)
@ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/validation.jl:92
[2] macro expansion
@ ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:128 [inlined]
[3] macro expansion
@ ~/.julia/packages/TimerOutputs/Lw5SP/src/TimerOutput.jl:253 [inlined]
[4]
@ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:126
[5]
@ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:111
[6] compile
@ ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:103 [inlined]
[7] #1145
@ ~/.julia/packages/CUDA/Tl08O/src/compiler/compilation.jl:254 [inlined]
[8] JuliaContext(f::CUDA.var"#1145#1148"{GPUCompiler.CompilerJob{…}}; kwargs::@Kwargs{})
@ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:52
[9] JuliaContext(f::Function)
@ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:42
[10] compile(job::GPUCompiler.CompilerJob)
@ CUDA ~/.julia/packages/CUDA/Tl08O/src/compiler/compilation.jl:253
@ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/execution.jl:237
[12] cached_compilation(cache::Dict{…}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{…}, compiler::Function, linker::Function)
@ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/execution.jl:151
[13] macro expansion
@ ~/.julia/packages/CUDA/Tl08O/src/compiler/execution.jl:369 [inlined]
[14] macro expansion
@ ./lock.jl:267 [inlined]
[15] cufunction(f::GPUArrays.var"#35#37", tt::Type{Tuple{…}}; kwargs::@Kwargs{})
@ CUDA ~/.julia/packages/CUDA/Tl08O/src/compiler/execution.jl:364
[16] cufunction
@ ~/.julia/packages/CUDA/Tl08O/src/compiler/execution.jl:361 [inlined]
[17] macro expansion
@ ~/.julia/packages/CUDA/Tl08O/src/compiler/execution.jl:112 [inlined]
[18] #launch_heuristic#1204
@ ~/.julia/packages/CUDA/Tl08O/src/gpuarrays.jl:17 [inlined]
[19] launch_heuristic
@ ~/.julia/packages/CUDA/Tl08O/src/gpuarrays.jl:15 [inlined]
[20] _copyto!
[21] copyto!
[22] copy
[24] accum!(g::Tracker.Grads, x::Tracker.Tracked{CuArray{Float32, 2, CUDA.DeviceMemory}}, Δ::Matrix{Float32})
@ Tracker ~/.julia/packages/Tracker/NYUWw/src/params.jl:46
[25] back(g::Tracker.Grads, x::Tracker.Tracked{CuArray{Float32, 2, CUDA.DeviceMemory}}, Δ::Matrix{Float32})
@ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:134
[26] #710
@ ~/.julia/packages/Tracker/NYUWw/src/back.jl:128 [inlined]
[27] #64
@ ./tuple.jl:628 [inlined]
[28] BottomRF
@ ./reduce.jl:86 [inlined]
[29] _foldl_impl(op::Base.BottomRF{Base.var"#64#65"{…}}, init::Nothing, itr::Base.Iterators.Zip{Tuple{…}})
@ Base ./reduce.jl:58
[30] foldl_impl
@ ./reduce.jl:48 [inlined]
[31] mapfoldl_impl
@ ./reduce.jl:44 [inlined]
[32] mapfoldl
@ ./reduce.jl:175 [inlined]
[33] foldl
@ ./reduce.jl:198 [inlined]
[34] foreach
@ ./tuple.jl:628 [inlined]
@ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:128
[36] back(g::Tracker.Grads, x::Tracker.Tracked{CuArray{Float32, 1, CUDA.DeviceMemory}}, Δ::Vector{Float32})
@ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:140
[37] #710
@ ~/.julia/packages/Tracker/NYUWw/src/back.jl:128 [inlined]
[38] #64
@ ./tuple.jl:628 [inlined]
[39] BottomRF
@ ./reduce.jl:86 [inlined]
[40] _foldl_impl
@ ./reduce.jl:58 [inlined]
[41] foldl_impl
@ ./reduce.jl:48 [inlined]
[42] mapfoldl_impl(f::typeof(identity), op::Base.var"#64#65"{…}, nt::Nothing, itr::Base.Iterators.Zip{…})
@ Base ./reduce.jl:44
[43] mapfoldl(f::Function, op::Function, itr::Base.Iterators.Zip{Tuple{Tuple{…}, Tuple{…}}}; init::Nothing)
@ Base ./reduce.jl:175
[44] mapfoldl
@ ./reduce.jl:175 [inlined]
[45] foldl
@ ./reduce.jl:198 [inlined]
[46] foreach(::Function, ::Tuple{Tracker.Tracked{…}, Tracker.Tracked{…}}, ::Tuple{Vector{…}, Vector{…}})
@ Base ./tuple.jl:628
@ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:128
[48] back(g::Tracker.Grads, x::Tracker.Tracked{CuArray{Float32, 2, CUDA.DeviceMemory}}, Δ::Matrix{Float32})
@ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:140
[49] #710
@ ~/.julia/packages/Tracker/NYUWw/src/back.jl:128 [inlined]
[50] #64
@ ./tuple.jl:628 [inlined]
[51] BottomRF
@ ./reduce.jl:86 [inlined]
[52] _foldl_impl(op::Base.BottomRF{Base.var"#64#65"{…}}, init::Nothing, itr::Base.Iterators.Zip{Tuple{…}})
@ Base ./reduce.jl:58
--- the last 12 lines are repeated 98 more times ---
[1229] foldl_impl
@ ./reduce.jl:48 [inlined]
[1230] mapfoldl_impl
@ ./reduce.jl:44 [inlined]
[1231] mapfoldl
@ ./reduce.jl:175 [inlined]
[1232] foldl
@ ./reduce.jl:198 [inlined]
[1233] foreach
@ ./tuple.jl:628 [inlined]
@ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:128
@ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:140
[1236] #712
@ ~/.julia/packages/Tracker/NYUWw/src/back.jl:155 [inlined]
[1237] #715
@ ~/.julia/packages/Tracker/NYUWw/src/back.jl:164 [inlined]
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4hOeN/src/concrete_solve.jl:1319
[1239] ZBack
@ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:211 [inlined]
[1240] (::Zygote.var"#kw_zpullback#53"{…})(dy::RODESolution{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:237
[1241] #291
@ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
[1242] (::Zygote.var"#2169#back#293"{…})(Δ::RODESolution{…})
[1243] #solve#51
@ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1003 [inlined]
[1244] (::Zygote.Pullback{…})(Δ::RODESolution{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[1245] #291
@ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
[1246] (::Zygote.var"#2169#back#293"{…})(Δ::RODESolution{…})
[1247] solve
@ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:993 [inlined]
[1248] (::Zygote.Pullback{…})(Δ::RODESolution{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[1249] NeuralSDE
@ ~/code/NeuroDynamics.jl/examples/mwe.jl:31 [inlined]
[1250] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::CuArray{Float32, 3, CUDA.DeviceMemory})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[1251] loss!
@ ~/code/NeuroDynamics.jl/examples/mwe.jl:36 [inlined]
[1252] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Float32, Nothing, Nothing})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[1253] #39
@ ~/code/NeuroDynamics.jl/examples/mwe.jl:48 [inlined]
[1254] #291
@ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
[1255] #2169#back
[1256] OptimizationFunction
@ ~/.julia/packages/SciMLBase/rR75x/src/scimlfunctions.jl:3763 [inlined]
[1257] #291
@ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
[1258] #2169#back
[1259] #37
@ ~/.julia/packages/OptimizationBase/mGHPN/ext/OptimizationZygoteExt.jl:94 [inlined]
[1260] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[1261] #291
@ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
[1262] #2169#back
[1263] #39
@ ~/.julia/packages/OptimizationBase/mGHPN/ext/OptimizationZygoteExt.jl:97 [inlined]
[1264] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[1265] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91
[1266] gradient(f::Function, args::ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{…}}})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:148
[1267] (::OptimizationZygoteExt.var"#38#56"{…})(::ComponentVector{…}, ::ComponentVector{…})
@ OptimizationZygoteExt ~/.julia/packages/OptimizationBase/mGHPN/ext/OptimizationZygoteExt.jl:97
[1268] macro expansion
@ ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:68 [inlined]
[1269] macro expansion
@ ~/.julia/packages/Optimization/fPKIF/src/utils.jl:32 [inlined]
[1270] __solve(cache::OptimizationCache{…})
@ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:66
[1271] solve!(cache::OptimizationCache{…})
@ SciMLBase ~/.julia/packages/SciMLBase/rR75x/src/solve.jl:188
[1272] solve(::OptimizationProblem{…}, ::OptimiserChain{…}; kwargs::@Kwargs{…})
@ SciMLBase ~/.julia/packages/SciMLBase/rR75x/src/solve.jl:96
in expression starting at /home/artiintel/ahmelg/code/NeuroDynamics.jl/examples/mwe.jl:50
Some type information was truncated. Use `show(err)` to see complete types.
``````

I have been struggling with this for a while now so was wondering if someone has experienced the same issue or can spot out a solution!
Thnaks