Reverse mode AD for SDEs

Training Large (e.g, Neural) SDEs on GPUs fails. The only way to obtain adjoints is via TrackerAdjoint() and this only currently works on CPU.

None of the continuous adjoints methods, e.g. InterpolatingAdjoint() or BackwardsolveAdjoint() work either on cpu or gpu.

  • I suspect the problem with the continuous methods is the shape of the noise during the backwards solve.
  • W.r.t. TrackerAdjoint() on gpus, something is transferred to the CPU during the backwards pass. This also happens for ODEs btw.

MWE



using DifferentialEquations, Lux, ComponentArrays, Random, SciMLSensitivity, Zygote, BenchmarkTools, LuxCUDA, CUDA,
OptimizationOptimisers



dev = gpu_device()
sensealg = TrackerAdjoint()  #This works only on cpu

data = rand32(32,100,512) |> dev
x₀ = rand32(32,512) |> dev
ts = range(0.0f0, 1.0f0, length=100)
drift = Dense(32, 32, tanh)
diffusion = Scale(32, sigmoid)

basic_tgrad(u, p, t) = zero(u)

struct NeuralSDE{D, F} <: Lux.AbstractExplicitContainerLayer{(:drift, :diffusion)}
    drift::D
    diffusion::F
    solver
    tspan
    sensealg
end

function (model::NeuralSDE)(x₀, ts, p, st)
    μ(u, p, t) = model.drift(u, p.drift, st.drift)[1]
    σ(u, p, t) = model.diffusion(u, p.diffusion, st.diffusion)[1]
    func = SDEFunction{false}(μ, σ; tgrad=basic_tgrad)
    prob = SDEProblem{false}(func, x₀, model.tspan, p)
    sol = solve(prob, model.solver; saveat=ts, dt=0.01f0, sensealg = model.sensealg)
    return permutedims(cat(sol.u..., dims=3), (1,3,2))
end

function loss!(p, data)
    pred = model(x₀, ts, p, st)
    l = sum(abs2, data .- pred)
    return l, st, pred
end

rng = Random.default_rng()
model = NeuralSDE(drift, diffusion, EM(), (0.0f0, 1.0f0), sensealg)
p, st = Lux.setup(rng, model)
p = p |> ComponentArray{Float32} |> dev


adtype = AutoZygote()
optf = OptimizationFunction((p, _ ) -> loss!(p, data), adtype)
optproblem = OptimizationProblem(optf, p)
result = Optimization.solve(optproblem, ADAMW(5e-4), maxiters=10)

Error & Stacktrace

ERROR: LoadError: GPU compilation of MethodInstance for (::GPUArrays.var"#35#37")(::CUDA.CuKernelContext, ::CuDeviceMatrix{…}, ::Base.Broadcast.Broadcasted{…}, ::Int64) failed
KernelError: passing and using non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2, CUDA.DeviceMemory}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(+), Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, Base.Broadcast.Extruded{Matrix{Float32}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, which is not isbits:
  .args is of type Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, Base.Broadcast.Extruded{Matrix{Float32}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}} which is not isbits.
    .2 is of type Base.Broadcast.Extruded{Matrix{Float32}, Tuple{Bool, Bool}, Tuple{Int64, Int64}} which is not isbits.
      .x is of type Matrix{Float32} which is not isbits.


Stacktrace:
    [1] check_invocation(job::GPUCompiler.CompilerJob)
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/validation.jl:92
    [2] macro expansion
      @ ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:128 [inlined]
    [3] macro expansion
      @ ~/.julia/packages/TimerOutputs/Lw5SP/src/TimerOutput.jl:253 [inlined]
    [4] 
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:126
    [5] 
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:111
    [6] compile
      @ ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:103 [inlined]
    [7] #1145
      @ ~/.julia/packages/CUDA/Tl08O/src/compiler/compilation.jl:254 [inlined]
    [8] JuliaContext(f::CUDA.var"#1145#1148"{GPUCompiler.CompilerJob{…}}; kwargs::@Kwargs{})
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:52
    [9] JuliaContext(f::Function)
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:42
   [10] compile(job::GPUCompiler.CompilerJob)
      @ CUDA ~/.julia/packages/CUDA/Tl08O/src/compiler/compilation.jl:253
   [11] actual_compilation(cache::Dict{…}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{…}, compiler::typeof(CUDA.compile), linker::typeof(CUDA.link))
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/execution.jl:237
   [12] cached_compilation(cache::Dict{…}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{…}, compiler::Function, linker::Function)
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/execution.jl:151
   [13] macro expansion
      @ ~/.julia/packages/CUDA/Tl08O/src/compiler/execution.jl:369 [inlined]
   [14] macro expansion
      @ ./lock.jl:267 [inlined]
   [15] cufunction(f::GPUArrays.var"#35#37", tt::Type{Tuple{…}}; kwargs::@Kwargs{})
      @ CUDA ~/.julia/packages/CUDA/Tl08O/src/compiler/execution.jl:364
   [16] cufunction
      @ ~/.julia/packages/CUDA/Tl08O/src/compiler/execution.jl:361 [inlined]
   [17] macro expansion
      @ ~/.julia/packages/CUDA/Tl08O/src/compiler/execution.jl:112 [inlined]
   [18] #launch_heuristic#1204
      @ ~/.julia/packages/CUDA/Tl08O/src/gpuarrays.jl:17 [inlined]
   [19] launch_heuristic
      @ ~/.julia/packages/CUDA/Tl08O/src/gpuarrays.jl:15 [inlined]
   [20] _copyto!
      @ ~/.julia/packages/GPUArrays/8Y80U/src/host/broadcast.jl:78 [inlined]
   [21] copyto!
      @ ~/.julia/packages/GPUArrays/8Y80U/src/host/broadcast.jl:44 [inlined]
   [22] copy
      @ ~/.julia/packages/GPUArrays/8Y80U/src/host/broadcast.jl:29 [inlined]
   [23] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{…}, Nothing, typeof(+), Tuple{…}})
      @ Base.Broadcast ./broadcast.jl:903
   [24] accum!(g::Tracker.Grads, x::Tracker.Tracked{CuArray{Float32, 2, CUDA.DeviceMemory}}, Δ::Matrix{Float32})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/params.jl:46
   [25] back(g::Tracker.Grads, x::Tracker.Tracked{CuArray{Float32, 2, CUDA.DeviceMemory}}, Δ::Matrix{Float32})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:134
   [26] #710
      @ ~/.julia/packages/Tracker/NYUWw/src/back.jl:128 [inlined]
   [27] #64
      @ ./tuple.jl:628 [inlined]
   [28] BottomRF
      @ ./reduce.jl:86 [inlined]
   [29] _foldl_impl(op::Base.BottomRF{Base.var"#64#65"{…}}, init::Nothing, itr::Base.Iterators.Zip{Tuple{…}})
      @ Base ./reduce.jl:58
   [30] foldl_impl
      @ ./reduce.jl:48 [inlined]
   [31] mapfoldl_impl
      @ ./reduce.jl:44 [inlined]
   [32] mapfoldl
      @ ./reduce.jl:175 [inlined]
   [33] foldl
      @ ./reduce.jl:198 [inlined]
   [34] foreach
      @ ./tuple.jl:628 [inlined]
   [35] back_(g::Tracker.Grads, c::Tracker.Call{Tracker.var"#583#584"{…}, Tuple{…}}, Δ::Vector{Float32})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:128
   [36] back(g::Tracker.Grads, x::Tracker.Tracked{CuArray{Float32, 1, CUDA.DeviceMemory}}, Δ::Vector{Float32})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:140
   [37] #710
      @ ~/.julia/packages/Tracker/NYUWw/src/back.jl:128 [inlined]
   [38] #64
      @ ./tuple.jl:628 [inlined]
   [39] BottomRF
      @ ./reduce.jl:86 [inlined]
   [40] _foldl_impl
      @ ./reduce.jl:58 [inlined]
   [41] foldl_impl
      @ ./reduce.jl:48 [inlined]
   [42] mapfoldl_impl(f::typeof(identity), op::Base.var"#64#65"{…}, nt::Nothing, itr::Base.Iterators.Zip{…})
      @ Base ./reduce.jl:44
   [43] mapfoldl(f::Function, op::Function, itr::Base.Iterators.Zip{Tuple{Tuple{…}, Tuple{…}}}; init::Nothing)
      @ Base ./reduce.jl:175
   [44] mapfoldl
      @ ./reduce.jl:175 [inlined]
   [45] foldl
      @ ./reduce.jl:198 [inlined]
   [46] foreach(::Function, ::Tuple{Tracker.Tracked{…}, Tracker.Tracked{…}}, ::Tuple{Vector{…}, Vector{…}})
      @ Base ./tuple.jl:628
   [47] back_(g::Tracker.Grads, c::Tracker.Call{Tracker.var"#552#555"{…}, Tuple{…}}, Δ::Matrix{Float32})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:128
   [48] back(g::Tracker.Grads, x::Tracker.Tracked{CuArray{Float32, 2, CUDA.DeviceMemory}}, Δ::Matrix{Float32})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:140
   [49] #710
      @ ~/.julia/packages/Tracker/NYUWw/src/back.jl:128 [inlined]
   [50] #64
      @ ./tuple.jl:628 [inlined]
   [51] BottomRF
      @ ./reduce.jl:86 [inlined]
   [52] _foldl_impl(op::Base.BottomRF{Base.var"#64#65"{…}}, init::Nothing, itr::Base.Iterators.Zip{Tuple{…}})
      @ Base ./reduce.jl:58
--- the last 12 lines are repeated 98 more times ---
 [1229] foldl_impl
      @ ./reduce.jl:48 [inlined]
 [1230] mapfoldl_impl
      @ ./reduce.jl:44 [inlined]
 [1231] mapfoldl
      @ ./reduce.jl:175 [inlined]
 [1232] foldl
      @ ./reduce.jl:198 [inlined]
 [1233] foreach
      @ ./tuple.jl:628 [inlined]
 [1234] back_(g::Tracker.Grads, c::Tracker.Call{…}, Δ::RODESolution{…})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:128
 [1235] back(g::Tracker.Grads, x::Tracker.Tracked{…}, Δ::RODESolution{…})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:140
 [1236] #712
      @ ~/.julia/packages/Tracker/NYUWw/src/back.jl:155 [inlined]
 [1237] #715
      @ ~/.julia/packages/Tracker/NYUWw/src/back.jl:164 [inlined]
 [1238] (::SciMLSensitivity.var"#tracker_adjoint_backpass#368"{…})(ybar::RODESolution{…})
      @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4hOeN/src/concrete_solve.jl:1319
 [1239] ZBack
      @ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:211 [inlined]
 [1240] (::Zygote.var"#kw_zpullback#53"{…})(dy::RODESolution{…})
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:237
 [1241] #291
      @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [1242] (::Zygote.var"#2169#back#293"{…})(Δ::RODESolution{…})
      @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [1243] #solve#51
      @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1003 [inlined]
 [1244] (::Zygote.Pullback{…})(Δ::RODESolution{…})
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [1245] #291
      @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [1246] (::Zygote.var"#2169#back#293"{…})(Δ::RODESolution{…})
      @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [1247] solve
      @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:993 [inlined]
 [1248] (::Zygote.Pullback{…})(Δ::RODESolution{…})
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [1249] NeuralSDE
      @ ~/code/NeuroDynamics.jl/examples/mwe.jl:31 [inlined]
 [1250] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::CuArray{Float32, 3, CUDA.DeviceMemory})
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [1251] loss!
      @ ~/code/NeuroDynamics.jl/examples/mwe.jl:36 [inlined]
 [1252] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Float32, Nothing, Nothing})
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [1253] #39
      @ ~/code/NeuroDynamics.jl/examples/mwe.jl:48 [inlined]
 [1254] #291
      @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [1255] #2169#back
      @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [1256] OptimizationFunction
      @ ~/.julia/packages/SciMLBase/rR75x/src/scimlfunctions.jl:3763 [inlined]
 [1257] #291
      @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [1258] #2169#back
      @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [1259] #37
      @ ~/.julia/packages/OptimizationBase/mGHPN/ext/OptimizationZygoteExt.jl:94 [inlined]
 [1260] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [1261] #291
      @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [1262] #2169#back
      @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [1263] #39
      @ ~/.julia/packages/OptimizationBase/mGHPN/ext/OptimizationZygoteExt.jl:97 [inlined]
 [1264] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [1265] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91
 [1266] gradient(f::Function, args::ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{…}}})
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:148
 [1267] (::OptimizationZygoteExt.var"#38#56"{…})(::ComponentVector{…}, ::ComponentVector{…})
      @ OptimizationZygoteExt ~/.julia/packages/OptimizationBase/mGHPN/ext/OptimizationZygoteExt.jl:97
 [1268] macro expansion
      @ ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:68 [inlined]
 [1269] macro expansion
      @ ~/.julia/packages/Optimization/fPKIF/src/utils.jl:32 [inlined]
 [1270] __solve(cache::OptimizationCache{…})
      @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:66
 [1271] solve!(cache::OptimizationCache{…})
      @ SciMLBase ~/.julia/packages/SciMLBase/rR75x/src/solve.jl:188
 [1272] solve(::OptimizationProblem{…}, ::OptimiserChain{…}; kwargs::@Kwargs{…})
      @ SciMLBase ~/.julia/packages/SciMLBase/rR75x/src/solve.jl:96
in expression starting at /home/artiintel/ahmelg/code/NeuroDynamics.jl/examples/mwe.jl:50
Some type information was truncated. Use `show(err)` to see complete types.

I have been struggling with this for a while now so was wondering if someone has experienced the same issue or can spot out a solution!
Thnaks