CuArray supports bits type error when performing automatic differentiation in sciml_train (possibly due to array reshaping)

Hello everyone,

I’m trying to convert my code so that it would run on GPU, but I still encountered some issues.
My code is mainly using DiffEqFlux, with some array reshaping required inside the Neural ODE function itself. When the the code gets to the sciml_train part, I always get this same error: CuArray only supports bits types, and most likely because when it’s converted to CuArray{ReverseDiff}, my code might mess up due to the reshaping part (still unconfirmed and I don’t really know how to confirm it). Below is the stacktrace log and further down I also attached a snippet of my code.

ERROR: LoadError: CuArray only supports bits types
Stacktrace:
 [1] error(::String) at .\error.jl:33
 [2] CuArray{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1}(::UndefInitializer, ::Tuple{Int64}) at C:\Users\timot\.julia\packages\CUDA\dZvbp\src\array.jl:115
 [3] CuArray{ReverseDiff.TrackedReal{Float32,Float32,Nothing},N} where N(::UndefInitializer, ::Tuple{Int64}) at C:\Users\timot\.julia\packages\CUDA\dZvbp\src\array.jl:124
 [4] similar(::Type{CuArray{ReverseDiff.TrackedReal{Float32,Float32,Nothing},N} where N}, ::Tuple{Int64}) at .\abstractarray.jl:675
 [5] similar(::Type{CuArray{ReverseDiff.TrackedReal{Float32,Float32,Nothing},N} where N}, ::Tuple{Base.OneTo{Int64}}) at .\abstractarray.jl:674
 [6] similar(::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1},Tuple{Base.OneTo{Int64}},Base.Broadcast.var"#2#4"{Base.Broadcast.var"#8#10"{Base.Broadcast.var"#8#10"{Base.Broadcast.var"#1#3",Base.Broadcast.var"#5#6"{Base.Broadcast.var"#5#6"{Base.Broadcast.var"#7#9"}},Base.Broadcast.var"#11#12"{Base.Broadcast.var"#11#12"{Base.Broadcast.var"#13#14"}},Base.Broadcast.var"#15#16"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#18"}},typeof(*)},Base.Broadcast.var"#5#6"{Base.Broadcast.var"#5#6"{Base.Broadcast.var"#7#9"}},Base.Broadcast.var"#11#12"{Base.Broadcast.var"#11#12"{Base.Broadcast.var"#13#14"}},Base.Broadcast.var"#15#16"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#18"}},typeof(*)},typeof(+)},Tuple{SubArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}},1,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},Tuple{UnitRange{Int64}},true},Float32,SubArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}},1,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},Tuple{UnitRange{Int64}},true},ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}}}, ::Type{ReverseDiff.TrackedReal{Float32,Float32,Nothing}}) at C:\Users\timot\.julia\packages\CUDA\dZvbp\src\broadcast.jl:11
 [7] copy(::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1},Tuple{Base.OneTo{Int64}},Base.Broadcast.var"#2#4"{Base.Broadcast.var"#8#10"{Base.Broadcast.var"#8#10"{Base.Broadcast.var"#1#3",Base.Broadcast.var"#5#6"{Base.Broadcast.var"#5#6"{Base.Broadcast.var"#7#9"}},Base.Broadcast.var"#11#12"{Base.Broadcast.var"#11#12"{Base.Broadcast.var"#13#14"}},Base.Broadcast.var"#15#16"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#18"}},typeof(*)},Base.Broadcast.var"#5#6"{Base.Broadcast.var"#5#6"{Base.Broadcast.var"#7#9"}},Base.Broadcast.var"#11#12"{Base.Broadcast.var"#11#12"{Base.Broadcast.var"#13#14"}},Base.Broadcast.var"#15#16"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#18"}},typeof(*)},typeof(+)},Tuple{SubArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}},1,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},Tuple{UnitRange{Int64}},true},Float32,SubArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}},1,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},Tuple{UnitRange{Int64}},true},ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}}}) at .\broadcast.jl:862
 [8] copy(::Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Tuple{Base.OneTo{Int64}},typeof(+),Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(*),Tuple{SubArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}},1,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},Tuple{UnitRange{Int64}},true},Float32}},Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(*),Tuple{SubArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}},1,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},Tuple{UnitRange{Int64}},true},ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}}}}}) at C:\Users\timot\.julia\packages\ReverseDiff\vScHI\src\derivatives\broadcast.jl:100
 [9] materialize(::Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(+),Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(*),Tuple{SubArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}},1,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},Tuple{UnitRange{Int64}},true},Float32}},Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(*),Tuple{SubArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}},1,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},Tuple{UnitRange{Int64}},true},ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}}}}}) at .\broadcast.jl:837
 [10] flux_kernel(::ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}, ::ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}) at C:\Users\timot\Documents\Doctoral\Research\SmartANN\Julia Code\Universal ODE\hom_diff_sorp\hom_diff_sorp_sparse_gpu.jl:175
 [11] nn_ode(::ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}, ::ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}, ::ReverseDiff.TrackedReal{Float64,Float32,ReverseDiff.TrackedArray{Float64,Float32,1,Array{Float64,1},Array{Float32,1}}}) at C:\Users\timot\Documents\Doctoral\Research\SmartANN\Julia Code\Universal ODE\hom_diff_sorp\hom_diff_sorp_sparse_gpu.jl:194
 [12] ODEFunction at C:\Users\timot\.julia\packages\SciMLBase\eghDQ\src\scimlfunctions.jl:324 [inlined]
 [13] (::DiffEqSensitivity.var"#78#87"{ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing}})(::ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}, ::ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}, ::ReverseDiff.TrackedArray{Float64,Float32,1,Array{Float64,1},Array{Float32,1}}) at C:\Users\timot\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\adjoint_common.jl:132
 [14] ReverseDiff.GradientTape(::Function, ::Tuple{CuArray{Float32,1},CuArray{Float32,1},Array{Float64,1}}, ::ReverseDiff.GradientConfig{Tuple{ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},ReverseDiff.TrackedArray{Float64,Float32,1,Array{Float64,1},Array{Float32,1}}}}) at C:\Users\timot\.julia\packages\ReverseDiff\vScHI\src\api\tape.jl:207
 [15] ReverseDiff.GradientTape(::Function, ::Tuple{CuArray{Float32,1},CuArray{Float32,1},Array{Float64,1}}) at C:\Users\timot\.julia\packages\ReverseDiff\vScHI\src\api\tape.jl:204
 [16] adjointdiffcache(::Function, ::InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool}, ::Bool, ::ODESolution{Float32,2,Array{CuArray{Float32,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},ODEProblem{CuArray{Float32,1},Tuple{Float64,Float64},false,CuArray{Float32,1},ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},SciMLBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Array{CuArray{Float32,1},1},Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float64}},DiffEqBase.DEStats}, ::Nothing, ::ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing}; quad::Bool, noiseterm::Bool) at C:\Users\timot\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\adjoint_common.jl:131
 [17] DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction(::Function, ::InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool}, ::Bool, ::ODESolution{Float32,2,Array{CuArray{Float32,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},ODEProblem{CuArray{Float32,1},Tuple{Float64,Float64},false,CuArray{Float32,1},ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},SciMLBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Array{CuArray{Float32,1},1},Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float64}},DiffEqBase.DEStats}, ::Nothing, ::Function, ::Array{Float64,1}, ::NamedTuple{(:reltol, :abstol),Tuple{Float64,Float64}}; noiseterm::Bool) at C:\Users\timot\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\interpolating_adjoint.jl:55
 [18] ODEInterpolatingAdjointSensitivityFunction at C:\Users\timot\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\interpolating_adjoint.jl:22 [inlined]
 [19] ODEAdjointProblem(::ODESolution{Float32,2,Array{CuArray{Float32,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},ODEProblem{CuArray{Float32,1},Tuple{Float64,Float64},false,CuArray{Float32,1},ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},SciMLBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Array{CuArray{Float32,1},1},Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float64}},DiffEqBase.DEStats}, ::InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool}, ::DiffEqSensitivity.var"#df#134"{CuArray{Float32,2},CuArray{Float32,1},Colon}, ::StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}, ::Nothing; checkpoints::Array{Float64,1}, callback::CallbackSet{Tuple{},Tuple{}}, reltol::Float64, abstol::Float64, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at C:\Users\timot\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\interpolating_adjoint.jl:173       
 [20] _adjoint_sensitivities(::ODESolution{Float32,2,Array{CuArray{Float32,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},ODEProblem{CuArray{Float32,1},Tuple{Float64,Float64},false,CuArray{Float32,1},ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},SciMLBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Array{CuArray{Float32,1},1},Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float64}},DiffEqBase.DEStats}, ::InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool}, ::Tsit5, ::DiffEqSensitivity.var"#df#134"{CuArray{Float32,2},CuArray{Float32,1},Colon}, ::StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}, ::Nothing; abstol::Float64, reltol::Float64, checkpoints::Array{Float64,1}, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at C:\Users\timot\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\sensitivity_interface.jl:17
 [21] _adjoint_sensitivities(::ODESolution{Float32,2,Array{CuArray{Float32,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},ODEProblem{CuArray{Float32,1},Tuple{Float64,Float64},false,CuArray{Float32,1},ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},SciMLBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Array{CuArray{Float32,1},1},Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float64}},DiffEqBase.DEStats}, ::InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool}, ::Tsit5, ::Function, ::StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}, ::Nothing) at C:\Users\timot\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\sensitivity_interface.jl:13 (repeats 2 times)
 [22] adjoint_sensitivities(::ODESolution{Float32,2,Array{CuArray{Float32,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},ODEProblem{CuArray{Float32,1},Tuple{Float64,Float64},false,CuArray{Float32,1},ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},SciMLBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Array{CuArray{Float32,1},1},Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float64}},DiffEqBase.DEStats}, ::Tsit5, ::Vararg{Any,N} where N; sensealg::InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool}, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at C:\Users\timot\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\sensitivity_interface.jl:6
 [23] (::DiffEqSensitivity.var"#adjoint_sensitivity_backpass#133"{Tsit5,InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool},CuArray{Float32,1},CuArray{Float32,1},Tuple{},Colon})(::CuArray{Float32,2}) at C:\Users\timot\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\concrete_solve.jl:144   
 [24] (::DiffEqBase.var"#275#back#79"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#133"{Tsit5,InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool},CuArray{Float32,1},CuArray{Float32,1},Tuple{},Colon}})(::CuArray{Float32,2}) at C:\Users\timot\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:65     
 [25] predict_rd at C:\Users\timot\Documents\Doctoral\Research\SmartANN\Julia Code\Universal ODE\hom_diff_sorp\hom_diff_sorp_sparse_gpu.jl:206 [inlined]
 [26] (::typeof(∂(predict_rd)))(::CuArray{Float32,2}) at C:\Users\timot\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
 [27] loss_rd at C:\Users\timot\Documents\Doctoral\Research\SmartANN\Julia Code\Universal ODE\hom_diff_sorp\hom_diff_sorp_sparse_gpu.jl:211 [inlined]
 [28] (::typeof(∂(loss_rd)))(::Tuple{Float32,Nothing}) at C:\Users\timot\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
 [29] #150 at C:\Users\timot\.julia\packages\Zygote\ggM8Z\src\lib\lib.jl:191 [inlined]
 [30] (::Zygote.var"#1693#back#152"{Zygote.var"#150#151"{typeof(∂(loss_rd)),Tuple{Tuple{Nothing},Tuple{}}}})(::Tuple{Float32,Nothing}) at C:\Users\timot\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [31] #74 at C:\Users\timot\.julia\packages\DiffEqFlux\8UHw5\src\train.jl:120 [inlined]
 [32] (::typeof(∂(λ)))(::Float32) at C:\Users\timot\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
 [33] (::Zygote.var"#54#55"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float32) at C:\Users\timot\.julia\packages\Zygote\ggM8Z\src\compiler\interface.jl:172  
 [34] gradient(::Function, ::Zygote.Params) at C:\Users\timot\.julia\packages\Zygote\ggM8Z\src\compiler\interface.jl:49
 [35] macro expansion at C:\Users\timot\.julia\packages\DiffEqFlux\8UHw5\src\train.jl:119 [inlined]
 [36] macro expansion at C:\Users\timot\.julia\packages\ProgressLogging\6KXlp\src\ProgressLogging.jl:328 [inlined]
 [37] (::DiffEqFlux.var"#73#78"{var"#263#266",Int64,Bool,Bool,typeof(loss_rd),CuArray{Float32,1},Zygote.Params})() at C:\Users\timot\.julia\packages\DiffEqFlux\8UHw5\src\train.jl:64
 [38] maybe_with_logger(::DiffEqFlux.var"#73#78"{var"#263#266",Int64,Bool,Bool,typeof(loss_rd),CuArray{Float32,1},Zygote.Params}, ::Nothing) at C:\Users\timot\.julia\packages\DiffEqFlux\8UHw5\src\train.jl:39
 [39] sciml_train(::Function, ::CuArray{Float32,1}, ::ADAM, ::Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}; cb::Function, maxiters::Int64, progress::Bool, save_best::Bool) at C:\Users\timot\.julia\packages\DiffEqFlux\8UHw5\src\train.jl:63
 [40] top-level scope at C:\Users\timot\Documents\Doctoral\Research\SmartANN\Julia Code\Universal ODE\hom_diff_sorp\hom_diff_sorp_sparse_gpu.jl:318
 [41] include_string(::Function, ::Module, ::String, ::String) at .\loading.jl:1088
in expression starting at C:\Users\timot\Documents\Doctoral\Research\SmartANN\Julia Code\Universal ODE\hom_diff_sorp\hom_diff_sorp_sparse_gpu.jl:318

Snippet of my code:

n_weights = 10

#for the retardation factor
rx_nn = Chain(x -> transpose(x),
                Dense(1, n_weights, tanh),
                Dense(n_weights, 2*n_weights, tanh),
                Dense(2*n_weights, n_weights, tanh),
                Dense(n_weights, 1, σ),
                x -> @view x[1,:]) |> gpu

# initialize the numerical stencil and the exponent for normalization
p_stencil = [-1.1, 1.05]
p_exp = [1.0]

D0 = [0.1]

p1,re1 = Flux.destructure(rx_nn)
p = [p1;D0;p_stencil;p_exp] |> gpu

exp_base = convert(CuArray{Float32,1}, [10.0])

full_restructure(p) = re1(p[1:length(p1)]), p[end-3], p[end-2:end-1], p[end]

function flux_kernel(u,p)
    u_left = @view u[1:Nx-1]
    u_right = @view u[2:Nx]
    stencil_nbr = @view p[end-1:end-1]
    stencil_ctr = @view p[end-2:end-2]
    D0 = @view p[end-3]
    left = stencil_nbr .* u_left .+ stencil_ctr .* u_right
    left_bc_flux = [stencil_nbr .* left_BC .+ stencil_ctr .* u[1:1]][1]
    left_flux = vcat(left_bc_flux, left)
    right = stencil_ctr .* u_left .+ stencil_nbr .* u_right
    u_bc_left = @view u[Nx-1:Nx-1]
    u_bc_right = @view u[Nx:Nx]
    right_BC_cauchy = (u_bc_left .- u_bc_right) .* D0 .* dx
    right_bc_flux = [stencil_ctr .* u[Nx:Nx] .+ stencil_nbr .* right_BC_cauchy][1]
    right_flux = vcat(right, right_bc_flux)
    left_flux + right_flux
end

function state_kernel(u,p,flux)
    rx_nn = re1(p[1:length(p1)])
    ret = rx_nn(u)
    exp_pow = convert(CuArray{Float32,1}, [p[end]])
    ret ./ CUDA.pow.(exp_base,exp_pow) .* flux
end

function nn_ode(u,p,t)
    flux = flux_kernel(u,p)
    state_kernel(u,p,flux)
end

########################
# Soving the neural PDE and setting up loss function
########################
prob_nn = ODEProblem(nn_ode, c0, (0.0, T), p)
sol_nn = concrete_solve(prob_nn,Tsit5(), c0, p)

function predict_rd(θ)
  # No ReverseDiff if using Flux
  gpu(Array(concrete_solve(prob_nn,Tsit5(),c0,θ,saveat=dt,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP()))))
end

#match data and force the weights of the CNN to add up to zero
function loss_rd(p)
    pred = predict_rd(p)
    breakthrough_pred = @view pred[end,:]
    rx_nn = re1(p[1:length(p1)])
    u = collect(0:0.01:1) |> gpu
    exp_pow = @view p[end:end]
    D0 = @view p[end-3:end-3]
    ret_temp = rx_nn(u)
    ret = D0 ./ ret_temp .* CUDA.pow.(exp_base, exp_pow)
    loss = convert(Float32,10^3 * sum(abs2, breakthrough_data .- breakthrough_pred) + 10^2 * sum(relu.(ret[2:end] .- ret[1:end-1])) + 10^2 * abs(sum(p[end-2 : end-1])))
    return loss, pred
end

Please let me know if anybody can help and/or needs further information from my side.
Many thanks in advance! :smiley:

Yes, ReverseDiff doesn’t not support GPUs. Use ZygoteVJP (or maybe EnzymeVJP)

It works now! Thank you very much for the clarification.
One more question from my side: Does BFGS work with CuArrays? I seem to have it working with ADAM but not with BFGS.

We’re just waiting on @pkofod for that: https://github.com/JuliaNLSolvers/Optim.jl/pull/931 . Looks like nightly tests failed but unrelated?

Funny, I’m using Optim.LBFGS and CUDA which works perfectly using large arrays. Maybe that’s acceptable to you as well.

I’m pretty sure the nightly failures are just because I’m a tool and not using StableRNGs for the ParticleSwarm tests :slight_smile:

Just saw your presentation. Thanks for the shoutout and good to see that LBFGS seems to work well on the GPU already.