Hello everyone,
I’m trying to convert my code so that it would run on GPU, but I still encountered some issues.
My code is mainly using DiffEqFlux, with some array reshaping required inside the Neural ODE function itself. When the the code gets to the sciml_train part, I always get this same error: CuArray only supports bits types, and most likely because when it’s converted to CuArray{ReverseDiff}, my code might mess up due to the reshaping part (still unconfirmed and I don’t really know how to confirm it). Below is the stacktrace log and further down I also attached a snippet of my code.
ERROR: LoadError: CuArray only supports bits types
Stacktrace:
[1] error(::String) at .\error.jl:33
[2] CuArray{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1}(::UndefInitializer, ::Tuple{Int64}) at C:\Users\timot\.julia\packages\CUDA\dZvbp\src\array.jl:115
[3] CuArray{ReverseDiff.TrackedReal{Float32,Float32,Nothing},N} where N(::UndefInitializer, ::Tuple{Int64}) at C:\Users\timot\.julia\packages\CUDA\dZvbp\src\array.jl:124
[4] similar(::Type{CuArray{ReverseDiff.TrackedReal{Float32,Float32,Nothing},N} where N}, ::Tuple{Int64}) at .\abstractarray.jl:675
[5] similar(::Type{CuArray{ReverseDiff.TrackedReal{Float32,Float32,Nothing},N} where N}, ::Tuple{Base.OneTo{Int64}}) at .\abstractarray.jl:674
[6] similar(::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1},Tuple{Base.OneTo{Int64}},Base.Broadcast.var"#2#4"{Base.Broadcast.var"#8#10"{Base.Broadcast.var"#8#10"{Base.Broadcast.var"#1#3",Base.Broadcast.var"#5#6"{Base.Broadcast.var"#5#6"{Base.Broadcast.var"#7#9"}},Base.Broadcast.var"#11#12"{Base.Broadcast.var"#11#12"{Base.Broadcast.var"#13#14"}},Base.Broadcast.var"#15#16"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#18"}},typeof(*)},Base.Broadcast.var"#5#6"{Base.Broadcast.var"#5#6"{Base.Broadcast.var"#7#9"}},Base.Broadcast.var"#11#12"{Base.Broadcast.var"#11#12"{Base.Broadcast.var"#13#14"}},Base.Broadcast.var"#15#16"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#18"}},typeof(*)},typeof(+)},Tuple{SubArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}},1,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},Tuple{UnitRange{Int64}},true},Float32,SubArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}},1,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},Tuple{UnitRange{Int64}},true},ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}}}, ::Type{ReverseDiff.TrackedReal{Float32,Float32,Nothing}}) at C:\Users\timot\.julia\packages\CUDA\dZvbp\src\broadcast.jl:11
[7] copy(::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1},Tuple{Base.OneTo{Int64}},Base.Broadcast.var"#2#4"{Base.Broadcast.var"#8#10"{Base.Broadcast.var"#8#10"{Base.Broadcast.var"#1#3",Base.Broadcast.var"#5#6"{Base.Broadcast.var"#5#6"{Base.Broadcast.var"#7#9"}},Base.Broadcast.var"#11#12"{Base.Broadcast.var"#11#12"{Base.Broadcast.var"#13#14"}},Base.Broadcast.var"#15#16"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#18"}},typeof(*)},Base.Broadcast.var"#5#6"{Base.Broadcast.var"#5#6"{Base.Broadcast.var"#7#9"}},Base.Broadcast.var"#11#12"{Base.Broadcast.var"#11#12"{Base.Broadcast.var"#13#14"}},Base.Broadcast.var"#15#16"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#18"}},typeof(*)},typeof(+)},Tuple{SubArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}},1,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},Tuple{UnitRange{Int64}},true},Float32,SubArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}},1,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},Tuple{UnitRange{Int64}},true},ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}}}) at .\broadcast.jl:862
[8] copy(::Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Tuple{Base.OneTo{Int64}},typeof(+),Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(*),Tuple{SubArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}},1,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},Tuple{UnitRange{Int64}},true},Float32}},Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(*),Tuple{SubArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}},1,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},Tuple{UnitRange{Int64}},true},ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}}}}}) at C:\Users\timot\.julia\packages\ReverseDiff\vScHI\src\derivatives\broadcast.jl:100
[9] materialize(::Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(+),Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(*),Tuple{SubArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}},1,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},Tuple{UnitRange{Int64}},true},Float32}},Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(*),Tuple{SubArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}},1,ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},Tuple{UnitRange{Int64}},true},ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}}}}}) at .\broadcast.jl:837
[10] flux_kernel(::ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}, ::ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}) at C:\Users\timot\Documents\Doctoral\Research\SmartANN\Julia Code\Universal ODE\hom_diff_sorp\hom_diff_sorp_sparse_gpu.jl:175
[11] nn_ode(::ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}, ::ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}, ::ReverseDiff.TrackedReal{Float64,Float32,ReverseDiff.TrackedArray{Float64,Float32,1,Array{Float64,1},Array{Float32,1}}}) at C:\Users\timot\Documents\Doctoral\Research\SmartANN\Julia Code\Universal ODE\hom_diff_sorp\hom_diff_sorp_sparse_gpu.jl:194
[12] ODEFunction at C:\Users\timot\.julia\packages\SciMLBase\eghDQ\src\scimlfunctions.jl:324 [inlined]
[13] (::DiffEqSensitivity.var"#78#87"{ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing}})(::ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}, ::ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}}, ::ReverseDiff.TrackedArray{Float64,Float32,1,Array{Float64,1},Array{Float32,1}}) at C:\Users\timot\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\adjoint_common.jl:132
[14] ReverseDiff.GradientTape(::Function, ::Tuple{CuArray{Float32,1},CuArray{Float32,1},Array{Float64,1}}, ::ReverseDiff.GradientConfig{Tuple{ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},ReverseDiff.TrackedArray{Float32,Float32,1,CuArray{Float32,1},CuArray{Float32,1}},ReverseDiff.TrackedArray{Float64,Float32,1,Array{Float64,1},Array{Float32,1}}}}) at C:\Users\timot\.julia\packages\ReverseDiff\vScHI\src\api\tape.jl:207
[15] ReverseDiff.GradientTape(::Function, ::Tuple{CuArray{Float32,1},CuArray{Float32,1},Array{Float64,1}}) at C:\Users\timot\.julia\packages\ReverseDiff\vScHI\src\api\tape.jl:204
[16] adjointdiffcache(::Function, ::InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool}, ::Bool, ::ODESolution{Float32,2,Array{CuArray{Float32,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},ODEProblem{CuArray{Float32,1},Tuple{Float64,Float64},false,CuArray{Float32,1},ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},SciMLBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Array{CuArray{Float32,1},1},Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float64}},DiffEqBase.DEStats}, ::Nothing, ::ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing}; quad::Bool, noiseterm::Bool) at C:\Users\timot\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\adjoint_common.jl:131
[17] DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction(::Function, ::InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool}, ::Bool, ::ODESolution{Float32,2,Array{CuArray{Float32,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},ODEProblem{CuArray{Float32,1},Tuple{Float64,Float64},false,CuArray{Float32,1},ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},SciMLBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Array{CuArray{Float32,1},1},Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float64}},DiffEqBase.DEStats}, ::Nothing, ::Function, ::Array{Float64,1}, ::NamedTuple{(:reltol, :abstol),Tuple{Float64,Float64}}; noiseterm::Bool) at C:\Users\timot\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\interpolating_adjoint.jl:55
[18] ODEInterpolatingAdjointSensitivityFunction at C:\Users\timot\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\interpolating_adjoint.jl:22 [inlined]
[19] ODEAdjointProblem(::ODESolution{Float32,2,Array{CuArray{Float32,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},ODEProblem{CuArray{Float32,1},Tuple{Float64,Float64},false,CuArray{Float32,1},ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},SciMLBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Array{CuArray{Float32,1},1},Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float64}},DiffEqBase.DEStats}, ::InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool}, ::DiffEqSensitivity.var"#df#134"{CuArray{Float32,2},CuArray{Float32,1},Colon}, ::StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}, ::Nothing; checkpoints::Array{Float64,1}, callback::CallbackSet{Tuple{},Tuple{}}, reltol::Float64, abstol::Float64, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at C:\Users\timot\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\interpolating_adjoint.jl:173
[20] _adjoint_sensitivities(::ODESolution{Float32,2,Array{CuArray{Float32,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},ODEProblem{CuArray{Float32,1},Tuple{Float64,Float64},false,CuArray{Float32,1},ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},SciMLBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Array{CuArray{Float32,1},1},Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float64}},DiffEqBase.DEStats}, ::InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool}, ::Tsit5, ::DiffEqSensitivity.var"#df#134"{CuArray{Float32,2},CuArray{Float32,1},Colon}, ::StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}, ::Nothing; abstol::Float64, reltol::Float64, checkpoints::Array{Float64,1}, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at C:\Users\timot\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\sensitivity_interface.jl:17
[21] _adjoint_sensitivities(::ODESolution{Float32,2,Array{CuArray{Float32,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},ODEProblem{CuArray{Float32,1},Tuple{Float64,Float64},false,CuArray{Float32,1},ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},SciMLBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Array{CuArray{Float32,1},1},Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float64}},DiffEqBase.DEStats}, ::InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool}, ::Tsit5, ::Function, ::StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}, ::Nothing) at C:\Users\timot\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\sensitivity_interface.jl:13 (repeats 2 times)
[22] adjoint_sensitivities(::ODESolution{Float32,2,Array{CuArray{Float32,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},ODEProblem{CuArray{Float32,1},Tuple{Float64,Float64},false,CuArray{Float32,1},ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},SciMLBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,typeof(nn_ode),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Array{CuArray{Float32,1},1},Array{Float64,1},Array{Array{CuArray{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float64}},DiffEqBase.DEStats}, ::Tsit5, ::Vararg{Any,N} where N; sensealg::InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool}, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at C:\Users\timot\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\sensitivity_interface.jl:6
[23] (::DiffEqSensitivity.var"#adjoint_sensitivity_backpass#133"{Tsit5,InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool},CuArray{Float32,1},CuArray{Float32,1},Tuple{},Colon})(::CuArray{Float32,2}) at C:\Users\timot\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\concrete_solve.jl:144
[24] (::DiffEqBase.var"#275#back#79"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#133"{Tsit5,InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool},CuArray{Float32,1},CuArray{Float32,1},Tuple{},Colon}})(::CuArray{Float32,2}) at C:\Users\timot\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:65
[25] predict_rd at C:\Users\timot\Documents\Doctoral\Research\SmartANN\Julia Code\Universal ODE\hom_diff_sorp\hom_diff_sorp_sparse_gpu.jl:206 [inlined]
[26] (::typeof(∂(predict_rd)))(::CuArray{Float32,2}) at C:\Users\timot\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
[27] loss_rd at C:\Users\timot\Documents\Doctoral\Research\SmartANN\Julia Code\Universal ODE\hom_diff_sorp\hom_diff_sorp_sparse_gpu.jl:211 [inlined]
[28] (::typeof(∂(loss_rd)))(::Tuple{Float32,Nothing}) at C:\Users\timot\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
[29] #150 at C:\Users\timot\.julia\packages\Zygote\ggM8Z\src\lib\lib.jl:191 [inlined]
[30] (::Zygote.var"#1693#back#152"{Zygote.var"#150#151"{typeof(∂(loss_rd)),Tuple{Tuple{Nothing},Tuple{}}}})(::Tuple{Float32,Nothing}) at C:\Users\timot\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
[31] #74 at C:\Users\timot\.julia\packages\DiffEqFlux\8UHw5\src\train.jl:120 [inlined]
[32] (::typeof(∂(λ)))(::Float32) at C:\Users\timot\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
[33] (::Zygote.var"#54#55"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float32) at C:\Users\timot\.julia\packages\Zygote\ggM8Z\src\compiler\interface.jl:172
[34] gradient(::Function, ::Zygote.Params) at C:\Users\timot\.julia\packages\Zygote\ggM8Z\src\compiler\interface.jl:49
[35] macro expansion at C:\Users\timot\.julia\packages\DiffEqFlux\8UHw5\src\train.jl:119 [inlined]
[36] macro expansion at C:\Users\timot\.julia\packages\ProgressLogging\6KXlp\src\ProgressLogging.jl:328 [inlined]
[37] (::DiffEqFlux.var"#73#78"{var"#263#266",Int64,Bool,Bool,typeof(loss_rd),CuArray{Float32,1},Zygote.Params})() at C:\Users\timot\.julia\packages\DiffEqFlux\8UHw5\src\train.jl:64
[38] maybe_with_logger(::DiffEqFlux.var"#73#78"{var"#263#266",Int64,Bool,Bool,typeof(loss_rd),CuArray{Float32,1},Zygote.Params}, ::Nothing) at C:\Users\timot\.julia\packages\DiffEqFlux\8UHw5\src\train.jl:39
[39] sciml_train(::Function, ::CuArray{Float32,1}, ::ADAM, ::Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}; cb::Function, maxiters::Int64, progress::Bool, save_best::Bool) at C:\Users\timot\.julia\packages\DiffEqFlux\8UHw5\src\train.jl:63
[40] top-level scope at C:\Users\timot\Documents\Doctoral\Research\SmartANN\Julia Code\Universal ODE\hom_diff_sorp\hom_diff_sorp_sparse_gpu.jl:318
[41] include_string(::Function, ::Module, ::String, ::String) at .\loading.jl:1088
in expression starting at C:\Users\timot\Documents\Doctoral\Research\SmartANN\Julia Code\Universal ODE\hom_diff_sorp\hom_diff_sorp_sparse_gpu.jl:318
Snippet of my code:
n_weights = 10
#for the retardation factor
rx_nn = Chain(x -> transpose(x),
Dense(1, n_weights, tanh),
Dense(n_weights, 2*n_weights, tanh),
Dense(2*n_weights, n_weights, tanh),
Dense(n_weights, 1, σ),
x -> @view x[1,:]) |> gpu
# initialize the numerical stencil and the exponent for normalization
p_stencil = [-1.1, 1.05]
p_exp = [1.0]
D0 = [0.1]
p1,re1 = Flux.destructure(rx_nn)
p = [p1;D0;p_stencil;p_exp] |> gpu
exp_base = convert(CuArray{Float32,1}, [10.0])
full_restructure(p) = re1(p[1:length(p1)]), p[end-3], p[end-2:end-1], p[end]
function flux_kernel(u,p)
u_left = @view u[1:Nx-1]
u_right = @view u[2:Nx]
stencil_nbr = @view p[end-1:end-1]
stencil_ctr = @view p[end-2:end-2]
D0 = @view p[end-3]
left = stencil_nbr .* u_left .+ stencil_ctr .* u_right
left_bc_flux = [stencil_nbr .* left_BC .+ stencil_ctr .* u[1:1]][1]
left_flux = vcat(left_bc_flux, left)
right = stencil_ctr .* u_left .+ stencil_nbr .* u_right
u_bc_left = @view u[Nx-1:Nx-1]
u_bc_right = @view u[Nx:Nx]
right_BC_cauchy = (u_bc_left .- u_bc_right) .* D0 .* dx
right_bc_flux = [stencil_ctr .* u[Nx:Nx] .+ stencil_nbr .* right_BC_cauchy][1]
right_flux = vcat(right, right_bc_flux)
left_flux + right_flux
end
function state_kernel(u,p,flux)
rx_nn = re1(p[1:length(p1)])
ret = rx_nn(u)
exp_pow = convert(CuArray{Float32,1}, [p[end]])
ret ./ CUDA.pow.(exp_base,exp_pow) .* flux
end
function nn_ode(u,p,t)
flux = flux_kernel(u,p)
state_kernel(u,p,flux)
end
########################
# Soving the neural PDE and setting up loss function
########################
prob_nn = ODEProblem(nn_ode, c0, (0.0, T), p)
sol_nn = concrete_solve(prob_nn,Tsit5(), c0, p)
function predict_rd(θ)
# No ReverseDiff if using Flux
gpu(Array(concrete_solve(prob_nn,Tsit5(),c0,θ,saveat=dt,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP()))))
end
#match data and force the weights of the CNN to add up to zero
function loss_rd(p)
pred = predict_rd(p)
breakthrough_pred = @view pred[end,:]
rx_nn = re1(p[1:length(p1)])
u = collect(0:0.01:1) |> gpu
exp_pow = @view p[end:end]
D0 = @view p[end-3:end-3]
ret_temp = rx_nn(u)
ret = D0 ./ ret_temp .* CUDA.pow.(exp_base, exp_pow)
loss = convert(Float32,10^3 * sum(abs2, breakthrough_data .- breakthrough_pred) + 10^2 * sum(relu.(ret[2:end] .- ret[1:end-1])) + 10^2 * abs(sum(p[end-2 : end-1])))
return loss, pred
end
Please let me know if anybody can help and/or needs further information from my side.
Many thanks in advance!