DiffEqFlux neural_ode used with Flux.Train! is slower on GPU than CPU

***It appears Flux training of the DiffEqFlux Package method neural_ode is incompatible with (fast) GPU operations, But I’m hoping that perhaps I’m doing something wrong and you can point me to example code where gpu ops are fast when using neural_ode. ***

But in my simple tests GPU is 10x slower than CPU on a RTX2070 and intel i7

CPU timing:
1.287959 seconds (8.92 M allocations: 358.429 MiB, 7.67% gc time)

GPU timing
12.238495 seconds (52.03 M allocations: 1.937 GiB, 5.78% gc time)

Likely Cause:
Monitoring the nvidia-smi I see ~3 to 15% activity for the julia process!!! But if I run it with a simple ANN instead of a neural_ode I get GPU activities >80% utilizaton.

Low utilization is frequently caused by either arrays transiting repeatedly between gpu and cpu memory, or non-parallel operations (often a result of indexed loops rather than vector ops). Support for the former comes the timing-data above: the GPU uses vastly more memory & allocations, Suggesting perhaps that arrays are being mirrors between gpu and cpu unnecessarily.

Support for the latter hypothesis --indexed looping-- comes from the error/warning. Using neural_ODE with Flux.train!() gives me a warning that scalar ops are very slow but code runs.

┌ Warning: Performing scalar operations on GPU arrays: This is very slow, consider disallowing these operations with `allowscalar(false)`
└ @ GPUArrays ~/.julia/packages/GPUArrays/tIMl5/src/indexing.jl:16

if I assert CuArrays.allowscalar(false)
then I get an error raised instead of a warning in Flux.train! that relates to setting an index.

Since the traceback show the warning in indexing.lj passes through Tracker early on, I have tried looking at the code there to see if I can figure out why but I can’t understand it.
###########
I’m going to paste some code to reproduce this below, followed by the error traceback

using Flux, CuArrays, DiffEqFlux
using DifferentialEquations # needed to define Tsit5

# Uncomment one of the following modes of opperation
#gc = Array  # uncomment if you want to run on cpu
gc = cu      # uncomment if you want to run on gpu

# uncomment on of the following to trap the origin of the warning
#CuArrays.allowscalar(false)  # Flux.Train! gives error in GPU mode 
CuArrays.allowscalar(true)  # Flux.Train! gives warning in GPU mode 

# True model used to create the label data
dudt = Dense(2,2)
if gc==cu
    dudt = dudt |> gpu
end

# initial boundary condition input
u0 = gc(Float32[1,1])

tspan = (0.0f0,1.0f0)
t = range(tspan...,length=8)  # can make the data set larger 


###
# create the labels
###
n_ode_truth = x->neural_ode(dudt,x,tspan,Tsit5(),saveat=t,reltol=1e-7,abstol=1e-9)
labels= n_ode_truth(u0).data # Get the prediction using the correct initial condition

### 
# create the model we will train
###

dudt_train = Dense(2,2)  # random vals  create arbitrary initial params
if gc==cu
    dudt_train = dudt_train |> gpu
end

tracking = params(dudt_train)


n_ode = x->neural_ode(dudt_train,x,tspan,Tsit5(),saveat=t,reltol=1e-7,abstol=1e-9)
# show that the problem isn't in neural_ode by itself
pred = n_ode(u0)


function loss()
    sum(abs2, n_ode(u0)-labels)
end

data = Iterators.repeated((), 3)  # can make this do more iterations
@time Flux.train!(loss, tracking, data, ADAM(0.1f0))


Here is the error when CuArrays.allowscalar(false) is asserted in gpu mode (cg==cu)

julia> Flux.train!(loss, tracking, data, ADAM(0.1f0))
ERROR: scalar getindex is disallowed
Stacktrace:
 [1] assertscalar(::String) at /home/cems/.julia/packages/GPUArrays/tIMl5/src/indexing.jl:14
 [2] getindex at /home/cems/.julia/packages/GPUArrays/tIMl5/src/indexing.jl:54 [inlined]
 [3] getindex at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/LinearAlgebra/src/adjtrans.jl:177 [inlined]
 [4] _unsafe_getindex_rs at ./reshapedarray.jl:245 [inlined]
 [5] _unsafe_getindex at ./reshapedarray.jl:242 [inlined]
 [6] getindex at ./reshapedarray.jl:231 [inlined]
 [7] macro expansion at ./multidimensional.jl:699 [inlined]
 [8] macro expansion at ./cartesian.jl:64 [inlined]
 [9] macro expansion at ./multidimensional.jl:694 [inlined]
 [10] _unsafe_getindex! at ./multidimensional.jl:690 [inlined]
 [11] _unsafe_getindex(::IndexLinear, ::Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1}},Tuple{}}, ::UnitRange{Int64}) at ./multidimensional.jl:684
 [12] _getindex at ./multidimensional.jl:670 [inlined]
 [13] getindex(::Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1}},Tuple{}}, ::UnitRange{Int64}) at ./abstractarray.jl:981
 [14] (::getfield(Tracker, Symbol("##429#432")){Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1}},Tuple{}}})(::TrackedArray{…,CuArray{Float32,1}}) at /home/cems/.julia/packages/Tracker/cpxco/src/lib/array.jl:196
 [15] iterate at ./generator.jl:47 [inlined]
 [16] collect(::Base.Generator{Tuple{TrackedArray{…,CuArray{Float32,1}},TrackedArray{…,CuArray{Float32,1}}},getfield(Tracker, Symbol("##429#432")){Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1}},Tuple{}}}}) at ./array.jl:606
 [17] #428 at /home/cems/.julia/packages/Tracker/cpxco/src/lib/array.jl:193 [inlined]
 [18] back_(::Tracker.Call{getfield(Tracker, Symbol("##428#431")){Tuple{TrackedArray{…,CuArray{Float32,1}},TrackedArray{…,CuArray{Float32,1}}}},Tuple{Tracker.Tracked{CuArray{Float32,1}},Tracker.Tracked{CuArray{Float32,1}}}}, ::Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1}},Tuple{}}, ::Bool) at /home/cems/.julia/packages/Tracker/cpxco/src/back.jl:35
 [19] back(::Tracker.Tracked{CuArray{Float32,1}}, ::Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1}},Tuple{}}, ::Bool) at /home/cems/.julia/packages/Tracker/cpxco/src/back.jl:58
 [20] (::getfield(Tracker, Symbol("##13#14")){Bool})(::Tracker.Tracked{CuArray{Float32,1}}, ::Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1}},Tuple{}}) at /home/cems/.julia/packages/Tracker/cpxco/src/back.jl:38
 [21] foreach(::Function, ::Tuple{Tracker.Tracked{CuArray{Float32,1}},Nothing,Nothing,Nothing}, ::Tuple{Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1}},Tuple{}},CuArray{Float32,1},Nothing,Nothing}) at ./abstractarray.jl:1921
 [22] back_(::Tracker.Call{getfield(DiffEqFlux, Symbol("##25#28")){DiffEqSensitivity.SensitivityAlg{0,true,Val{:central}},TrackedArray{…,CuArray{Float32,1}},CuArray{Float32,1},Tuple{Tsit5},ODESolution{Float32,2,Array{CuArray{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{CuArray{Float32,1},1},1},ODEProblem{CuArray{Float32,1},Tuple{Float32,Float32},false,CuArray{Float32,1},ODEFunction{false,getfield(DiffEqFlux, Symbol("#dudt_#32")){Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,getfield(DiffEqFlux, Symbol("#dudt_#32")){Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{CuArray{Float32,1},1},Array{Float32,1},Array{Array{CuArray{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats},Bool},Tuple{Tracker.Tracked{CuArray{Float32,1}},Nothing,Nothing,Nothing}}, ::CuArray{Float32,2}, ::Bool) at /home/cems/.julia/packages/Tracker/cpxco/src/back.jl:38
 [23] back(::Tracker.Tracked{CuArray{Float32,2}}, ::CuArray{Float32,2}, ::Bool) at /home/cems/.julia/packages/Tracker/cpxco/src/back.jl:58
 [24] #13 at /home/cems/.julia/packages/Tracker/cpxco/src/back.jl:38 [inlined]
 [25] foreach at ./abstractarray.jl:1921 [inlined]
 [26] back_(::Tracker.Call{getfield(Tracker, Symbol("#back#550")){2,typeof(-),Tuple{TrackedArray{…,CuArray{Float32,2}},CuArray{Float32,2}}},Tuple{Tracker.Tracked{CuArray{Float32,2}},Nothing}}, ::CuArray{Float32,2}, ::Bool) at /home/cems/.julia/packages/Tracker/cpxco/src/back.jl:38
 [27] back(::Tracker.Tracked{CuArray{Float32,2}}, ::CuArray{Float32,2}, ::Bool) at /home/cems/.julia/packages/Tracker/cpxco/src/back.jl:58
 [28] foreach at /home/cems/.julia/packages/Tracker/cpxco/src/back.jl:38 [inlined]
 [29] back_(::Tracker.Call{getfield(Tracker, Symbol("#back#550")){1,typeof(abs2),Tuple{TrackedArray{…,CuArray{Float32,2}}}},Tuple{Tracker.Tracked{CuArray{Float32,2}}}}, ::CuArray{Float32,2}, ::Bool) at /home/cems/.julia/packages/Tracker/cpxco/src/back.jl:38
 [30] back(::Tracker.Tracked{CuArray{Float32,2}}, ::CuArray{Float32,2}, ::Bool) at /home/cems/.julia/packages/Tracker/cpxco/src/back.jl:58
 [31] #13 at /home/cems/.julia/packages/Tracker/cpxco/src/back.jl:38 [inlined]
 [32] foreach at ./abstractarray.jl:1921 [inlined]
 [33] back_(::Tracker.Call{getfield(Tracker, Symbol("##484#485")){TrackedArray{…,CuArray{Float32,2}}},Tuple{Tracker.Tracked{CuArray{Float32,2}}}}, ::Float32, ::Bool) at /home/cems/.julia/packages/Tracker/cpxco/src/back.jl:38
 [34] back(::Tracker.Tracked{Float32}, ::Int64, ::Bool) at /home/cems/.julia/packages/Tracker/cpxco/src/back.jl:58
 [35] #back!#15 at /home/cems/.julia/packages/Tracker/cpxco/src/back.jl:77 [inlined]
 [36] #back! at ./none:0 [inlined]
 [37] #back!#32 at /home/cems/.julia/packages/Tracker/cpxco/src/lib/real.jl:16 [inlined]
 [38] back!(::Tracker.TrackedReal{Float32}) at /home/cems/.julia/packages/Tracker/cpxco/src/lib/real.jl:14
 [39] gradient_(::getfield(Flux.Optimise, Symbol("##14#20")){typeof(loss),Tuple{}}, ::Tracker.Params) at /home/cems/.julia/packages/Tracker/cpxco/src/back.jl:4
 [40] #gradient#24(::Bool, ::typeof(Tracker.gradient), ::Function, ::Tracker.Params) at /home/cems/.julia/packages/Tracker/cpxco/src/back.jl:164
 [41] gradient at /home/cems/.julia/packages/Tracker/cpxco/src/back.jl:164 [inlined]
 [42] macro expansion at /home/cems/.julia/packages/Flux/dkJUV/src/optimise/train.jl:71 [inlined]
 [43] macro expansion at /home/cems/.julia/packages/Juno/oLB1d/src/progress.jl:119 [inlined]
 [44] #train!#12(::getfield(Flux.Optimise, Symbol("##16#22")), ::typeof(Flux.Optimise.train!), ::Function, ::Tracker.Params, ::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{}}}, ::ADAM) at /home/cems/.julia/packages/Flux/dkJUV/src/optimise/train.jl:69
 [45] train!(::Function, ::Tracker.Params, ::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{}}}, ::ADAM) at /home/cems/.julia/packages/Flux/dkJUV/src/optimise/train.jl:67
 [46] top-level scope at none:0