CuArray supports bits type error when performing automatic differentiation in sciml_train (possibly due to array reshaping)

Hello everyone,

I’m trying to convert my code so that it would run on GPU, but I still encountered some issues.
My code is mainly using DiffEqFlux, with some array reshaping required inside the Neural ODE function itself. When the the code gets to the sciml_train part, I always get this same error: CuArray only supports bits types, and most likely because when it’s converted to CuArray{ReverseDiff}, my code might mess up due to the reshaping part (still unconfirmed and I don’t really know how to confirm it). Below is the stacktrace log and further down I also attached a snippet of my code.

ERROR: LoadError: CuArray only supports bits types
Stacktrace:
 [1] error(::String) at .\error.jl:33
 [2] CuArray{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1}(::UndefInitializer, ::Tuple{Int64}) at C:\Users\timot\.julia\packages\CUDA\dZvbp\src\array.jl:115
 [3] CuArray{ReverseDiff.TrackedReal{Float32,Float32,Nothing},N} where N(::UndefInitializer, ::Tuple{Int64}) at C:\Users\timot\.julia\packages\CUDA\dZvbp\src\array.jl:124
 [4] similar(::Type{CuArray{ReverseDiff.TrackedReal{Float32,Float32,Nothing},N} where N}, ::Tuple{Int64}) at .\abstractarray.jl:675
 [5] similar(::Type{CuArray{ReverseDiff.TrackedReal{Float32,Float32,Nothing},N} where N}, ::Tuple{Base.OneTo{Int64}}) at .\abstractarray.jl:674
 [6] similar(::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1},Tuple{Base.OneTo{Int64}},Base.Broadcast.var"#2#4"{Base.Broadcast.var"#8#10"{Base.Broadcast.var"#8#10"{Base.Broadcast.var"#1#3",Base.Broadcast.var"#5#6"{Base.Broadcast.var"#5#6"{Base.Broadcast.var"#7#9"}},Base.Broadcast.var"#11#12"{Base.Broadcast.var"#11#12"{Base.Broadcast.var"#13#14"}},Base.Broadcast.var"#15#16"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#18"}},typeof(*)},Base.Broadcast.var"#5#6"{Base.Broadcast.var"#5#6"{Base.Broadcast.var"#7#9"}},Base.Broadcast.var"#11#12"{Base.Broadcast.var"#11#12"{Base.Broadcast.var"#13#14"}},Base.Broadcast.var"#15#16"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#18"}},typeof(*)},typeof(+)},Tuple{SubArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}},1,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},Tuple{UnitRange{Int64}},true},Float32,SubArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}},1,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},Tuple{UnitRange{Int64}},true},ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}}}, ::Type{ReverseDiff.TrackedReal{Float32,Float32,Nothing}}) at C:\Users\timot\.julia\packages\CUDA\dZvbp\src\broadcast.jl:11
 [7] copy(::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1},Tuple{Base.OneTo{Int64}},Base.Broadcast.var"#2#4"{Base.Broadcast.var"#8#10"{Base.Broadcast.var"#8#10"{Base.Broadcast.var"#1#3",Base.Broadcast.var"#5#6"{Base.Broadcast.var"#5#6"{Base.Broadcast.var"#7#9"}},Base.Broadcast.var"#11#12"{Base.Broadcast.var"#11#12"{Base.Broadcast.var"#13#14"}},Base.Broadcast.var"#15#16"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#18"}},typeof(*)},Base.Broadcast.var"#5#6"{Base.Broadcast.var"#5#6"{Base.Broadcast.var"#7#9"}},Base.Broadcast.var"#11#12"{Base.Broadcast.var"#11#12"{Base.Broadcast.var"#13#14"}},Base.Broadcast.var"#15#16"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#18"}},typeof(*)},typeof(+)},Tuple{SubArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}},1,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},Tuple{UnitRange{Int64}},true},Float32,SubArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}},1,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},Tuple{UnitRange{Int64}},true},ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}}}) at .\broadcast.jl:862
 [8] copy(::Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Tuple{Base.OneTo{Int64}},typeof(+),Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(*),Tuple{SubArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}},1,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},Tuple{UnitRange{Int64}},true},Float32}},Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(*),Tuple{SubArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}},1,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},Tuple{UnitRange{Int64}},true},ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}}}}}) at C:\Users\timot\.julia\packages\ReverseDiff\vScHI\src\derivatives\broadcast.jl:100
 [9] materialize(::Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(+),Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(*),Tuple{SubArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}},1,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},Tuple{UnitRange{Int64}},true},Float32}},Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(*),Tuple{SubArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}},1,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},Tuple{UnitRange{Int64}},true},ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}}}}}) at .\broadcast.jl:837
 [10] flux_kernel(::ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}, ::ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}) at C:\Users\timot\Documents\Doctoral\Research\SmartANN\Julia Code\Universal ODE\hom_diff_sorp\hom_diff_sorp_sparse_gpu.jl:175
 [11] nn_ode(::ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}, ::ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}, ::ReverseDiff.TrackedReal{Float64,Float32,ReverseDiff.TrackedArray{Float64,Float32,1,Array{Float64,1},Array{Float32,1}}}) at C:\Users\timot\Documents\Doctoral\Research\SmartANN\Julia Code\Universal ODE\hom_diff_sorp\hom_diff_sorp_sparse_gpu.jl:194
 [12] ODEFunction at C:\Users\timot\.julia\packages\SciMLBase\eghDQ\src\scimlfunctions.jl:324 [inlined]
 [13] (::DiffEqSensitivity.var"#78#87"{ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing}})(::ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}, ::ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}, ::ReverseDiff.TrackedArray{Float64,Float32,1,Array{Float64,1},Array{Float32,1}}) at C:\Users\timot\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\adjoint_common.jl:132
 [14] ReverseDiff.GradientTape(::Function, ::Tuple{CuArray{Float32,1},CuArray{Float32,1},Array{Float64,1}}, ::ReverseDiff.GradientConfig{Tuple{ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},ReverseDiff.TrackedArray{Float64,Float32,1,Array{Float64,1},Array{Float32,1}}}}) at C:\Users\timot\.julia\packages\ReverseDiff\vScHI\src\api\tape.jl:207
 [15] ReverseDiff.GradientTape(::Function, ::Tuple{CuArray{Float32,1},CuArray{Float32,1},Array{Float64,1}}) at C:\Users\timot\.julia\packages\ReverseDiff\vScHI\src\api\tape.jl:204
 [16] adjointdiffcache(::Function, ::InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool}, ::Bool, ::ODESolution{Float32,2,Array{CuArray{Float32,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},ODEProblem{CuArray{Float32,1},Tuple{Float64,Float64},false,CuArray{Float32,1},ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},SciMLBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Array{CuArray{Float32,1},1},Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float64}},DiffEqBase.DEStats}, ::Nothing, ::ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing}; quad::Bool, noiseterm::Bool) at C:\Users\timot\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\adjoint_common.jl:131
 [17] DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction(::Function, ::InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool}, ::Bool, ::ODESolution{Float32,2,Array{CuArray{Float32,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},ODEProblem{CuArray{Float32,1},Tuple{Float64,Float64},false,CuArray{Float32,1},ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},SciMLBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Array{CuArray{Float32,1},1},Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float64}},DiffEqBase.DEStats}, ::Nothing, ::Function, ::Array{Float64,1}, ::NamedTuple{(:reltol, :abstol),Tuple{Float64,Float64}}; noiseterm::Bool) at C:\Users\timot\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\interpolating_adjoint.jl:55
 [18] ODEInterpolatingAdjointSensitivityFunction at C:\Users\timot\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\interpolating_adjoint.jl:22 [inlined]
 [19] ODEAdjointProblem(::ODESolution{Float32,2,Array{CuArray{Float32,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},ODEProblem{CuArray{Float32,1},Tuple{Float64,Float64},false,CuArray{Float32,1},ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},SciMLBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Array{CuArray{Float32,1},1},Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float64}},DiffEqBase.DEStats}, ::InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool}, ::DiffEqSensitivity.var"#df#134"{CuArray{Float32,2},CuArray{Float32,1},Colon}, ::StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}, ::Nothing; checkpoints::Array{Float64,1}, callback::CallbackSet{Tuple{},Tuple{}}, reltol::Float64, abstol::Float64, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at C:\Users\timot\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\interpolating_adjoint.jl:173       
 [20] _adjoint_sensitivities(::ODESolution{Float32,2,Array{CuArray{Float32,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},ODEProblem{CuArray{Float32,1},Tuple{Float64,Float64},false,CuArray{Float32,1},ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},SciMLBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Array{CuArray{Float32,1},1},Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float64}},DiffEqBase.DEStats}, ::InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool}, ::Tsit5, ::DiffEqSensitivity.var"#df#134"{CuArray{Float32,2},CuArray{Float32,1},Colon}, ::StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}, ::Nothing; abstol::Float64, reltol::Float64, checkpoints::Array{Float64,1}, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at C:\Users\timot\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\sensitivity_interface.jl:17
 [21] _adjoint_sensitivities(::ODESolution{Float32,2,Array{CuArray{Float32,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},ODEProblem{CuArray{Float32,1},Tuple{Float64,Float64},false,CuArray{Float32,1},ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},SciMLBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Array{CuArray{Float32,1},1},Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float64}},DiffEqBase.DEStats}, ::InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool}, ::Tsit5, ::Function, ::StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}, ::Nothing) at C:\Users\timot\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\sensitivity_interface.jl:13 (repeats 2 times)
 [22] adjoint_sensitivities(::ODESolution{Float32,2,Array{CuArray{Float32,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},ODEProblem{CuArray{Float32,1},Tuple{Float64,Float64},false,CuArray{Float32,1},ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},SciMLBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Array{CuArray{Float32,1},1},Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float64}},DiffEqBase.DEStats}, ::Tsit5, ::Vararg{Any,N} where N; sensealg::InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool}, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at C:\Users\timot\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\sensitivity_interface.jl:6
 [23] (::DiffEqSensitivity.var"#adjoint_sensitivity_backpass#133"{Tsit5,InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool},CuArray{Float32,1},CuArray{Float32,1},Tuple{},Colon})(::CuArray{Float32,2}) at C:\Users\timot\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\concrete_solve.jl:144   
 [24] (::DiffEqBase.var"#275#back#79"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#133"{Tsit5,InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool},CuArray{Float32,1},CuArray{Float32,1},Tuple{},Colon}})(::CuArray{Float32,2}) at C:\Users\timot\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:65     
 [25] predict_rd at C:\Users\timot\Documents\Doctoral\Research\SmartANN\Julia Code\Universal ODE\hom_diff_sorp\hom_diff_sorp_sparse_gpu.jl:206 [inlined]
 [26] (::typeof(∂(predict_rd)))(::CuArray{Float32,2}) at C:\Users\timot\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
 [27] loss_rd at C:\Users\timot\Documents\Doctoral\Research\SmartANN\Julia Code\Universal ODE\hom_diff_sorp\hom_diff_sorp_sparse_gpu.jl:211 [inlined]
 [28] (::typeof(∂(loss_rd)))(::Tuple{Float32,Nothing}) at C:\Users\timot\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
 [29] #150 at C:\Users\timot\.julia\packages\Zygote\ggM8Z\src\lib\lib.jl:191 [inlined]
 [30] (::Zygote.var"#1693#back#152"{Zygote.var"#150#151"{typeof(∂(loss_rd)),Tuple{Tuple{Nothing},Tuple{}}}})(::Tuple{Float32,Nothing}) at C:\Users\timot\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [31] #74 at C:\Users\timot\.julia\packages\DiffEqFlux\8UHw5\src\train.jl:120 [inlined]
 [32] (::typeof(∂(λ)))(::Float32) at C:\Users\timot\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
 [33] (::Zygote.var"#54#55"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float32) at C:\Users\timot\.julia\packages\Zygote\ggM8Z\src\compiler\interface.jl:172  
 [34] gradient(::Function, ::Zygote.Params) at C:\Users\timot\.julia\packages\Zygote\ggM8Z\src\compiler\interface.jl:49
 [35] macro expansion at C:\Users\timot\.julia\packages\DiffEqFlux\8UHw5\src\train.jl:119 [inlined]
 [36] macro expansion at C:\Users\timot\.julia\packages\ProgressLogging\6KXlp\src\ProgressLogging.jl:328 [inlined]
 [37] (::DiffEqFlux.var"#73#78"{var"#263#266",Int64,Bool,Bool,typeof(loss_rd),CuArray{Float32,1},Zygote.Params})() at C:\Users\timot\.julia\packages\DiffEqFlux\8UHw5\src\train.jl:64
 [38] maybe_with_logger(::DiffEqFlux.var"#73#78"{var"#263#266",Int64,Bool,Bool,typeof(loss_rd),CuArray{Float32,1},Zygote.Params}, ::Nothing) at C:\Users\timot\.julia\packages\DiffEqFlux\8UHw5\src\train.jl:39
 [39] sciml_train(::Function, ::CuArray{Float32,1}, ::ADAM, ::Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}; cb::Function, maxiters::Int64, progress::Bool, save_best::Bool) at C:\Users\timot\.julia\packages\DiffEqFlux\8UHw5\src\train.jl:63
 [40] top-level scope at C:\Users\timot\Documents\Doctoral\Research\SmartANN\Julia Code\Universal ODE\hom_diff_sorp\hom_diff_sorp_sparse_gpu.jl:318
 [41] include_string(::Function, ::Module, ::String, ::String) at .\loading.jl:1088
in expression starting at C:\Users\timot\Documents\Doctoral\Research\SmartANN\Julia Code\Universal ODE\hom_diff_sorp\hom_diff_sorp_sparse_gpu.jl:318

Snippet of my code:

n_weights = 10

#for the retardation factor
rx_nn = Chain(x -> transpose(x),
                Dense(1, n_weights, tanh),
                Dense(n_weights, 2*n_weights, tanh),
                Dense(2*n_weights, n_weights, tanh),
                Dense(n_weights, 1, σ),
                x -> @view x[1,:]) |> gpu

# initialize the numerical stencil and the exponent for normalization
p_stencil = [-1.1, 1.05]
p_exp = [1.0]

D0 = [0.1]

p1,re1 = Flux.destructure(rx_nn)
p = [p1;D0;p_stencil;p_exp] |> gpu

exp_base = convert(CuArray{Float32,1}, [10.0])

full_restructure(p) = re1(p[1:length(p1)]), p[end-3], p[end-2:end-1], p[end]

function flux_kernel(u,p)
    u_left = @view u[1:Nx-1]
    u_right = @view u[2:Nx]
    stencil_nbr = @view p[end-1:end-1]
    stencil_ctr = @view p[end-2:end-2]
    D0 = @view p[end-3]
    left = stencil_nbr .* u_left .+ stencil_ctr .* u_right
    left_bc_flux = [stencil_nbr .* left_BC .+ stencil_ctr .* u[1:1]][1]
    left_flux = vcat(left_bc_flux, left)
    right = stencil_ctr .* u_left .+ stencil_nbr .* u_right
    u_bc_left = @view u[Nx-1:Nx-1]
    u_bc_right = @view u[Nx:Nx]
    right_BC_cauchy = (u_bc_left .- u_bc_right) .* D0 .* dx
    right_bc_flux = [stencil_ctr .* u[Nx:Nx] .+ stencil_nbr .* right_BC_cauchy][1]
    right_flux = vcat(right, right_bc_flux)
    left_flux + right_flux
end

function state_kernel(u,p,flux)
    rx_nn = re1(p[1:length(p1)])
    ret = rx_nn(u)
    exp_pow = convert(CuArray{Float32,1}, [p[end]])
    ret ./ CUDA.pow.(exp_base,exp_pow) .* flux
end

function nn_ode(u,p,t)
    flux = flux_kernel(u,p)
    state_kernel(u,p,flux)
end

########################
# Soving the neural PDE and setting up loss function
########################
prob_nn = ODEProblem(nn_ode, c0, (0.0, T), p)
sol_nn = concrete_solve(prob_nn,Tsit5(), c0, p)

function predict_rd(θ)
  # No ReverseDiff if using Flux
  gpu(Array(concrete_solve(prob_nn,Tsit5(),c0,θ,saveat=dt,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP()))))
end

#match data and force the weights of the CNN to add up to zero
function loss_rd(p)
    pred = predict_rd(p)
    breakthrough_pred = @view pred[end,:]
    rx_nn = re1(p[1:length(p1)])
    u = collect(0:0.01:1) |> gpu
    exp_pow = @view p[end:end]
    D0 = @view p[end-3:end-3]
    ret_temp = rx_nn(u)
    ret = D0 ./ ret_temp .* CUDA.pow.(exp_base, exp_pow)
    loss = convert(Float32,10^3 * sum(abs2, breakthrough_data .- breakthrough_pred) + 10^2 * sum(relu.(ret[2:end] .- ret[1:end-1])) + 10^2 * abs(sum(p[end-2 : end-1])))
    return loss, pred
end

Please let me know if anybody can help and/or needs further information from my side.
Many thanks in advance! :smiley:

Yes, ReverseDiff doesn’t not support GPUs. Use ZygoteVJP (or maybe EnzymeVJP)

It works now! Thank you very much for the clarification.
One more question from my side: Does BFGS work with CuArrays? I seem to have it working with ADAM but not with BFGS.

We’re just waiting on @pkofod for that: https://github.com/JuliaNLSolvers/Optim.jl/pull/931 . Looks like nightly tests failed but unrelated?

Funny, I’m using Optim.LBFGS and CUDA which works perfectly using large arrays. Maybe that’s acceptable to you as well.

1 Like

I’m pretty sure the nightly failures are just because I’m a tool and not using StableRNGs for the ParticleSwarm tests :slight_smile:

Just saw your presentation. Thanks for the shoutout and good to see that LBFGS seems to work well on the GPU already.

2 Likes