Training Large (e.g, Neural) SDEs on GPUs fails. The only way to obtain adjoints is via TrackerAdjoint()
and this only currently works on CPU.
None of the continuous adjoints methods, e.g. InterpolatingAdjoint()
or BackwardsolveAdjoint()
work either on cpu or gpu.
- I suspect the problem with the continuous methods is the shape of the noise during the backwards solve.
- W.r.t.
TrackerAdjoint()
on gpus, something is transferred to the CPU during the backwards pass. This also happens for ODEs btw.
MWE
using DifferentialEquations, Lux, ComponentArrays, Random, SciMLSensitivity, Zygote, BenchmarkTools, LuxCUDA, CUDA,
OptimizationOptimisers
dev = gpu_device()
sensealg = TrackerAdjoint() #This works only on cpu
data = rand32(32,100,512) |> dev
x₀ = rand32(32,512) |> dev
ts = range(0.0f0, 1.0f0, length=100)
drift = Dense(32, 32, tanh)
diffusion = Scale(32, sigmoid)
basic_tgrad(u, p, t) = zero(u)
struct NeuralSDE{D, F} <: Lux.AbstractExplicitContainerLayer{(:drift, :diffusion)}
drift::D
diffusion::F
solver
tspan
sensealg
end
function (model::NeuralSDE)(x₀, ts, p, st)
μ(u, p, t) = model.drift(u, p.drift, st.drift)[1]
σ(u, p, t) = model.diffusion(u, p.diffusion, st.diffusion)[1]
func = SDEFunction{false}(μ, σ; tgrad=basic_tgrad)
prob = SDEProblem{false}(func, x₀, model.tspan, p)
sol = solve(prob, model.solver; saveat=ts, dt=0.01f0, sensealg = model.sensealg)
return permutedims(cat(sol.u..., dims=3), (1,3,2))
end
function loss!(p, data)
pred = model(x₀, ts, p, st)
l = sum(abs2, data .- pred)
return l, st, pred
end
rng = Random.default_rng()
model = NeuralSDE(drift, diffusion, EM(), (0.0f0, 1.0f0), sensealg)
p, st = Lux.setup(rng, model)
p = p |> ComponentArray{Float32} |> dev
adtype = AutoZygote()
optf = OptimizationFunction((p, _ ) -> loss!(p, data), adtype)
optproblem = OptimizationProblem(optf, p)
result = Optimization.solve(optproblem, ADAMW(5e-4), maxiters=10)
Error & Stacktrace
ERROR: LoadError: GPU compilation of MethodInstance for (::GPUArrays.var"#35#37")(::CUDA.CuKernelContext, ::CuDeviceMatrix{…}, ::Base.Broadcast.Broadcasted{…}, ::Int64) failed
KernelError: passing and using non-bitstype argument
Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2, CUDA.DeviceMemory}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(+), Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, Base.Broadcast.Extruded{Matrix{Float32}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, which is not isbits:
.args is of type Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, Base.Broadcast.Extruded{Matrix{Float32}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}} which is not isbits.
.2 is of type Base.Broadcast.Extruded{Matrix{Float32}, Tuple{Bool, Bool}, Tuple{Int64, Int64}} which is not isbits.
.x is of type Matrix{Float32} which is not isbits.
Stacktrace:
[1] check_invocation(job::GPUCompiler.CompilerJob)
@ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/validation.jl:92
[2] macro expansion
@ ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:128 [inlined]
[3] macro expansion
@ ~/.julia/packages/TimerOutputs/Lw5SP/src/TimerOutput.jl:253 [inlined]
[4]
@ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:126
[5]
@ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:111
[6] compile
@ ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:103 [inlined]
[7] #1145
@ ~/.julia/packages/CUDA/Tl08O/src/compiler/compilation.jl:254 [inlined]
[8] JuliaContext(f::CUDA.var"#1145#1148"{GPUCompiler.CompilerJob{…}}; kwargs::@Kwargs{})
@ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:52
[9] JuliaContext(f::Function)
@ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:42
[10] compile(job::GPUCompiler.CompilerJob)
@ CUDA ~/.julia/packages/CUDA/Tl08O/src/compiler/compilation.jl:253
[11] actual_compilation(cache::Dict{…}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{…}, compiler::typeof(CUDA.compile), linker::typeof(CUDA.link))
@ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/execution.jl:237
[12] cached_compilation(cache::Dict{…}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{…}, compiler::Function, linker::Function)
@ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/execution.jl:151
[13] macro expansion
@ ~/.julia/packages/CUDA/Tl08O/src/compiler/execution.jl:369 [inlined]
[14] macro expansion
@ ./lock.jl:267 [inlined]
[15] cufunction(f::GPUArrays.var"#35#37", tt::Type{Tuple{…}}; kwargs::@Kwargs{})
@ CUDA ~/.julia/packages/CUDA/Tl08O/src/compiler/execution.jl:364
[16] cufunction
@ ~/.julia/packages/CUDA/Tl08O/src/compiler/execution.jl:361 [inlined]
[17] macro expansion
@ ~/.julia/packages/CUDA/Tl08O/src/compiler/execution.jl:112 [inlined]
[18] #launch_heuristic#1204
@ ~/.julia/packages/CUDA/Tl08O/src/gpuarrays.jl:17 [inlined]
[19] launch_heuristic
@ ~/.julia/packages/CUDA/Tl08O/src/gpuarrays.jl:15 [inlined]
[20] _copyto!
@ ~/.julia/packages/GPUArrays/8Y80U/src/host/broadcast.jl:78 [inlined]
[21] copyto!
@ ~/.julia/packages/GPUArrays/8Y80U/src/host/broadcast.jl:44 [inlined]
[22] copy
@ ~/.julia/packages/GPUArrays/8Y80U/src/host/broadcast.jl:29 [inlined]
[23] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{…}, Nothing, typeof(+), Tuple{…}})
@ Base.Broadcast ./broadcast.jl:903
[24] accum!(g::Tracker.Grads, x::Tracker.Tracked{CuArray{Float32, 2, CUDA.DeviceMemory}}, Δ::Matrix{Float32})
@ Tracker ~/.julia/packages/Tracker/NYUWw/src/params.jl:46
[25] back(g::Tracker.Grads, x::Tracker.Tracked{CuArray{Float32, 2, CUDA.DeviceMemory}}, Δ::Matrix{Float32})
@ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:134
[26] #710
@ ~/.julia/packages/Tracker/NYUWw/src/back.jl:128 [inlined]
[27] #64
@ ./tuple.jl:628 [inlined]
[28] BottomRF
@ ./reduce.jl:86 [inlined]
[29] _foldl_impl(op::Base.BottomRF{Base.var"#64#65"{…}}, init::Nothing, itr::Base.Iterators.Zip{Tuple{…}})
@ Base ./reduce.jl:58
[30] foldl_impl
@ ./reduce.jl:48 [inlined]
[31] mapfoldl_impl
@ ./reduce.jl:44 [inlined]
[32] mapfoldl
@ ./reduce.jl:175 [inlined]
[33] foldl
@ ./reduce.jl:198 [inlined]
[34] foreach
@ ./tuple.jl:628 [inlined]
[35] back_(g::Tracker.Grads, c::Tracker.Call{Tracker.var"#583#584"{…}, Tuple{…}}, Δ::Vector{Float32})
@ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:128
[36] back(g::Tracker.Grads, x::Tracker.Tracked{CuArray{Float32, 1, CUDA.DeviceMemory}}, Δ::Vector{Float32})
@ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:140
[37] #710
@ ~/.julia/packages/Tracker/NYUWw/src/back.jl:128 [inlined]
[38] #64
@ ./tuple.jl:628 [inlined]
[39] BottomRF
@ ./reduce.jl:86 [inlined]
[40] _foldl_impl
@ ./reduce.jl:58 [inlined]
[41] foldl_impl
@ ./reduce.jl:48 [inlined]
[42] mapfoldl_impl(f::typeof(identity), op::Base.var"#64#65"{…}, nt::Nothing, itr::Base.Iterators.Zip{…})
@ Base ./reduce.jl:44
[43] mapfoldl(f::Function, op::Function, itr::Base.Iterators.Zip{Tuple{Tuple{…}, Tuple{…}}}; init::Nothing)
@ Base ./reduce.jl:175
[44] mapfoldl
@ ./reduce.jl:175 [inlined]
[45] foldl
@ ./reduce.jl:198 [inlined]
[46] foreach(::Function, ::Tuple{Tracker.Tracked{…}, Tracker.Tracked{…}}, ::Tuple{Vector{…}, Vector{…}})
@ Base ./tuple.jl:628
[47] back_(g::Tracker.Grads, c::Tracker.Call{Tracker.var"#552#555"{…}, Tuple{…}}, Δ::Matrix{Float32})
@ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:128
[48] back(g::Tracker.Grads, x::Tracker.Tracked{CuArray{Float32, 2, CUDA.DeviceMemory}}, Δ::Matrix{Float32})
@ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:140
[49] #710
@ ~/.julia/packages/Tracker/NYUWw/src/back.jl:128 [inlined]
[50] #64
@ ./tuple.jl:628 [inlined]
[51] BottomRF
@ ./reduce.jl:86 [inlined]
[52] _foldl_impl(op::Base.BottomRF{Base.var"#64#65"{…}}, init::Nothing, itr::Base.Iterators.Zip{Tuple{…}})
@ Base ./reduce.jl:58
--- the last 12 lines are repeated 98 more times ---
[1229] foldl_impl
@ ./reduce.jl:48 [inlined]
[1230] mapfoldl_impl
@ ./reduce.jl:44 [inlined]
[1231] mapfoldl
@ ./reduce.jl:175 [inlined]
[1232] foldl
@ ./reduce.jl:198 [inlined]
[1233] foreach
@ ./tuple.jl:628 [inlined]
[1234] back_(g::Tracker.Grads, c::Tracker.Call{…}, Δ::RODESolution{…})
@ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:128
[1235] back(g::Tracker.Grads, x::Tracker.Tracked{…}, Δ::RODESolution{…})
@ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:140
[1236] #712
@ ~/.julia/packages/Tracker/NYUWw/src/back.jl:155 [inlined]
[1237] #715
@ ~/.julia/packages/Tracker/NYUWw/src/back.jl:164 [inlined]
[1238] (::SciMLSensitivity.var"#tracker_adjoint_backpass#368"{…})(ybar::RODESolution{…})
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4hOeN/src/concrete_solve.jl:1319
[1239] ZBack
@ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:211 [inlined]
[1240] (::Zygote.var"#kw_zpullback#53"{…})(dy::RODESolution{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:237
[1241] #291
@ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
[1242] (::Zygote.var"#2169#back#293"{…})(Δ::RODESolution{…})
@ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
[1243] #solve#51
@ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1003 [inlined]
[1244] (::Zygote.Pullback{…})(Δ::RODESolution{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[1245] #291
@ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
[1246] (::Zygote.var"#2169#back#293"{…})(Δ::RODESolution{…})
@ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
[1247] solve
@ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:993 [inlined]
[1248] (::Zygote.Pullback{…})(Δ::RODESolution{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[1249] NeuralSDE
@ ~/code/NeuroDynamics.jl/examples/mwe.jl:31 [inlined]
[1250] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::CuArray{Float32, 3, CUDA.DeviceMemory})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[1251] loss!
@ ~/code/NeuroDynamics.jl/examples/mwe.jl:36 [inlined]
[1252] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Float32, Nothing, Nothing})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[1253] #39
@ ~/code/NeuroDynamics.jl/examples/mwe.jl:48 [inlined]
[1254] #291
@ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
[1255] #2169#back
@ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
[1256] OptimizationFunction
@ ~/.julia/packages/SciMLBase/rR75x/src/scimlfunctions.jl:3763 [inlined]
[1257] #291
@ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
[1258] #2169#back
@ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
[1259] #37
@ ~/.julia/packages/OptimizationBase/mGHPN/ext/OptimizationZygoteExt.jl:94 [inlined]
[1260] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[1261] #291
@ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
[1262] #2169#back
@ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
[1263] #39
@ ~/.julia/packages/OptimizationBase/mGHPN/ext/OptimizationZygoteExt.jl:97 [inlined]
[1264] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[1265] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91
[1266] gradient(f::Function, args::ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{…}}})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:148
[1267] (::OptimizationZygoteExt.var"#38#56"{…})(::ComponentVector{…}, ::ComponentVector{…})
@ OptimizationZygoteExt ~/.julia/packages/OptimizationBase/mGHPN/ext/OptimizationZygoteExt.jl:97
[1268] macro expansion
@ ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:68 [inlined]
[1269] macro expansion
@ ~/.julia/packages/Optimization/fPKIF/src/utils.jl:32 [inlined]
[1270] __solve(cache::OptimizationCache{…})
@ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:66
[1271] solve!(cache::OptimizationCache{…})
@ SciMLBase ~/.julia/packages/SciMLBase/rR75x/src/solve.jl:188
[1272] solve(::OptimizationProblem{…}, ::OptimiserChain{…}; kwargs::@Kwargs{…})
@ SciMLBase ~/.julia/packages/SciMLBase/rR75x/src/solve.jl:96
in expression starting at /home/artiintel/ahmelg/code/NeuroDynamics.jl/examples/mwe.jl:50
Some type information was truncated. Use `show(err)` to see complete types.
I have been struggling with this for a while now so was wondering if someone has experienced the same issue or can spot out a solution!
Thnaks