I just started using Flux.jl. I have a loss function whose gradient is calculated much slower on GPU than on CPU. As a machine learning project typically involves a large dataset, it is difficult to create a MWE, but I wonder if someone have any suggestions by looking at the types of the arguments.
Here are the packages I am using:
using Flux
using Flux: onehotbatch, onecold, crossentropy, params
using CUDA
I have a loss function
loss(x, y) = crossentropy(m(x), y)
defined with some model m
. On CPU, I can evaluate the gradient of the loss function in 10 secs:
julia> @timev gradient(()->loss(t...), ps)
8.022828 seconds (4.67 M allocations: 7.465 GiB, 16.80% gc time, 38.20% compilation time)
elapsed time (ns): 8022828062
gc time (ns): 1347510721
bytes allocated: 8014950633
pool allocs: 4671353
non-pool GC allocs: 512
malloc() calls: 290
free() calls: 100
minor collections: 3
full collections: 3
Grads(...)
Here the type of t
is
julia> typeof(t)
Tuple{Array{Float32, 3}, OneHotArrays.OneHotMatrix{UInt32, 10, Vector{UInt32}}}
and the types of the elements of ps
are a combination of 1D, 2D, and 3D arrays:
julia> typeof.(ps)
62-element Vector{DataType}:
Array{Float32, 3}
⋮
Vector{Float32} (alias for Array{Float32, 1})
⋮
Matrix{Float32} (alias for Array{Float32, 2})
⋮
Now, to perform the gradient on GPU, I defined a new loss function as follows:
m2 = m |> gpu
loss2(x,y) = crossentropy(m2(x), y)
I also prepared the CUDA version of t
and ps
:
julia> typeof(t2)
Tuple{CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}, OneHotArrays.OneHotMatrix{UInt32, 10, CuArray{UInt32, 1, CUDA.Mem.DeviceBuffer}}}
julia> typeof.(ps2)
62-element Vector{DataType}:
CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}
⋮
CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}
⋮
CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}
⋮
You can see that the standard Array
s in t
and ps
are converted to the corresponding CuArray
s in t2
and ps2
.
However, when I peform gradient
on loss2
using t2
and ps2
, it takes much longer:
julia> @timev gradient(()->loss2(t2...), ps2)
559.368659 seconds (209.64 M allocations: 31.599 GiB, 0.86% gc time, 0.26% compilation time)
elapsed time (ns): 559368659181
gc time (ns): 4785176794
bytes allocated: 33928929401
pool allocs: 209638516
non-pool GC allocs: 61
malloc() calls: 9
free() calls: 223
minor collections: 101
full collections: 0
Grads(...)
So the gradient calculation on GPU is ~70X slower than on CPU.
Any suggestions on how to tackle this issue? Even if not a direct solution to the issue, I will appreciate any suggestions that help understanding the cause.