# Code using Flux slow on GPU

I am trying to make Julia code running on GPU for quantile function regression where the probability distribution is specified by a linear combination of Bernstein polynomials and its parameters are linked to the covariates by means of a neural network function. A pure CPU implementation works seemingly fine and is fast, but my GPU version (see below) is extremely slow. Any hints on how to improve the code would be greatly appreciated. My main suspicion is the declaration of the training data, but as a novice Julia programmer I am at the moment not sure of alternative ways. Thanks in advance.

``````using Flux, CuArrays, Statistics

##  quantile levels for training
nprob = 10 |> gpu
prob  = Float32.(1:nprob) / (nprob+1) |> gpu

##  Bernstein design matrix (size(B): nprob, degree+1)
degree = 8
dbin(x, n, p) = binomial(n, x) * p^x * (1-p)^(n-x)
B = [dbin(d, degree, p) for p in prob, d in 0:degree] |> gpu

##  quantile loss function (size(b): degree+1, #cases)
function qtloss(b, y)
mean( ((y .< (B*b)') .- prob') .* ((B*b)' .- y) )
end

##  some random training data with batch size 100
n = 10_000
p = 50
x = rand(Float32, p, n)  |> gpu
y = rand(Float32, n)     |> gpu
trdata = [(x[:,i], y[i]) for i in Flux.chunk(1:n, n/100)]

##  define network
model = Chain(Dense(p, 32, relu),
Dense(32, 16, relu),
Dense(16, degree + 1)) |> gpu

##  training for 1 epoch
loss(x, y) = qtloss(model(x), y)
@time Flux.@epochs 1 Flux.train!(loss, Flux.params(model), trdata, Flux.ADAM())
``````

Do you get any warnings about scalar iteration? Did you try running with it disabled (`CuArrays.allowscalar(false)`)?

Yes, I get the warning when constructing the B matrix, but as it is only of dimension 10x9 in the example and fast I did not pursue this further. With
CuArrays.allowscalar(false) I get the following error

``````julia> B = [dbin(d, degree, p) for p in prob, d in 0:degree] |> gpu
ERROR: scalar getindex is disallowed
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] assertscalar(::String) at /home/johnbb/.julia/packages/GPUArrays/fLiQ1/src/indexing.jl:14
[3] getindex at /home/johnbb/.julia/packages/GPUArrays/fLiQ1/src/indexing.jl:54 [inlined]
[4] iterate at ./abstractarray.jl:914 [inlined]
[5] iterate at ./abstractarray.jl:912 [inlined]
[6] _piterate at ./iterators.jl:810 [inlined]
[7] iterate at ./iterators.jl:818 [inlined]
[8] iterate at ./generator.jl:44 [inlined]
[9] collect(::Base.Generator{Base.Iterators.ProductIterator{Tuple{CuArray{Float32,1},UnitRange{Int64}}},getfield(Main, Symbol("##3#4"))}) at ./array.jl:606
[10] top-level scope at REPL[7]:1
``````

I should add that I get some errors when testing CuArrays (related to CUFFT) and Flux (related to RNNs).

``````(v1.2) pkg> test CuArrays
Testing CuArrays
Resolving package versions...
Status `/tmp/jl_86KTwF/Manifest.toml`
[621f4979] AbstractFFTs v0.4.1
[9e28174c] BinDeps v0.8.10
[b99e7846] BinaryProvider v0.5.7
[fa961155] CEnum v0.2.0
[00ebfdb7] CSTParser v1.0.0
[3895d2a7] CUDAapi v1.2.0
[be33ccc6] CUDAnative v2.4.0
[bbf7d656] CommonSubexpressions v0.2.0
[34da2185] Compat v2.2.0
[8f4d0f93] Conda v1.3.0
[a8cc5b0e] Crayons v4.0.0
[3a865a2d] CuArrays v1.2.1
[864edb3b] DataStructures v0.17.5
[163ba53b] DiffResults v0.0.4
[b552c78f] DiffRules v0.0.10
[7a1cc6ca] FFTW v1.0.1
[1a297f60] FillArrays v0.7.4
[f6369f11] ForwardDiff v0.10.5
[0c68f7d7] GPUArrays v1.0.4
[682c06a0] JSON v0.21.0
[929cbde3] LLVM v1.3.1
[1914dd2f] MacroTools v0.5.1
[872c559c] NNlib v0.6.0
[77ba4419] NaNMath v0.3.2
[bac558e1] OrderedCollections v1.1.0
[69de0a69] Parsers v0.3.7
[189a3867] Reexport v0.2.0
[ae029012] Requires v0.5.2
[276daf66] SpecialFunctions v0.8.0
[90137ffa] StaticArrays v0.12.0
[a759f4b9] TimerOutputs v0.5.0
[0796e94c] Tokenize v0.5.6
[30578b45] URIParser v0.4.0
[81def892] VersionParsing v1.1.3
[2a0f44e3] Base64  [`@stdlib/Base64`]
[8bb1440f] DelimitedFiles  [`@stdlib/DelimitedFiles`]
[8ba89e20] Distributed  [`@stdlib/Distributed`]
[b77e0a4c] InteractiveUtils  [`@stdlib/InteractiveUtils`]
[76f85450] LibGit2  [`@stdlib/LibGit2`]
[8f399da3] Libdl  [`@stdlib/Libdl`]
[37e2e46d] LinearAlgebra  [`@stdlib/LinearAlgebra`]
[56ddb016] Logging  [`@stdlib/Logging`]
[d6f4376e] Markdown  [`@stdlib/Markdown`]
[44cfe95a] Pkg  [`@stdlib/Pkg`]
[de0858da] Printf  [`@stdlib/Printf`]
[3fa0cd96] REPL  [`@stdlib/REPL`]
[9a3f8284] Random  [`@stdlib/Random`]
[ea8e919c] SHA  [`@stdlib/SHA`]
[9e88b42a] Serialization  [`@stdlib/Serialization`]
[1a1011a3] SharedArrays  [`@stdlib/SharedArrays`]
[6462fe0b] Sockets  [`@stdlib/Sockets`]
[2f01184e] SparseArrays  [`@stdlib/SparseArrays`]
[10745b16] Statistics  [`@stdlib/Statistics`]
[8dfed614] Test  [`@stdlib/Test`]
[cf7118a7] UUIDs  [`@stdlib/UUIDs`]
[4ec0a83e] Unicode  [`@stdlib/Unicode`]
┌ Warning: Performing scalar operations on GPU arrays: This is very slow, consider disallowing these operations with `allowscalar(false)`
└ @ GPUArrays ~/.julia/packages/GPUArrays/tIMl5/src/indexing.jl:16
Batch 2D (in 4D): Test Failed at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:62
Expression: isapprox(Y, fftw_X, rtol=MYRTOL, atol=MYATOL)
Evaluated: isapprox(Complex{Float32}[30.93433f0 + 27.453215f0im -1.6714139f0 - 3.6637213f0im … 1.452174f0 + 0.04918611f0im 0.7687962f0 - 0.57745886f0im; 0.29303554f0 + 0.9587163f0im 1.5358994f0 + 3.0074072f0im … -1.807018f0 - 0.3964303f0im 4.198574f0 - 2.245281f0im; … ; 1.3176649f0 + 3.5774553f0im -0.25220183f0 - 0.41377443f0im … 1.0596181f0 - 0.36518365f0im -3.8873076f0 - 0.4233678f0im; 3.9277287f0 - 1.6763877f0im 2.3764422f0 - 2.5789008f0im … 1.1053141f0 + 1.7568265f0im 1.7967496f0 + 1.9728156f0im]
................
Stacktrace:
[1] (::getfield(Main, Symbol("#batched#177")){Float64,Float64})(::Array{Complex{Float32},4}, ::Tuple{Int64,Int64}) at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:62
[2] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:135
[3] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
[4] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:132
[5] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
[6] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:70
[7] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
[8] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:3
2D: Test Failed at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:165
Expression: isapprox(Z, X, rtol=MYRTOL, atol=MYATOL)
Evaluated: isapprox(Float32[16.026249 0.20003724 … -0.14261723 3.4577808; 14.708965 0.24957383 … 2.4554594 -0.011748552; … ; 15.179053 1.21625 … 0.6149055 1.0816126; 17.156378 -2.9685822 … -0.51396996 -2.361169], Float32[0.5479944 0.69968283 … 0.35572374 0.008343816; 0.8192854 0.568097 … 0.26755 0.37813723; … ; 0.8134997 0.8640851 … 0.7806642 0.43586862; 0.0038974285 0.7791145 … 0.78436923 0.010035396]; rtol=1.0e-5, atol=1.0e-8)
Stacktrace:
[1] (::getfield(Main, Symbol("#out_of_place#175")){Float64,Float64})(::Array{Float32,2}) at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:165
[2] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:210
[3] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
[4] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:209
[5] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
[6] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:187
[7] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
[8] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:3
3D: Test Failed at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:165
Expression: isapprox(Z, X, rtol=MYRTOL, atol=MYATOL)
Evaluated: isapprox(Float32[1027.6874 -14.55261 … 12.393054 -7.6938467; 1024.6332 8.651075 … 7.487852 -13.836114; … ; 1018.8914 25.11908 … -4.143159 5.174135; 1002.0034 -8.577894 … 1.4080315 -12.264083]

Float32[3.8303516 23.954319 … 15.5727825 10.477801; -4.425709 -13.273114 … 13.321645 2.6528811; … ; -2.1250527 4.9475956 … -16.986565 -5.5888686; 1.3448858 -2.2884483 … -2.8479586 19.10226]
...........................
Float32[0.73248553 0.47862995 … 0.08591521 0.4444232; 0.6592467 0.6849257 … 0.47312415 0.40733826; … ; 0.9162493 0.82751226 … 0.085659266 0.8531437; 0.38045 0.3253126 … 0.85089886 0.03872478]; rtol=1.0e-5, atol=1.0e-8)
Stacktrace:
[1] (::getfield(Main, Symbol("#out_of_place#175")){Float64,Float64})(::Array{Float32,3}) at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:165
[2] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:238
[3] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
[4] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:237
[5] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
[6] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:187
[7] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
[8] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:3
2D: Test Failed at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:165
Expression: isapprox(Z, X, rtol=MYRTOL, atol=MYATOL)
Evaluated: isapprox([13.110226890046071 1.3772673622431166 … -1.534417373612804 -2.5842606057456314; 14.498238941932932 1.357461983938537 … 0.2196697162595076 0.6284427764869798; … ; 16.414946940034493 0.6907537499873704 … 3.2728619718604675 -1.3926039426756092; 16.924688849739088 -1.085251641446253 … 1.2158256481955552 0.9847502665068547], [0.23603334566204692 0.25166218303197185 … 0.9363392981141667 0.0992834999475467; 0.34651701419196046 0.9866663668987996 … 0.26826295612968565 0.3794133659941419; … ; 0.951916339835734 0.2811902322857298 … 0.8025337387522802 0.16828405178779438; 0.9999046588986136 0.20947237319807077 … 0.6998008592631042 0.7892684010289475]; rtol=1.0e-5, atol=1.0e-8)
Stacktrace:
[1] (::getfield(Main, Symbol("#out_of_place#175")){Float64,Float64})(::Array{Float64,2}) at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:165
[2] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:210
[3] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
[4] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:209
[5] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
[6] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:187
[7] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
[8] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:3
3D: Test Failed at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:165
Expression: isapprox(Z, X, rtol=MYRTOL, atol=MYATOL)
Evaluated: isapprox([1049.9017421270862 16.472043665886325 … -18.26215551035024 -12.210517732053521; 1025.6632398346655 -6.661014983218484 … -5.195857775569898 4.546808897557256; … ; 1036.0445034495613 4.915146189883673 … 21.033289503207456 12.505645347032504; 1023.607860961061 24.31438337979879 … -0.6654002225101903 4.115047865085936]
.....................
[0.6778807836814413 0.16599513698823087 … 0.7853702692584121 0.9362500948944805; 0.6699133180023191 0.49007506866428474 … 0.6248287567699398 0.4948073880973143; … ; 0.4894338006908039 0.9424596895465458 … 0.7183333513853476 0.760581981651117; 0.44988663531118345 0.8486461546211002 … 0.0903150721160646 0.1578958025629098]; rtol=1.0e-5, atol=1.0e-8)
Stacktrace:
[1] (::getfield(Main, Symbol("#out_of_place#175")){Float64,Float64})(::Array{Float64,3}) at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:165
[2] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:238
[3] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
[4] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:237
[5] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
[6] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:187
[7] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
[8] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:3
[ Info: Testing CUDNN 7.6.4
[ Info: Testing ForwardDiff integration
Total GPU memory usage: 20.44% (1.593 GiB/7.794 GiB)
CuArrays.jl pool usage: 1.94% (16.008 KiB in use by 2 buffers, 31.611 MiB idle)
────────────────────────────────────────────────
Time
──────────────────────
Tot / % measured:           319s / 0.10%

Section           ncalls     time   %tot     avg
────────────────────────────────────────────────
pooled alloc       9.22k    181ms  54.4%  19.6μs
1 try alloc      1.88k    177ms  53.3%  94.2μs
background task       10    152ms  45.6%  15.2ms
reclaim             10   8.22ms  2.47%   822μs
scan                10   12.6μs  0.00%  1.26μs
────────────────────────────────────────────────
Test Summary:            | Pass  Fail  Total
CuArrays                 | 4451     5   4456
GPUArrays test suite   | 1104         1104
Memory                 |    5            5
Array                  |   22           22
Cufunc                 |    8            8
Reduce                 |    6            6
0D                     |    2            2
SubArray               |   20           20
reshape                |    1            1
triu! with diagonal -2 |    1            1
triu! with diagonal -1 |    1            1
triu! with diagonal 0  |    1            1
triu! with diagonal 1  |    1            1
triu! with diagonal 2  |    1            1
tril! with diagonal -2 |    1            1
tril! with diagonal -1 |    1            1
tril! with diagonal 0  |    1            1
tril! with diagonal 1  |    1            1
tril! with diagonal 2  |    1            1
Utilities              |    2            2
accumulate             |    8            8
logical indexing       |   15           15
generic fallbacks      |   11           11
reverse                |    6            6
permutedims            |    2            2
CUBLAS                 | 1184         1184
CURAND                 |   99           99
CUFFT                  |  145     5    150
T = Complex{Float32} |   36     1     37
1D                 |    3            3
1D inplace         |    2            2
2D                 |    3            3
2D inplace         |    2            2
Batch 1D           |    6            6
3D                 |    3            3
3D inplace         |    2            2
Batch 2D (in 3D)   |    7            7
Batch 2D (in 4D)   |    8     1      9
T = Complex{Float64} |   37           37
T = Float32          |   32     2     34
1D                 |    4            4
Batch 1D           |    6            6
2D                 |    3     1      4
Batch 2D (in 3D)   |    7            7
Batch 2D (in 4D)   |    9            9
3D                 |    3     1      4
T = Float64          |   32     2     34
1D                 |    4            4
Batch 1D           |    6            6
2D                 |    3     1      4
Batch 2D (in 3D)   |    7            7
Batch 2D (in 4D)   |    9            9
3D                 |    3     1      4
T = Complex{Int32}   |    2            2
T = Complex{Int64}   |    2            2
T = Int32            |    2            2
T = Int64            |    2            2
streams              |             No tests
CUSPARSE               | 1229         1229
CUSOLVER               |  297          297
CUSPARSE + CUSOLVER    |   84           84
CUDNN                  |   70           70
ForwardDiff            |   96           96
ERROR: LoadError: Some tests did not pass: 4451 passed, 5 failed, 0 errored, 0 broken.
in expression starting at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/runtests.jl:18
ERROR: Package CuArrays errored during testing

(v1.2) pkg> test Flux
Testing Flux
Resolving package versions...
Status `/tmp/jl_qzeiZA/Manifest.toml`
[621f4979] AbstractFFTs v0.4.1
[1520ce14] AbstractTrees v0.2.1
[9e28174c] BinDeps v0.8.10
[b99e7846] BinaryProvider v0.5.7
[fa961155] CEnum v0.2.0
[00ebfdb7] CSTParser v1.0.0
[3895d2a7] CUDAapi v1.2.0
[be33ccc6] CUDAnative v2.4.0
[944b1d66] CodecZlib v0.6.0
[3da002f7] ColorTypes v0.8.0
[5ae59095] Colors v0.9.6
[bbf7d656] CommonSubexpressions v0.2.0
[34da2185] Compat v2.2.0
[8f4d0f93] Conda v1.3.0
[a8cc5b0e] Crayons v4.0.0
[3a865a2d] CuArrays v1.2.1
[9a962f9c] DataAPI v1.1.0
[864edb3b] DataStructures v0.17.5
[163ba53b] DiffResults v0.0.4
[b552c78f] DiffRules v0.0.10
[7a1cc6ca] FFTW v1.0.1
[1a297f60] FillArrays v0.7.4
[53c48c17] FixedPointNumbers v0.6.1
[587475ba] Flux v0.9.0
[f6369f11] ForwardDiff v0.10.5
[0c68f7d7] GPUArrays v1.0.4
[682c06a0] JSON v0.21.0
[e5e0dc1b] Juno v0.7.2
[929cbde3] LLVM v1.3.1
[1914dd2f] MacroTools v0.5.1
[e89f7d12] Media v0.5.0
[e1d29d7a] Missings v0.4.3
[872c559c] NNlib v0.6.0
[77ba4419] NaNMath v0.3.2
[bac558e1] OrderedCollections v1.1.0
[69de0a69] Parsers v0.3.7
[189a3867] Reexport v0.2.0
[ae029012] Requires v0.5.2
[a2af1166] SortingAlgorithms v0.3.1
[276daf66] SpecialFunctions v0.8.0
[90137ffa] StaticArrays v0.12.0
[2913bbd2] StatsBase v0.32.0
[a759f4b9] TimerOutputs v0.5.0
[0796e94c] Tokenize v0.5.6
[3bb67fe8] TranscodingStreams v0.9.5
[30578b45] URIParser v0.4.0
[81def892] VersionParsing v1.1.3
[a5390f91] ZipFile v0.8.3
[2a0f44e3] Base64  [`@stdlib/Base64`]
[8bb1440f] DelimitedFiles  [`@stdlib/DelimitedFiles`]
[8ba89e20] Distributed  [`@stdlib/Distributed`]
[b77e0a4c] InteractiveUtils  [`@stdlib/InteractiveUtils`]
[76f85450] LibGit2  [`@stdlib/LibGit2`]
[8f399da3] Libdl  [`@stdlib/Libdl`]
[37e2e46d] LinearAlgebra  [`@stdlib/LinearAlgebra`]
[56ddb016] Logging  [`@stdlib/Logging`]
[d6f4376e] Markdown  [`@stdlib/Markdown`]
[44cfe95a] Pkg  [`@stdlib/Pkg`]
[de0858da] Printf  [`@stdlib/Printf`]
[9abbd945] Profile  [`@stdlib/Profile`]
[3fa0cd96] REPL  [`@stdlib/REPL`]
[9a3f8284] Random  [`@stdlib/Random`]
[ea8e919c] SHA  [`@stdlib/SHA`]
[9e88b42a] Serialization  [`@stdlib/Serialization`]
[1a1011a3] SharedArrays  [`@stdlib/SharedArrays`]
[6462fe0b] Sockets  [`@stdlib/Sockets`]
[2f01184e] SparseArrays  [`@stdlib/SparseArrays`]
[10745b16] Statistics  [`@stdlib/Statistics`]
[8dfed614] Test  [`@stdlib/Test`]
[cf7118a7] UUIDs  [`@stdlib/UUIDs`]
[4ec0a83e] Unicode  [`@stdlib/Unicode`]
[ Info: Testing Basics
[ Info: Testing Layers
[ Info: Testing GPU Support
[ Info: Testing Flux/CUDNN
batch_size = 1: Error During Test at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:7
Got exception outside of a @test
Stacktrace:
[1] macro expansion at /home/johnbb/.julia/packages/CuArrays/wXQp8/src/dnn/error.jl:19 [inlined]
[2] cudnnRNNBackwardData(::Flux.CUDA.RNNDesc{Float32}, ::Int64, ::Array{CuArrays.CUDNN.TensorDesc,1}, ::CuArray{Float32,1}, ::Array{CuArrays.CUDNN.TensorDesc,1}, ::CuArray{Float32,1}, ::CuArrays.CUDNN.TensorDesc, ::CuArray{Float32,1}, ::Ptr{Nothing}, ::CUDAdrv.CuPtr{Nothing}, ::CuArrays.CUDNN.FilterDesc, ::CuArray{Float32,1}, ::CuArrays.CUDNN.TensorDesc, ::CuArray{Float32,1}, ::Ptr{Nothing}, ::CUDAdrv.CuPtr{Nothing}, ::Array{CuArrays.CUDNN.TensorDesc,1}, ::CuArray{Float32,1}, ::CuArrays.CUDNN.TensorDesc, ::CuArray{Float32,1}, ::Ptr{Nothing}, ::CUDAdrv.CuPtr{Nothing}, ::CuArray{UInt8,1}, ::CuArray{UInt8,1}) at /home/johnbb/.julia/packages/Flux/dkJUV/src/cuda/curnn.jl:174
[3] backwardData(::Flux.CUDA.RNNDesc{Float32}, ::CuArray{Float32,1}, ::CuArray{Float32,1}, ::CuArray{Float32,1}, ::Nothing, ::CuArray{Float32,1}, ::Nothing, ::CuArray{UInt8,1}) at /home/johnbb/.julia/packages/Flux/dkJUV/src/cuda/curnn.jl:191
[4] backwardData(::Flux.CUDA.RNNDesc{Float32}, ::CuArray{Float32,1}, ::CuArray{Float32,1}, ::CuArray{Float32,1}, ::CuArray{Float32,1}, ::CuArray{UInt8,1}) at /home/johnbb/.julia/packages/Flux/dkJUV/src/cuda/curnn.jl:199
[5] (::getfield(Flux.CUDA, Symbol("##8#9")){Flux.GRUCell{TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},TrackedArray{…,CuArray{Float32,1}},TrackedArray{…,CuArray{Float32,1}},CuArray{UInt8,1},Tuple{CuArray{Float32,1},CuArray{Float32,1}}})(::Tuple{CuArray{Float32,1},CuArray{Float32,1}}) at /home/johnbb/.julia/packages/Flux/dkJUV/src/cuda/curnn.jl:310
[6] back_(::Tracker.Call{getfield(Flux.CUDA, Symbol("##8#9")){Flux.GRUCell{TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},TrackedArray{…,CuArray{Float32,1}},TrackedArray{…,CuArray{Float32,1}},CuArray{UInt8,1},Tuple{CuArray{Float32,1},CuArray{Float32,1}}},Tuple{Tracker.Tracked{CuArray{Float32,1}},Tracker.Tracked{CuArray{Float32,1}},Tracker.Tracked{CuArray{Float32,2}},Tracker.Tracked{CuArray{Float32,2}},Tracker.Tracked{CuArray{Float32,1}}}}, ::Tuple{CuArray{Float32,1},CuArray{Float32,1}}, ::Bool) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:35
[7] back(::Tracker.Tracked{Tuple{CuArray{Float32,1},CuArray{Float32,1}}}, ::Tuple{CuArray{Float32,1},Int64}, ::Bool) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:58
[8] (::getfield(Tracker, Symbol("##13#14")){Bool})(::Tracker.Tracked{Tuple{CuArray{Float32,1},CuArray{Float32,1}}}, ::Tuple{CuArray{Float32,1},Int64}) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:38
[9] foreach(::Function, ::Tuple{Tracker.Tracked{Tuple{CuArray{Float32,1},CuArray{Float32,1}}},Nothing}, ::Tuple{Tuple{CuArray{Float32,1},Int64},Nothing}) at ./abstractarray.jl:1921
[10] back_(::Tracker.Call{getfield(Tracker, Symbol("##361#363")){Tracker.TrackedTuple{Tuple{CuArray{Float32,1},CuArray{Float32,1}}},Int64},Tuple{Tracker.Tracked{Tuple{CuArray{Float32,1},CuArray{Float32,1}}},Nothing}}, ::CuArray{Float32,1}, ::Bool) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:38
[11] back(::Tracker.Tracked{CuArray{Float32,1}}, ::CuArray{Float32,1}, ::Bool) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:58
[12] back!(::TrackedArray{…,CuArray{Float32,1}}, ::CuArray{Float32,1}) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:77
[13] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:23
[14] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
[15] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:7
[16] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
[17] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:4
[18] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
[19] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:4
[20] include at ./boot.jl:328 [inlined]
[22] include(::Module, ::String) at ./Base.jl:31
[23] include(::String) at ./client.jl:431
[24] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/cuda.jl:55
[25] include at ./boot.jl:328 [inlined]
[27] include(::Module, ::String) at ./Base.jl:31
[28] include(::String) at ./client.jl:431
[29] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/runtests.jl:30
[30] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
[31] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/runtests.jl:11
[32] include at ./boot.jl:328 [inlined]
[34] include(::Module, ::String) at ./Base.jl:31
[35] include(::String) at ./client.jl:431
[36] top-level scope at none:5
[37] eval(::Module, ::Any) at ./boot.jl:330
[38] exec_options(::Base.JLOptions) at ./client.jl:271
[39] _start() at ./client.jl:464

batch_size = 5: Test Failed at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:26
Evaluated: Float32[0.026422147 0.031162314 … 0.04015291 0.048164792; -0.00580488 -0.006842452 … -0.0006331135 -0.0051276702; … ; -0.48899454 -0.3973075 … -0.5430576 -0.59323865; -1.5976363 -1.7279379 … -1.2472157 -2.0624979] ≈ Float32[0.0163969 0.020081667 … 0.035748813 0.03631903; -0.0038312192 -0.004661015 … 0.00023391606 -0.002795606; … ; -0.6556128 -0.5814663 … -0.616253 -0.79011357; -1.3221657 -1.4234674 … -1.1262015 -1.7370036]
Stacktrace:
[1] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:26
[2] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
[3] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:7
[4] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
[5] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:4
[6] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
[7] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:4
batch_size = 5: Test Failed at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:27
Evaluated: Float32[-0.015565421 0.0032692542 … -0.02177089 -0.0079453755; 0.0033148595 -0.00096146535 … 0.002088631 -0.00068250747; … ; 0.060095496 -0.037989482 … 0.0877803 0.05621599; 0.13518983 -0.049921982 … 0.1269001 0.046479825] ≈ Float32[-0.0102084465 0.0006350695 … -0.017785752 -0.0077674175; 0.0022602363 -0.00044287564 … 0.0013040807 -0.00071754144; … ; 0.08005668 -0.047804993 … 0.10262971 0.056879077; 0.12019784 -0.042549968 … 0.11574733 0.0459818]
Stacktrace:
[1] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:27
[2] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
[3] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:7
[4] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
[5] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:4
[6] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
[7] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:4
batch_size = 5: Test Failed at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:28
Evaluated: Float32[0.043129362, -0.0052406434, 0.0032806913, 0.03839735, -0.027703835, -0.1904509, 0.3789641, 0.11353702, -1.1046532, -0.10988483, -0.3808116, -0.806893, -0.59165734, -0.8405457, -2.1673152] ≈ Float32[0.030957595, -0.002844398, 0.007849413, 0.050801173, -0.025155623, -0.11647805, 0.19690844, 0.15651214, -1.3034244, -0.10899977, -0.033872187, -0.53707606, -0.8408148, -1.0428388, -1.8328632]
Stacktrace:
[1] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:28
[2] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
[3] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:7
[4] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
[5] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:4
[6] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
[7] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:4
batch_size = 5: Test Failed at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:29
Evaluated: Float32[-0.012821687, -0.7067668, -0.35861096, -1.4159613, -0.54712236] ≈ Float32[0.066714235, -0.12755066, -0.45448342, -1.7976848, -0.41404387]
Stacktrace:
[1] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:29
[2] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
[3] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:7
[4] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
[5] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:4
[6] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
[7] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:4
Test Summary:                         | Pass  Fail  Error  Total
Flux                                  |  274     4      1    279
Throttle                            |   11                  11
Jacobian                            |    1                   1
Initialization                      |   12                  12
Params                              |    4                   4
Basic Stacking                      |    1                   1
Precision                           |    6                   6
Stacking                            |    3                   3
onecold                             |    4                   4
onehotbatch indexing                |    2                   2
Optimise                            |   11                  11
Optimiser                           |    3                   3
Training Loop                       |    2                   2
ExpDecay                            |    3                   3
basic                               |   27                  27
Dropout                             |   10                  10
BatchNorm                           |   14                  14
InstanceNorm                        |   16                  16
GroupNorm                           |   16                  16
losses                              |   30                  30
Pooling                             |    2                   2
CNN                                 |    1                   1
Depthwise Conv                      |    3                   3
ConvTranspose                       |    1                   1
CrossCor                            |    4                   4
Conv with non quadratic window #700 |    4                   4
Tracker                             |    4                   4
CuArrays                            |    8                   8
onecold gpu                         |    2                   2
CUDNN BatchNorm                     |   10                  10
RNN                                 |   40     4      1     45
R = RNN                           |   16                  16
R = GRU                           |    6     4      1     11
batch_size = 1                  |    2            1      3
batch_size = 5                  |    4     4             8
R = LSTM                          |   18                  18
ERROR: LoadError: Some tests did not pass: 274 passed, 4 failed, 1 errored, 0 broken.
in expression starting at /home/johnbb/.julia/packages/Flux/dkJUV/test/runtests.jl:9
ERROR: Package Flux errored during testing

``````

The warning only triggers once, so try setting `allowscalar` only before your main loop. Or fix the initialization of `B` by saving a copy of `prob` on the CPU.

Those are unrelated, and fixed on the latest master branches (not yet released).

Not sure I fully understand, but setting it just before the Flux.train! and leaving the rest as before resulted in the following error

``````julia> CuArrays.allowscalar(false)

julia> @time Flux.@epochs 1 Flux.train!(loss, Flux.params(model), trdata, Flux.ADAM())
[ Info: Epoch 1
ERROR: scalar getindex is disallowed
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] assertscalar(::String) at /home/johnbb/.julia/packages/GPUArrays/tIMl5/src/indexing.jl:14
[3] getindex at /home/johnbb/.julia/packages/GPUArrays/tIMl5/src/indexing.jl:54 [inlined]
[4] _getindex at ./abstractarray.jl:1004 [inlined]
[5] getindex at ./abstractarray.jl:981 [inlined]
[7] _unsafe_getindex_rs at ./reshapedarray.jl:245 [inlined]
[8] _unsafe_getindex at ./reshapedarray.jl:242 [inlined]
[9] getindex at ./reshapedarray.jl:231 [inlined]
[10] _generic_matmatmul!(::CuArray{Float32,2}, ::Char, ::Char, ::Base.ReshapedArray{Float32,2,LinearAlgebra.Adjoint{Float32,CuArray{Float32,2}},Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, ::CuArray{Float32,2}) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/LinearAlgebra/src/matmul.jl:614
[11] generic_matmatmul! at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/LinearAlgebra/src/matmul.jl:584 [inlined]
[12] mul! at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/LinearAlgebra/src/matmul.jl:255 [inlined]
[13] * at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/LinearAlgebra/src/matmul.jl:143 [inlined]
[14] _forward at /home/johnbb/.julia/packages/Tracker/SAr25/src/lib/array.jl:415 [inlined]
[15] #track#1 at /home/johnbb/.julia/packages/Tracker/SAr25/src/Tracker.jl:51 [inlined]
[16] track at /home/johnbb/.julia/packages/Tracker/SAr25/src/Tracker.jl:51 [inlined]
[17] * at /home/johnbb/.julia/packages/Tracker/SAr25/src/lib/array.jl:379 [inlined]
[18] #509 at /home/johnbb/.julia/packages/Tracker/SAr25/src/lib/array.jl:416 [inlined]
[19] back_(::Tracker.Call{getfield(Tracker, Symbol("##509#510")){CuArray{Float32,2},TrackedArray{…,CuArray{Float32,2}}},Tuple{Nothing,Tracker.Tracked{CuArray{Float32,2}}}}, ::Base.ReshapedArray{Float32,2,LinearAlgebra.Adjoint{Float32,CuArray{Float32,2}},Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, ::Bool) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:35
[20] back(::Tracker.Tracked{CuArray{Float32,2}}, ::Base.ReshapedArray{Float32,2,LinearAlgebra.Adjoint{Float32,CuArray{Float32,2}},Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, ::Bool) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:58
[21] foreach at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:38 [inlined]
[22] back_(::Tracker.Call{getfield(Tracker, Symbol("##390#391")){TrackedArray{…,CuArray{Float32,2}}},Tuple{Tracker.Tracked{CuArray{Float32,2}}}}, ::CuArray{Float32,2}, ::Bool) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:38
[23] back(::Tracker.Tracked{LinearAlgebra.Adjoint{Float32,CuArray{Float32,2}}}, ::CuArray{Float32,2}, ::Bool) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:58
[24] #13 at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:38 [inlined]
[25] foreach at ./abstractarray.jl:1921 [inlined]
[27] back(::Tracker.Tracked{CuArray{Float32,2}}, ::CuArray{Float32,2}, ::Bool) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:58
[28] #13 at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:38 [inlined]
[29] foreach at ./abstractarray.jl:1921 [inlined]
[30] back_(::Tracker.Call{getfield(Tracker, Symbol("##497#498")){Colon,TrackedArray{…,CuArray{Float32,2}}},Tuple{Tracker.Tracked{CuArray{Float32,2}}}}, ::Float32, ::Bool) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:38
[31] back(::Tracker.Tracked{Float32}, ::Int64, ::Bool) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:58
[32] #back!#15 at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:77 [inlined]
[33] #back! at ./none:0 [inlined]
[34] #back!#32 at /home/johnbb/.julia/packages/Tracker/SAr25/src/lib/real.jl:16 [inlined]
[35] back!(::Tracker.TrackedReal{Float32}) at /home/johnbb/.julia/packages/Tracker/SAr25/src/lib/real.jl:14
[36] gradient_(::getfield(Flux.Optimise, Symbol("##15#21")){typeof(loss),Tuple{CuArray{Float32,2},CuArray{Float32,1}}}, ::Tracker.Params) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:4
[39] macro expansion at /home/johnbb/.julia/packages/Flux/dkJUV/src/optimise/train.jl:71 [inlined]
[40] macro expansion at /home/johnbb/.julia/packages/Juno/oLB1d/src/progress.jl:134 [inlined]
[41] #train!#12(::getfield(Flux.Optimise, Symbol("##16#22")), ::typeof(Flux.Optimise.train!), ::Function, ::Tracker.Params, ::Array{Tuple{CuArray{Float32,2},CuArray{Float32,1}},1}, ::ADAM) at /home/johnbb/.julia/packages/Flux/dkJUV/src/optimise/train.jl:69
[42] train!(::Function, ::Tracker.Params, ::Array{Tuple{CuArray{Float32,2},CuArray{Float32,1}},1}, ::ADAM) at /home/johnbb/.julia/packages/Flux/dkJUV/src/optimise/train.jl:67
[43] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/src/optimise/train.jl:106
[44] top-level scope at /home/johnbb/.julia/packages/Juno/oLB1d/src/progress.jl:134
[45] top-level scope at util.jl:156
julia> B
10×9 CuArray{Float32,2}:
0.466507    0.373206     0.130622     0.0261244    0.00326555  …  1.30622e-5   3.73206e-7   4.66507e-9
0.200816    0.357006     0.277672     0.12341      0.0342805      0.000677145  4.29933e-5   1.19426e-6
0.078267    0.234801     0.308176     0.231132     0.108343       0.0060943    0.000652961  3.06076e-5
0.0268932   0.12294      0.245881     0.281007     0.200719       0.0262164    0.00428022   0.00030573
0.00783553  0.0522369    0.152358     0.253929     0.26451        0.0734749    0.017494     0.00182229
0.00182229  0.017494     0.0734749    0.17634      0.26451     …  0.152358     0.0522369    0.00783553
0.00030573  0.00428022   0.0262164    0.0917573    0.200719       0.245881     0.12294      0.0268932
3.06075e-5  0.000652961  0.0060943    0.0325029    0.108343       0.308176     0.234801     0.078267
1.19426e-6  4.29933e-5   0.000677145  0.0060943    0.0342805      0.277672     0.357006     0.200816
4.66506e-9  3.73205e-7   1.30622e-5   0.000261244  0.00326555     0.130622     0.373206     0.466507
julia>

``````

You understood correctly! So this reveals a scalar operation within your main loop, because of calling `mul!(::CuArray, ::ReshapedArray{Adjoint{CuArray}})`. This is an unfortunate case where we don’t know that we should dispatch to a GPU-optimized generic matmul, instead falling back to the Base implementation. Our type system falls a little short here.

A workaround would be to figure out where the `ReshapedArray` comes from, since calling `reshape` on a `CuArray` should give a `CuArray` again (avoiding this situation since `mul!(::CuArray, Adjoint{CuArray})` does dispatch to the correct implementation. cc @MikeInnes.

It could possibly be in the `qtloss` function. For example `y` is of size (100,1) while `(B*b)'` is of size (100, degree+1). Not sure what happens in the background when the following is evaluated: `y .< (B*b)'`. On the other hand `qtloss` seems to be very fast.

Are there any thoughts on a general solution to this class of issues? I’ve seen this crop up before

There’s been a couple of PRs, but nothing GPU specific (https://github.com/JuliaLang/julia/pull/31563, https://github.com/JuliaLang/julia/pull/25558). Here’s something similar to Adapt.jl: https://github.com/JuliaGPU/GPUArrays.jl/issues/147#issuecomment-417255267. I know @keno had some thoughts about this too, but I don’t think he’s had the time to do anything with them.