Hello,
In my code I would like to use the integrator interface to DifferentialEquations inside CUDA kernels. Unfortunately, the type of integrators is not isbits and therefore can not be accessed inside the kernel. I wonder if anyone can suggest any workarounds.
Here is a more detailed description of my problem. I have a differential equation of the following type:
\displaystyle \frac{\partial u(r,t)}{\partial t}=2 \, g(r,t) \, [1 - u(r,t)],
where the function g(r,t) is predefined on a grid (r,t). As you can see, the variable r plays a role of a parameter. Therefore, the above equation can be represented as a set of independent equations (I do not want to consider them as a system of equations, because in the original code the above equation is only one from another system of equations). As a result, the solution can be effectively parallelized, and I would like to use GPU for this purpose.
The serial CPU code can be written as follows:
using PyPlot
using DifferentialEquations
function f(u, p, t)
a, g = p
return a * g * (1 - u)
end
function mysolve!(u, integrator, dt, g)
Nr, Nt = size(u)
for i=1:Nr
for j=1:Nt
integrator.p = (2., g[i, j])
step!(integrator, dt, true)
u[i, j] = integrator.u
end
reinit!(integrator)
end
return nothing
end
function main()
Nr, Nt = 512, 1024
r = range(0., 3., length=Nr)
t = range(-3., 3., length=Nt)
dt = t[2] - t[1]
g = zeros((Nr, Nt))
for j=1:Nt
for i=1:Nr
g[i, j] = exp(-r[i]^2) * exp(-t[j]^2)
end
end
u0 = 0.
p = (2., g[1, 1])
tspan = (t[1], t[end])
prob = ODEProblem(f, u0, tspan, p)
integrator = init(prob, Tsit5(), dense=false)
u = zeros((Nr, Nt))
mysolve!(u, integrator, dt, g)
plot(t, g[1, :])
plot(t, u[1, :])
show()
end
main()
A straightforward translation into the GPU code will be:
using PyPlot
using DifferentialEquations
using CUDAnative
using CuArrays
function f(u, p, t)
a, g = p
return a * g * (1 - u)
end
function mysolve!(u, integrator, dt, g)
Nr, Nt = size(u)
nth = 256
nbl = Int(ceil(Nr / nth))
@cuda blocks=nbl threads=nth mysolve_kernel(u, integrator, dt, g)
return nothing
end
function mysolve_kernel(u, integrator, dt, g)
id = (blockIdx().x - 1) * blockDim().x + threadIdx().x
stride = blockDim().x * gridDim().x
Nr, Nt = size(u)
for i=id:stride:Nr
for j=1:Nt
integrator.p = (2., g[i, j])
step!(integrator, dt, true)
u[i, j] = integrator.u
end
reinit!(integrator)
end
return nothing
end
function main()
Nr, Nt = 512, 1024
r = range(0., 3., length=Nr)
t = range(-3., 3., length=Nt)
dt = t[2] - t[1]
g = zeros((Nr, Nt))
for j=1:Nt
for i=1:Nr
g[i, j] = exp(-r[i]^2) * exp(-t[j]^2)
end
end
u0 = 0.
p = (2., g[1, 1])
tspan = (t[1], t[end])
prob = ODEProblem(f, u0, tspan, p)
integrator = init(prob, Tsit5(), dense=false)
gd = CuArrays.CuArray(g)
ud = CuArrays.zeros((Nr, Nt))
mysolve!(ud, integrator, dt, gd)
u = CuArrays.collect(ud)
plot(t, g[1, :])
plot(t, u[1, :])
show()
end
main()
Unfortunately this GPU code does not compile. I guess, the main reason for it is the non isbits type of the integrator:
typeof(integrator) = OrdinaryDiffEq.ODEIntegrator{Tsit5,false,Float64,Float64,Tuple{Float64,Float64},Float64,Float64,Float64,Array{Float64,1},ODESolution{Float64,1,Array{Float64,1},Nothing,Nothing,Array{Float64,1},Array{Array{Float64,1},1},ODEProblem{Float64,Tuple{Float64,Float64},false,Tuple{Float64,Float64},ODEFunction{false,typeof(f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,typeof(f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Float64,1},Array{Float64,1},Array{Array{Float64,1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}},DiffEqBase.DEStats},ODEFunction{false,typeof(f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64},OrdinaryDiffEq.DEOptions{Float64,Float64,Float64,Float64,typeof(DiffEqBase.ODE_DEFAULT_NORM),typeof(LinearAlgebra.opnorm),CallbackSet{Tuple{},Tuple{}},typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN),typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE),typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK),DataStructures.BinaryHeap{Float64,DataStructures.LessThan},DataStructures.BinaryHeap{Float64,DataStructures.LessThan},Nothing,Nothing,Int64,Array{Float64,1},Array{Float64,1},Array{Float64,1}},Float64,Float64,Nothing}
Here you can see an extensive presence of Array{Float64,1}
types which are not isbits. Therefore, I wonder, are there any way to tell the integrator to use instead, for example, StaticArrays. Or if anyone can suggest any other solution, I would be very happy.
Thank you.