Hello,

I’m trying to run the following neural ODE program using GPU. If I run the code using GPU (NVIDIA GeForce GTX 1050), the simulation time of `Flux.train!(loss_n_ode, ps, data, opt, cb = cb)`

part was much longer than simulation using only CPU as shown below.

GPU: 362.607972 seconds (719.08 M allocations: 24.526 GiB, 3.86% gc time)

CPU: 35.925501 seconds (150.99 M allocations: 13.473 GiB, 9.14% gc time)

Could you tell me how to fix the issue to make simulation speed of GPU code faster?

```
using Flux, DiffEqFlux, DifferentialEquations, Plots, CSV, CuArrays
# Read CSV file
flamedata = CSV.read("./results.csv")
flamedata2 = flamedata[2:2:6002,:]
u0 = Float32[0.; 0.11189834407236525; 0.8881016559276348; 0.]|>gpu
datasize = 3001
tspan = (0.0f0,0.0003f0)|>gpu
t = range(tspan[1],tspan[2],length=datasize)|>gpu
ode_data2 = Matrix(flamedata2[[:1,:4,:7,:9]])
ode_data2 = transpose(ode_data2)
ode_data2 = convert(Array{Float32}, ode_data2)
ode_data = ode_data2|>gpu
ode_data[1,:] = tanh.(ode_data[1,:]*100)|>gpu
dudt = Chain(
Dense(4,32,swish),
Dense(32,16,swish),
Dense(16,8,swish),
Dense(8,4)
)|>gpu
ps = Flux.params(dudt)|>gpu
n_ode = x->neural_ode(dudt,gpu(x),gpu(tspan),AutoTsit5(Rodas5(autodiff=false)),saveat=t,dtmin=1.0E-14,maxiters=1e10,reltol=1e-7,abstol=1e-9)
function predict_n_ode()
n_ode(u0)
end
loss_n_ode() = sum(abs2,ode_data .- predict_n_ode())
data = Iterators.repeated((), 5)
opt = ADAM(0.1, (0.9, 0.999))
cb = function () #callback function to observe training
display(loss_n_ode())
# plot current prediction against data
cur_pred = Flux.data(predict_n_ode())
pl1 = plot(t,ode_data[1,:],label="data1",lw=2)
plot!(pl1,t,cur_pred[1,:],label="prediction1",lw=2)
plot!(pl1,t,ode_data[2,:],label="data2",lw=2)
plot!(pl1,t,cur_pred[2,:],label="prediction2",lw=2)
plot!(pl1,t,ode_data[3,:],label="data3",lw=2)
plot!(pl1,t,cur_pred[3,:],label="prediction3",lw=2)
plot!(pl1,t,ode_data[4,:],label="data4",lw=2)
plot!(pl1,t,cur_pred[4,:],label="prediction4",lw=2)
gui(plot(pl1))
end
# Display the ODE with the initial parameter values.
cb()
@time Flux.train!(loss_n_ode, ps, data, opt, cb = cb)
```