Code using Flux slow on GPU

I am trying to make Julia code running on GPU for quantile function regression where the probability distribution is specified by a linear combination of Bernstein polynomials and its parameters are linked to the covariates by means of a neural network function. A pure CPU implementation works seemingly fine and is fast, but my GPU version (see below) is extremely slow. Any hints on how to improve the code would be greatly appreciated. My main suspicion is the declaration of the training data, but as a novice Julia programmer I am at the moment not sure of alternative ways. Thanks in advance.

using Flux, CuArrays, Statistics

##  quantile levels for training
nprob = 10 |> gpu
prob  = Float32.(1:nprob) / (nprob+1) |> gpu

##  Bernstein design matrix (size(B): nprob, degree+1)
degree = 8
dbin(x, n, p) = binomial(n, x) * p^x * (1-p)^(n-x)
B = [dbin(d, degree, p) for p in prob, d in 0:degree] |> gpu  

##  quantile loss function (size(b): degree+1, #cases)
function qtloss(b, y)
    mean( ((y .< (B*b)') .- prob') .* ((B*b)' .- y) ) 
end

##  some random training data with batch size 100
n = 10_000
p = 50
x = rand(Float32, p, n)  |> gpu
y = rand(Float32, n)     |> gpu
trdata = [(x[:,i], y[i]) for i in Flux.chunk(1:n, n/100)]   

##  define network
model = Chain(Dense(p, 32, relu),
              Dense(32, 16, relu),
              Dense(16, degree + 1)) |> gpu

##  training for 1 epoch
loss(x, y) = qtloss(model(x), y)
@time Flux.@epochs 1 Flux.train!(loss, Flux.params(model), trdata, Flux.ADAM())

Do you get any warnings about scalar iteration? Did you try running with it disabled (CuArrays.allowscalar(false))?

Yes, I get the warning when constructing the B matrix, but as it is only of dimension 10x9 in the example and fast I did not pursue this further. With
CuArrays.allowscalar(false) I get the following error

julia> B = [dbin(d, degree, p) for p in prob, d in 0:degree] |> gpu
ERROR: scalar getindex is disallowed
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] assertscalar(::String) at /home/johnbb/.julia/packages/GPUArrays/fLiQ1/src/indexing.jl:14
 [3] getindex at /home/johnbb/.julia/packages/GPUArrays/fLiQ1/src/indexing.jl:54 [inlined]
 [4] iterate at ./abstractarray.jl:914 [inlined]
 [5] iterate at ./abstractarray.jl:912 [inlined]
 [6] _piterate at ./iterators.jl:810 [inlined]
 [7] iterate at ./iterators.jl:818 [inlined]
 [8] iterate at ./generator.jl:44 [inlined]
 [9] collect(::Base.Generator{Base.Iterators.ProductIterator{Tuple{CuArray{Float32,1},UnitRange{Int64}}},getfield(Main, Symbol("##3#4"))}) at ./array.jl:606
 [10] top-level scope at REPL[7]:1

I should add that I get some errors when testing CuArrays (related to CUFFT) and Flux (related to RNNs).

(v1.2) pkg> test CuArrays
   Testing CuArrays
 Resolving package versions...
    Status `/tmp/jl_86KTwF/Manifest.toml`
  [621f4979] AbstractFFTs v0.4.1
  [79e6a3ab] Adapt v1.0.0
  [9e28174c] BinDeps v0.8.10
  [b99e7846] BinaryProvider v0.5.7
  [fa961155] CEnum v0.2.0
  [00ebfdb7] CSTParser v1.0.0
  [3895d2a7] CUDAapi v1.2.0
  [c5f51814] CUDAdrv v3.1.0
  [be33ccc6] CUDAnative v2.4.0
  [bbf7d656] CommonSubexpressions v0.2.0
  [34da2185] Compat v2.2.0
  [8f4d0f93] Conda v1.3.0
  [a8cc5b0e] Crayons v4.0.0
  [3a865a2d] CuArrays v1.2.1
  [864edb3b] DataStructures v0.17.5
  [163ba53b] DiffResults v0.0.4
  [b552c78f] DiffRules v0.0.10
  [7a1cc6ca] FFTW v1.0.1
  [1a297f60] FillArrays v0.7.4
  [f6369f11] ForwardDiff v0.10.5
  [0c68f7d7] GPUArrays v1.0.4
  [682c06a0] JSON v0.21.0
  [929cbde3] LLVM v1.3.1
  [1914dd2f] MacroTools v0.5.1
  [872c559c] NNlib v0.6.0
  [77ba4419] NaNMath v0.3.2
  [bac558e1] OrderedCollections v1.1.0
  [69de0a69] Parsers v0.3.7
  [189a3867] Reexport v0.2.0
  [ae029012] Requires v0.5.2
  [276daf66] SpecialFunctions v0.8.0
  [90137ffa] StaticArrays v0.12.0
  [a759f4b9] TimerOutputs v0.5.0
  [0796e94c] Tokenize v0.5.6
  [30578b45] URIParser v0.4.0
  [81def892] VersionParsing v1.1.3
  [2a0f44e3] Base64  [`@stdlib/Base64`]
  [ade2ca70] Dates  [`@stdlib/Dates`]
  [8bb1440f] DelimitedFiles  [`@stdlib/DelimitedFiles`]
  [8ba89e20] Distributed  [`@stdlib/Distributed`]
  [b77e0a4c] InteractiveUtils  [`@stdlib/InteractiveUtils`]
  [76f85450] LibGit2  [`@stdlib/LibGit2`]
  [8f399da3] Libdl  [`@stdlib/Libdl`]
  [37e2e46d] LinearAlgebra  [`@stdlib/LinearAlgebra`]
  [56ddb016] Logging  [`@stdlib/Logging`]
  [d6f4376e] Markdown  [`@stdlib/Markdown`]
  [a63ad114] Mmap  [`@stdlib/Mmap`]
  [44cfe95a] Pkg  [`@stdlib/Pkg`]
  [de0858da] Printf  [`@stdlib/Printf`]
  [3fa0cd96] REPL  [`@stdlib/REPL`]
  [9a3f8284] Random  [`@stdlib/Random`]
  [ea8e919c] SHA  [`@stdlib/SHA`]
  [9e88b42a] Serialization  [`@stdlib/Serialization`]
  [1a1011a3] SharedArrays  [`@stdlib/SharedArrays`]
  [6462fe0b] Sockets  [`@stdlib/Sockets`]
  [2f01184e] SparseArrays  [`@stdlib/SparseArrays`]
  [10745b16] Statistics  [`@stdlib/Statistics`]
  [8dfed614] Test  [`@stdlib/Test`]
  [cf7118a7] UUIDs  [`@stdlib/UUIDs`]
  [4ec0a83e] Unicode  [`@stdlib/Unicode`]
┌ Warning: Performing scalar operations on GPU arrays: This is very slow, consider disallowing these operations with `allowscalar(false)`
└ @ GPUArrays ~/.julia/packages/GPUArrays/tIMl5/src/indexing.jl:16
Batch 2D (in 4D): Test Failed at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:62
  Expression: isapprox(Y, fftw_X, rtol=MYRTOL, atol=MYATOL)
   Evaluated: isapprox(Complex{Float32}[30.93433f0 + 27.453215f0im -1.6714139f0 - 3.6637213f0im … 1.452174f0 + 0.04918611f0im 0.7687962f0 - 0.57745886f0im; 0.29303554f0 + 0.9587163f0im 1.5358994f0 + 3.0074072f0im … -1.807018f0 - 0.3964303f0im 4.198574f0 - 2.245281f0im; … ; 1.3176649f0 + 3.5774553f0im -0.25220183f0 - 0.41377443f0im … 1.0596181f0 - 0.36518365f0im -3.8873076f0 - 0.4233678f0im; 3.9277287f0 - 1.6763877f0im 2.3764422f0 - 2.5789008f0im … 1.1053141f0 + 1.7568265f0im 1.7967496f0 + 1.9728156f0im]
................
Stacktrace:
 [1] (::getfield(Main, Symbol("#batched#177")){Float64,Float64})(::Array{Complex{Float32},4}, ::Tuple{Int64,Int64}) at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:62
 [2] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:135
 [3] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
 [4] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:132
 [5] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
 [6] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:70
 [7] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
 [8] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:3
2D: Test Failed at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:165
  Expression: isapprox(Z, X, rtol=MYRTOL, atol=MYATOL)
   Evaluated: isapprox(Float32[16.026249 0.20003724 … -0.14261723 3.4577808; 14.708965 0.24957383 … 2.4554594 -0.011748552; … ; 15.179053 1.21625 … 0.6149055 1.0816126; 17.156378 -2.9685822 … -0.51396996 -2.361169], Float32[0.5479944 0.69968283 … 0.35572374 0.008343816; 0.8192854 0.568097 … 0.26755 0.37813723; … ; 0.8134997 0.8640851 … 0.7806642 0.43586862; 0.0038974285 0.7791145 … 0.78436923 0.010035396]; rtol=1.0e-5, atol=1.0e-8)
Stacktrace:
 [1] (::getfield(Main, Symbol("#out_of_place#175")){Float64,Float64})(::Array{Float32,2}) at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:165
 [2] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:210
 [3] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
 [4] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:209
 [5] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
 [6] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:187
 [7] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
 [8] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:3
3D: Test Failed at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:165
  Expression: isapprox(Z, X, rtol=MYRTOL, atol=MYATOL)
   Evaluated: isapprox(Float32[1027.6874 -14.55261 … 12.393054 -7.6938467; 1024.6332 8.651075 … 7.487852 -13.836114; … ; 1018.8914 25.11908 … -4.143159 5.174135; 1002.0034 -8.577894 … 1.4080315 -12.264083]

Float32[3.8303516 23.954319 … 15.5727825 10.477801; -4.425709 -13.273114 … 13.321645 2.6528811; … ; -2.1250527 4.9475956 … -16.986565 -5.5888686; 1.3448858 -2.2884483 … -2.8479586 19.10226]
...........................
Float32[0.73248553 0.47862995 … 0.08591521 0.4444232; 0.6592467 0.6849257 … 0.47312415 0.40733826; … ; 0.9162493 0.82751226 … 0.085659266 0.8531437; 0.38045 0.3253126 … 0.85089886 0.03872478]; rtol=1.0e-5, atol=1.0e-8)
Stacktrace:
 [1] (::getfield(Main, Symbol("#out_of_place#175")){Float64,Float64})(::Array{Float32,3}) at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:165
 [2] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:238
 [3] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
 [4] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:237
 [5] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
 [6] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:187
 [7] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
 [8] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:3
2D: Test Failed at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:165
  Expression: isapprox(Z, X, rtol=MYRTOL, atol=MYATOL)
   Evaluated: isapprox([13.110226890046071 1.3772673622431166 … -1.534417373612804 -2.5842606057456314; 14.498238941932932 1.357461983938537 … 0.2196697162595076 0.6284427764869798; … ; 16.414946940034493 0.6907537499873704 … 3.2728619718604675 -1.3926039426756092; 16.924688849739088 -1.085251641446253 … 1.2158256481955552 0.9847502665068547], [0.23603334566204692 0.25166218303197185 … 0.9363392981141667 0.0992834999475467; 0.34651701419196046 0.9866663668987996 … 0.26826295612968565 0.3794133659941419; … ; 0.951916339835734 0.2811902322857298 … 0.8025337387522802 0.16828405178779438; 0.9999046588986136 0.20947237319807077 … 0.6998008592631042 0.7892684010289475]; rtol=1.0e-5, atol=1.0e-8)
Stacktrace:
 [1] (::getfield(Main, Symbol("#out_of_place#175")){Float64,Float64})(::Array{Float64,2}) at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:165
 [2] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:210
 [3] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
 [4] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:209
 [5] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
 [6] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:187
 [7] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
 [8] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:3
3D: Test Failed at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:165
  Expression: isapprox(Z, X, rtol=MYRTOL, atol=MYATOL)
   Evaluated: isapprox([1049.9017421270862 16.472043665886325 … -18.26215551035024 -12.210517732053521; 1025.6632398346655 -6.661014983218484 … -5.195857775569898 4.546808897557256; … ; 1036.0445034495613 4.915146189883673 … 21.033289503207456 12.505645347032504; 1023.607860961061 24.31438337979879 … -0.6654002225101903 4.115047865085936]
.....................
[0.6778807836814413 0.16599513698823087 … 0.7853702692584121 0.9362500948944805; 0.6699133180023191 0.49007506866428474 … 0.6248287567699398 0.4948073880973143; … ; 0.4894338006908039 0.9424596895465458 … 0.7183333513853476 0.760581981651117; 0.44988663531118345 0.8486461546211002 … 0.0903150721160646 0.1578958025629098]; rtol=1.0e-5, atol=1.0e-8)
Stacktrace:
 [1] (::getfield(Main, Symbol("#out_of_place#175")){Float64,Float64})(::Array{Float64,3}) at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:165
 [2] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:238
 [3] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
 [4] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:237
 [5] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
 [6] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:187
 [7] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
 [8] top-level scope at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/fft.jl:3
[ Info: Testing CUDNN 7.6.4
[ Info: Testing ForwardDiff integration
Total GPU memory usage: 20.44% (1.593 GiB/7.794 GiB)
CuArrays.jl pool usage: 1.94% (16.008 KiB in use by 2 buffers, 31.611 MiB idle)
 ────────────────────────────────────────────────
                                   Time          
                           ──────────────────────
     Tot / % measured:           319s / 0.10%    

 Section           ncalls     time   %tot     avg
 ────────────────────────────────────────────────
 pooled alloc       9.22k    181ms  54.4%  19.6μs
   1 try alloc      1.88k    177ms  53.3%  94.2μs
 background task       10    152ms  45.6%  15.2ms
   reclaim             10   8.22ms  2.47%   822μs
   scan                10   12.6μs  0.00%  1.26μs
 ────────────────────────────────────────────────
Test Summary:            | Pass  Fail  Total
CuArrays                 | 4451     5   4456
  GPUArrays test suite   | 1104         1104
  Memory                 |    5            5
  Array                  |   22           22
  Adapt                  |    2            2
  Broadcast              |   18           18
  Cufunc                 |    8            8
  Ref Broadcast          |    1            1
  Broadcast Fix          |    4            4
  Reduce                 |    6            6
  0D                     |    2            2
  SubArray               |   20           20
  reshape                |    1            1
  triu! with diagonal -2 |    1            1
  triu! with diagonal -1 |    1            1
  triu! with diagonal 0  |    1            1
  triu! with diagonal 1  |    1            1
  triu! with diagonal 2  |    1            1
  tril! with diagonal -2 |    1            1
  tril! with diagonal -1 |    1            1
  tril! with diagonal 0  |    1            1
  tril! with diagonal 1  |    1            1
  tril! with diagonal 2  |    1            1
  Utilities              |    2            2
  accumulate             |    8            8
  logical indexing       |   15           15
  generic fallbacks      |   11           11
  reverse                |    6            6
  permutedims            |    2            2
  CUBLAS                 | 1184         1184
  CURAND                 |   99           99
  CUFFT                  |  145     5    150
    T = Complex{Float32} |   36     1     37
      1D                 |    3            3
      1D inplace         |    2            2
      2D                 |    3            3
      2D inplace         |    2            2
      Batch 1D           |    6            6
      3D                 |    3            3
      3D inplace         |    2            2
      Batch 2D (in 3D)   |    7            7
      Batch 2D (in 4D)   |    8     1      9
    T = Complex{Float64} |   37           37
    T = Float32          |   32     2     34
      1D                 |    4            4
      Batch 1D           |    6            6
      2D                 |    3     1      4
      Batch 2D (in 3D)   |    7            7
      Batch 2D (in 4D)   |    9            9
      3D                 |    3     1      4
    T = Float64          |   32     2     34
      1D                 |    4            4
      Batch 1D           |    6            6
      2D                 |    3     1      4
      Batch 2D (in 3D)   |    7            7
      Batch 2D (in 4D)   |    9            9
      3D                 |    3     1      4
    T = Complex{Int32}   |    2            2
    T = Complex{Int64}   |    2            2
    T = Int32            |    2            2
    T = Int64            |    2            2
    streams              |             No tests
  CUSPARSE               | 1229         1229
  CUSOLVER               |  297          297
  CUSPARSE + CUSOLVER    |   84           84
  CUDNN                  |   70           70
  ForwardDiff            |   96           96
ERROR: LoadError: Some tests did not pass: 4451 passed, 5 failed, 0 errored, 0 broken.
in expression starting at /home/johnbb/.julia/packages/CuArrays/wXQp8/test/runtests.jl:18
ERROR: Package CuArrays errored during testing


(v1.2) pkg> test Flux
   Testing Flux
 Resolving package versions...
    Status `/tmp/jl_qzeiZA/Manifest.toml`
  [621f4979] AbstractFFTs v0.4.1
  [1520ce14] AbstractTrees v0.2.1
  [79e6a3ab] Adapt v1.0.0
  [9e28174c] BinDeps v0.8.10
  [b99e7846] BinaryProvider v0.5.7
  [fa961155] CEnum v0.2.0
  [00ebfdb7] CSTParser v1.0.0
  [3895d2a7] CUDAapi v1.2.0
  [c5f51814] CUDAdrv v3.1.0
  [be33ccc6] CUDAnative v2.4.0
  [944b1d66] CodecZlib v0.6.0
  [3da002f7] ColorTypes v0.8.0
  [5ae59095] Colors v0.9.6
  [bbf7d656] CommonSubexpressions v0.2.0
  [34da2185] Compat v2.2.0
  [8f4d0f93] Conda v1.3.0
  [a8cc5b0e] Crayons v4.0.0
  [3a865a2d] CuArrays v1.2.1
  [9a962f9c] DataAPI v1.1.0
  [864edb3b] DataStructures v0.17.5
  [163ba53b] DiffResults v0.0.4
  [b552c78f] DiffRules v0.0.10
  [7a1cc6ca] FFTW v1.0.1
  [1a297f60] FillArrays v0.7.4
  [53c48c17] FixedPointNumbers v0.6.1
  [587475ba] Flux v0.9.0
  [f6369f11] ForwardDiff v0.10.5
  [0c68f7d7] GPUArrays v1.0.4
  [682c06a0] JSON v0.21.0
  [e5e0dc1b] Juno v0.7.2
  [929cbde3] LLVM v1.3.1
  [1914dd2f] MacroTools v0.5.1
  [e89f7d12] Media v0.5.0
  [e1d29d7a] Missings v0.4.3
  [872c559c] NNlib v0.6.0
  [77ba4419] NaNMath v0.3.2
  [bac558e1] OrderedCollections v1.1.0
  [69de0a69] Parsers v0.3.7
  [189a3867] Reexport v0.2.0
  [ae029012] Requires v0.5.2
  [a2af1166] SortingAlgorithms v0.3.1
  [276daf66] SpecialFunctions v0.8.0
  [90137ffa] StaticArrays v0.12.0
  [2913bbd2] StatsBase v0.32.0
  [a759f4b9] TimerOutputs v0.5.0
  [0796e94c] Tokenize v0.5.6
  [9f7883ad] Tracker v0.2.3
  [3bb67fe8] TranscodingStreams v0.9.5
  [30578b45] URIParser v0.4.0
  [81def892] VersionParsing v1.1.3
  [a5390f91] ZipFile v0.8.3
  [2a0f44e3] Base64  [`@stdlib/Base64`]
  [ade2ca70] Dates  [`@stdlib/Dates`]
  [8bb1440f] DelimitedFiles  [`@stdlib/DelimitedFiles`]
  [8ba89e20] Distributed  [`@stdlib/Distributed`]
  [b77e0a4c] InteractiveUtils  [`@stdlib/InteractiveUtils`]
  [76f85450] LibGit2  [`@stdlib/LibGit2`]
  [8f399da3] Libdl  [`@stdlib/Libdl`]
  [37e2e46d] LinearAlgebra  [`@stdlib/LinearAlgebra`]
  [56ddb016] Logging  [`@stdlib/Logging`]
  [d6f4376e] Markdown  [`@stdlib/Markdown`]
  [a63ad114] Mmap  [`@stdlib/Mmap`]
  [44cfe95a] Pkg  [`@stdlib/Pkg`]
  [de0858da] Printf  [`@stdlib/Printf`]
  [9abbd945] Profile  [`@stdlib/Profile`]
  [3fa0cd96] REPL  [`@stdlib/REPL`]
  [9a3f8284] Random  [`@stdlib/Random`]
  [ea8e919c] SHA  [`@stdlib/SHA`]
  [9e88b42a] Serialization  [`@stdlib/Serialization`]
  [1a1011a3] SharedArrays  [`@stdlib/SharedArrays`]
  [6462fe0b] Sockets  [`@stdlib/Sockets`]
  [2f01184e] SparseArrays  [`@stdlib/SparseArrays`]
  [10745b16] Statistics  [`@stdlib/Statistics`]
  [8dfed614] Test  [`@stdlib/Test`]
  [cf7118a7] UUIDs  [`@stdlib/UUIDs`]
  [4ec0a83e] Unicode  [`@stdlib/Unicode`]
[ Info: Testing Basics
[ Info: Testing Layers
[ Info: Running Gradient Checks
[ Info: Testing GPU Support
[ Info: Testing Flux/CUDNN
batch_size = 1: Error During Test at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:7
  Got exception outside of a @test
  CUDNNError(code 3, CUDNN_STATUS_BAD_PARAM)
  Stacktrace:
   [1] macro expansion at /home/johnbb/.julia/packages/CuArrays/wXQp8/src/dnn/error.jl:19 [inlined]
   [2] cudnnRNNBackwardData(::Flux.CUDA.RNNDesc{Float32}, ::Int64, ::Array{CuArrays.CUDNN.TensorDesc,1}, ::CuArray{Float32,1}, ::Array{CuArrays.CUDNN.TensorDesc,1}, ::CuArray{Float32,1}, ::CuArrays.CUDNN.TensorDesc, ::CuArray{Float32,1}, ::Ptr{Nothing}, ::CUDAdrv.CuPtr{Nothing}, ::CuArrays.CUDNN.FilterDesc, ::CuArray{Float32,1}, ::CuArrays.CUDNN.TensorDesc, ::CuArray{Float32,1}, ::Ptr{Nothing}, ::CUDAdrv.CuPtr{Nothing}, ::Array{CuArrays.CUDNN.TensorDesc,1}, ::CuArray{Float32,1}, ::CuArrays.CUDNN.TensorDesc, ::CuArray{Float32,1}, ::Ptr{Nothing}, ::CUDAdrv.CuPtr{Nothing}, ::CuArray{UInt8,1}, ::CuArray{UInt8,1}) at /home/johnbb/.julia/packages/Flux/dkJUV/src/cuda/curnn.jl:174
   [3] backwardData(::Flux.CUDA.RNNDesc{Float32}, ::CuArray{Float32,1}, ::CuArray{Float32,1}, ::CuArray{Float32,1}, ::Nothing, ::CuArray{Float32,1}, ::Nothing, ::CuArray{UInt8,1}) at /home/johnbb/.julia/packages/Flux/dkJUV/src/cuda/curnn.jl:191
   [4] backwardData(::Flux.CUDA.RNNDesc{Float32}, ::CuArray{Float32,1}, ::CuArray{Float32,1}, ::CuArray{Float32,1}, ::CuArray{Float32,1}, ::CuArray{UInt8,1}) at /home/johnbb/.julia/packages/Flux/dkJUV/src/cuda/curnn.jl:199
   [5] (::getfield(Flux.CUDA, Symbol("##8#9")){Flux.GRUCell{TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},TrackedArray{…,CuArray{Float32,1}},TrackedArray{…,CuArray{Float32,1}},CuArray{UInt8,1},Tuple{CuArray{Float32,1},CuArray{Float32,1}}})(::Tuple{CuArray{Float32,1},CuArray{Float32,1}}) at /home/johnbb/.julia/packages/Flux/dkJUV/src/cuda/curnn.jl:310
   [6] back_(::Tracker.Call{getfield(Flux.CUDA, Symbol("##8#9")){Flux.GRUCell{TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},TrackedArray{…,CuArray{Float32,1}},TrackedArray{…,CuArray{Float32,1}},CuArray{UInt8,1},Tuple{CuArray{Float32,1},CuArray{Float32,1}}},Tuple{Tracker.Tracked{CuArray{Float32,1}},Tracker.Tracked{CuArray{Float32,1}},Tracker.Tracked{CuArray{Float32,2}},Tracker.Tracked{CuArray{Float32,2}},Tracker.Tracked{CuArray{Float32,1}}}}, ::Tuple{CuArray{Float32,1},CuArray{Float32,1}}, ::Bool) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:35
   [7] back(::Tracker.Tracked{Tuple{CuArray{Float32,1},CuArray{Float32,1}}}, ::Tuple{CuArray{Float32,1},Int64}, ::Bool) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:58
   [8] (::getfield(Tracker, Symbol("##13#14")){Bool})(::Tracker.Tracked{Tuple{CuArray{Float32,1},CuArray{Float32,1}}}, ::Tuple{CuArray{Float32,1},Int64}) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:38
   [9] foreach(::Function, ::Tuple{Tracker.Tracked{Tuple{CuArray{Float32,1},CuArray{Float32,1}}},Nothing}, ::Tuple{Tuple{CuArray{Float32,1},Int64},Nothing}) at ./abstractarray.jl:1921
   [10] back_(::Tracker.Call{getfield(Tracker, Symbol("##361#363")){Tracker.TrackedTuple{Tuple{CuArray{Float32,1},CuArray{Float32,1}}},Int64},Tuple{Tracker.Tracked{Tuple{CuArray{Float32,1},CuArray{Float32,1}}},Nothing}}, ::CuArray{Float32,1}, ::Bool) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:38
   [11] back(::Tracker.Tracked{CuArray{Float32,1}}, ::CuArray{Float32,1}, ::Bool) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:58
   [12] back!(::TrackedArray{…,CuArray{Float32,1}}, ::CuArray{Float32,1}) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:77
   [13] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:23
   [14] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
   [15] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:7
   [16] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
   [17] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:4
   [18] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
   [19] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:4
   [20] include at ./boot.jl:328 [inlined]
   [21] include_relative(::Module, ::String) at ./loading.jl:1094
   [22] include(::Module, ::String) at ./Base.jl:31
   [23] include(::String) at ./client.jl:431
   [24] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/cuda.jl:55
   [25] include at ./boot.jl:328 [inlined]
   [26] include_relative(::Module, ::String) at ./loading.jl:1094
   [27] include(::Module, ::String) at ./Base.jl:31
   [28] include(::String) at ./client.jl:431
   [29] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/runtests.jl:30
   [30] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
   [31] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/runtests.jl:11
   [32] include at ./boot.jl:328 [inlined]
   [33] include_relative(::Module, ::String) at ./loading.jl:1094
   [34] include(::Module, ::String) at ./Base.jl:31
   [35] include(::String) at ./client.jl:431
   [36] top-level scope at none:5
   [37] eval(::Module, ::Any) at ./boot.jl:330
   [38] exec_options(::Base.JLOptions) at ./client.jl:271
   [39] _start() at ./client.jl:464
  
batch_size = 5: Test Failed at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:26
  Expression: rnn.cell.Wi.grad ≈ collect(curnn.cell.Wi.grad)
   Evaluated: Float32[0.026422147 0.031162314 … 0.04015291 0.048164792; -0.00580488 -0.006842452 … -0.0006331135 -0.0051276702; … ; -0.48899454 -0.3973075 … -0.5430576 -0.59323865; -1.5976363 -1.7279379 … -1.2472157 -2.0624979] ≈ Float32[0.0163969 0.020081667 … 0.035748813 0.03631903; -0.0038312192 -0.004661015 … 0.00023391606 -0.002795606; … ; -0.6556128 -0.5814663 … -0.616253 -0.79011357; -1.3221657 -1.4234674 … -1.1262015 -1.7370036]
Stacktrace:
 [1] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:26
 [2] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
 [3] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:7
 [4] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
 [5] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:4
 [6] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
 [7] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:4
batch_size = 5: Test Failed at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:27
  Expression: rnn.cell.Wh.grad ≈ collect(curnn.cell.Wh.grad)
   Evaluated: Float32[-0.015565421 0.0032692542 … -0.02177089 -0.0079453755; 0.0033148595 -0.00096146535 … 0.002088631 -0.00068250747; … ; 0.060095496 -0.037989482 … 0.0877803 0.05621599; 0.13518983 -0.049921982 … 0.1269001 0.046479825] ≈ Float32[-0.0102084465 0.0006350695 … -0.017785752 -0.0077674175; 0.0022602363 -0.00044287564 … 0.0013040807 -0.00071754144; … ; 0.08005668 -0.047804993 … 0.10262971 0.056879077; 0.12019784 -0.042549968 … 0.11574733 0.0459818]
Stacktrace:
 [1] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:27
 [2] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
 [3] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:7
 [4] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
 [5] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:4
 [6] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
 [7] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:4
batch_size = 5: Test Failed at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:28
  Expression: rnn.cell.b.grad ≈ collect(curnn.cell.b.grad)
   Evaluated: Float32[0.043129362, -0.0052406434, 0.0032806913, 0.03839735, -0.027703835, -0.1904509, 0.3789641, 0.11353702, -1.1046532, -0.10988483, -0.3808116, -0.806893, -0.59165734, -0.8405457, -2.1673152] ≈ Float32[0.030957595, -0.002844398, 0.007849413, 0.050801173, -0.025155623, -0.11647805, 0.19690844, 0.15651214, -1.3034244, -0.10899977, -0.033872187, -0.53707606, -0.8408148, -1.0428388, -1.8328632]
Stacktrace:
 [1] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:28
 [2] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
 [3] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:7
 [4] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
 [5] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:4
 [6] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
 [7] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:4
batch_size = 5: Test Failed at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:29
  Expression: rnn.cell.h.grad ≈ collect(curnn.cell.h.grad)
   Evaluated: Float32[-0.012821687, -0.7067668, -0.35861096, -1.4159613, -0.54712236] ≈ Float32[0.066714235, -0.12755066, -0.45448342, -1.7976848, -0.41404387]
Stacktrace:
 [1] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:29
 [2] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
 [3] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:7
 [4] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1186
 [5] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:4
 [6] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1113
 [7] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/test/cuda/curnn.jl:4
Test Summary:                         | Pass  Fail  Error  Total
Flux                                  |  274     4      1    279
  Throttle                            |   11                  11
  Jacobian                            |    1                   1
  Initialization                      |   12                  12
  Params                              |    4                   4
  Basic Stacking                      |    1                   1
  Precision                           |    6                   6
  Stacking                            |    3                   3
  onecold                             |    4                   4
  onehotbatch indexing                |    2                   2
  Optimise                            |   11                  11
  Optimiser                           |    3                   3
  Training Loop                       |    2                   2
  ExpDecay                            |    3                   3
  basic                               |   27                  27
  Dropout                             |   10                  10
  BatchNorm                           |   14                  14
  InstanceNorm                        |   16                  16
  GroupNorm                           |   16                  16
  losses                              |   30                  30
  Pooling                             |    2                   2
  CNN                                 |    1                   1
  asymmetric padding                  |    7                   7
  Depthwise Conv                      |    3                   3
  ConvTranspose                       |    1                   1
  CrossCor                            |    4                   4
  Conv with non quadratic window #700 |    4                   4
  Tracker                             |    4                   4
  CuArrays                            |    8                   8
  onecold gpu                         |    2                   2
  CUDNN BatchNorm                     |   10                  10
  RNN                                 |   40     4      1     45
    R = RNN                           |   16                  16
    R = GRU                           |    6     4      1     11
      batch_size = 1                  |    2            1      3
      batch_size = 5                  |    4     4             8
    R = LSTM                          |   18                  18
ERROR: LoadError: Some tests did not pass: 274 passed, 4 failed, 1 errored, 0 broken.
in expression starting at /home/johnbb/.julia/packages/Flux/dkJUV/test/runtests.jl:9
ERROR: Package Flux errored during testing


The warning only triggers once, so try setting allowscalar only before your main loop. Or fix the initialization of B by saving a copy of prob on the CPU.

Those are unrelated, and fixed on the latest master branches (not yet released).

Not sure I fully understand, but setting it just before the Flux.train! and leaving the rest as before resulted in the following error

julia> CuArrays.allowscalar(false)

julia> @time Flux.@epochs 1 Flux.train!(loss, Flux.params(model), trdata, Flux.ADAM()) 
[ Info: Epoch 1
ERROR: scalar getindex is disallowed
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] assertscalar(::String) at /home/johnbb/.julia/packages/GPUArrays/tIMl5/src/indexing.jl:14
 [3] getindex at /home/johnbb/.julia/packages/GPUArrays/tIMl5/src/indexing.jl:54 [inlined]
 [4] _getindex at ./abstractarray.jl:1004 [inlined]
 [5] getindex at ./abstractarray.jl:981 [inlined]
 [6] getindex at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/LinearAlgebra/src/adjtrans.jl:178 [inlined]
 [7] _unsafe_getindex_rs at ./reshapedarray.jl:245 [inlined]
 [8] _unsafe_getindex at ./reshapedarray.jl:242 [inlined]
 [9] getindex at ./reshapedarray.jl:231 [inlined]
 [10] _generic_matmatmul!(::CuArray{Float32,2}, ::Char, ::Char, ::Base.ReshapedArray{Float32,2,LinearAlgebra.Adjoint{Float32,CuArray{Float32,2}},Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, ::CuArray{Float32,2}) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/LinearAlgebra/src/matmul.jl:614
 [11] generic_matmatmul! at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/LinearAlgebra/src/matmul.jl:584 [inlined]
 [12] mul! at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/LinearAlgebra/src/matmul.jl:255 [inlined]
 [13] * at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.2/LinearAlgebra/src/matmul.jl:143 [inlined]
 [14] _forward at /home/johnbb/.julia/packages/Tracker/SAr25/src/lib/array.jl:415 [inlined]
 [15] #track#1 at /home/johnbb/.julia/packages/Tracker/SAr25/src/Tracker.jl:51 [inlined]
 [16] track at /home/johnbb/.julia/packages/Tracker/SAr25/src/Tracker.jl:51 [inlined]
 [17] * at /home/johnbb/.julia/packages/Tracker/SAr25/src/lib/array.jl:379 [inlined]
 [18] #509 at /home/johnbb/.julia/packages/Tracker/SAr25/src/lib/array.jl:416 [inlined]
 [19] back_(::Tracker.Call{getfield(Tracker, Symbol("##509#510")){CuArray{Float32,2},TrackedArray{…,CuArray{Float32,2}}},Tuple{Nothing,Tracker.Tracked{CuArray{Float32,2}}}}, ::Base.ReshapedArray{Float32,2,LinearAlgebra.Adjoint{Float32,CuArray{Float32,2}},Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, ::Bool) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:35
 [20] back(::Tracker.Tracked{CuArray{Float32,2}}, ::Base.ReshapedArray{Float32,2,LinearAlgebra.Adjoint{Float32,CuArray{Float32,2}},Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, ::Bool) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:58
 [21] foreach at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:38 [inlined]
 [22] back_(::Tracker.Call{getfield(Tracker, Symbol("##390#391")){TrackedArray{…,CuArray{Float32,2}}},Tuple{Tracker.Tracked{CuArray{Float32,2}}}}, ::CuArray{Float32,2}, ::Bool) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:38
 [23] back(::Tracker.Tracked{LinearAlgebra.Adjoint{Float32,CuArray{Float32,2}}}, ::CuArray{Float32,2}, ::Bool) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:58
 [24] #13 at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:38 [inlined]
 [25] foreach at ./abstractarray.jl:1921 [inlined]
 [26] back_(::Tracker.Call{getfield(Tracker, Symbol("#back#548")){5,getfield(Base.Broadcast, Symbol("##2#4")){getfield(Base.Broadcast, Symbol("##8#10")){getfield(Base.Broadcast, Symbol("##8#10")){getfield(Base.Broadcast, Symbol("##1#3")),getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##7#9"))}},getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##13#14"))}},getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##17#18"))}},typeof(-)},getfield(Base.Broadcast, Symbol("##8#10")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##7#9"))},getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##7#9"))}},getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##13#14"))}},getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##17#18"))}},typeof(<)},getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##13#14"))}},getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##17#18"))}},typeof(-)},typeof(*)},Tuple{CuArray{Float32,1},TrackedArray{…,LinearAlgebra.Adjoint{Float32,CuArray{Float32,2}}},LinearAlgebra.Adjoint{Float32,CuArray{Float32,1}},TrackedArray{…,LinearAlgebra.Adjoint{Float32,CuArray{Float32,2}}},CuArray{Float32,1}}},Tuple{Nothing,Tracker.Tracked{LinearAlgebra.Adjoint{Float32,CuArray{Float32,2}}},Nothing,Tracker.Tracked{LinearAlgebra.Adjoint{Float32,CuArray{Float32,2}}},Nothing}}, ::CuArray{Float32,2}, ::Bool) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:38
 [27] back(::Tracker.Tracked{CuArray{Float32,2}}, ::CuArray{Float32,2}, ::Bool) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:58
 [28] #13 at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:38 [inlined]
 [29] foreach at ./abstractarray.jl:1921 [inlined]
 [30] back_(::Tracker.Call{getfield(Tracker, Symbol("##497#498")){Colon,TrackedArray{…,CuArray{Float32,2}}},Tuple{Tracker.Tracked{CuArray{Float32,2}}}}, ::Float32, ::Bool) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:38
 [31] back(::Tracker.Tracked{Float32}, ::Int64, ::Bool) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:58
 [32] #back!#15 at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:77 [inlined]
 [33] #back! at ./none:0 [inlined]
 [34] #back!#32 at /home/johnbb/.julia/packages/Tracker/SAr25/src/lib/real.jl:16 [inlined]
 [35] back!(::Tracker.TrackedReal{Float32}) at /home/johnbb/.julia/packages/Tracker/SAr25/src/lib/real.jl:14
 [36] gradient_(::getfield(Flux.Optimise, Symbol("##15#21")){typeof(loss),Tuple{CuArray{Float32,2},CuArray{Float32,1}}}, ::Tracker.Params) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:4
 [37] #gradient#24(::Bool, ::typeof(Tracker.gradient), ::Function, ::Tracker.Params) at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:164
 [38] gradient at /home/johnbb/.julia/packages/Tracker/SAr25/src/back.jl:164 [inlined]
 [39] macro expansion at /home/johnbb/.julia/packages/Flux/dkJUV/src/optimise/train.jl:71 [inlined]
 [40] macro expansion at /home/johnbb/.julia/packages/Juno/oLB1d/src/progress.jl:134 [inlined]
 [41] #train!#12(::getfield(Flux.Optimise, Symbol("##16#22")), ::typeof(Flux.Optimise.train!), ::Function, ::Tracker.Params, ::Array{Tuple{CuArray{Float32,2},CuArray{Float32,1}},1}, ::ADAM) at /home/johnbb/.julia/packages/Flux/dkJUV/src/optimise/train.jl:69
 [42] train!(::Function, ::Tracker.Params, ::Array{Tuple{CuArray{Float32,2},CuArray{Float32,1}},1}, ::ADAM) at /home/johnbb/.julia/packages/Flux/dkJUV/src/optimise/train.jl:67
 [43] top-level scope at /home/johnbb/.julia/packages/Flux/dkJUV/src/optimise/train.jl:106
 [44] top-level scope at /home/johnbb/.julia/packages/Juno/oLB1d/src/progress.jl:134
 [45] top-level scope at util.jl:156
julia> B
10×9 CuArray{Float32,2}:
 0.466507    0.373206     0.130622     0.0261244    0.00326555  …  1.30622e-5   3.73206e-7   4.66507e-9
 0.200816    0.357006     0.277672     0.12341      0.0342805      0.000677145  4.29933e-5   1.19426e-6
 0.078267    0.234801     0.308176     0.231132     0.108343       0.0060943    0.000652961  3.06076e-5
 0.0268932   0.12294      0.245881     0.281007     0.200719       0.0262164    0.00428022   0.00030573
 0.00783553  0.0522369    0.152358     0.253929     0.26451        0.0734749    0.017494     0.00182229
 0.00182229  0.017494     0.0734749    0.17634      0.26451     …  0.152358     0.0522369    0.00783553
 0.00030573  0.00428022   0.0262164    0.0917573    0.200719       0.245881     0.12294      0.0268932 
 3.06075e-5  0.000652961  0.0060943    0.0325029    0.108343       0.308176     0.234801     0.078267  
 1.19426e-6  4.29933e-5   0.000677145  0.0060943    0.0342805      0.277672     0.357006     0.200816  
 4.66506e-9  3.73205e-7   1.30622e-5   0.000261244  0.00326555     0.130622     0.373206     0.466507  
julia>


You understood correctly! So this reveals a scalar operation within your main loop, because of calling mul!(::CuArray, ::ReshapedArray{Adjoint{CuArray}}). This is an unfortunate case where we don’t know that we should dispatch to a GPU-optimized generic matmul, instead falling back to the Base implementation. Our type system falls a little short here.

A workaround would be to figure out where the ReshapedArray comes from, since calling reshape on a CuArray should give a CuArray again (avoiding this situation since mul!(::CuArray, Adjoint{CuArray}) does dispatch to the correct implementation. cc @MikeInnes.

It could possibly be in the qtloss function. For example y is of size (100,1) while (B*b)' is of size (100, degree+1). Not sure what happens in the background when the following is evaluated: y .< (B*b)'. On the other hand qtloss seems to be very fast.

Are there any thoughts on a general solution to this class of issues? I’ve seen this crop up before

There’s been a couple of PRs, but nothing GPU specific (https://github.com/JuliaLang/julia/pull/31563, https://github.com/JuliaLang/julia/pull/25558). Here’s something similar to Adapt.jl: https://github.com/JuliaGPU/GPUArrays.jl/issues/147#issuecomment-417255267. I know @keno had some thoughts about this too, but I don’t think he’s had the time to do anything with them.