I have trained a small Flux.jl neural network that I am embedding into a larger model but I need to use the neural network to make millions (billions?) of predictions as part of running this larger model.
Since the larger model runs on the GPU I also want to evaluate the neural network on the GPU. And even if the network is too small to saturate the GPU, I think queuing up millions of evaluations should saturate the GPU and result in a significant speedup. However, I am not sure how to do this.
I tried calling the neural network in a CUDA kernel so that I could launch many of them but Chain
s and Dense
layers are not isbits
so I donβt think you can use a kernel here. Iβm also hesitant to write a custom kernel to evaluate the chain since I plan to try out different chains/architectures so Iβm looking for a more generic solution.
Unfortunately evaluating the chain in a loop for _ in 1:10^4; G(y); end
doesnβt queue up many CUDA kernel launches which can then be executed in parallel. It probably also doesnβt help that evaluating the chain on the GPU actually incurs quite a few CPU allocations.
Iβd appreciate any tips for speeding up batch chain evaluations on the GPU if anyone else has tried doing something similar!
CPU benchmark
using BenchmarkTools
using CUDA
using Flux
x = ones(Float32, 32)
C = Chain(
Dense(32, 128, relu),
Dense(128, 128, relu),
Dense(128, 31, relu)
)
@benchmark C(x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min β¦ max): 11.436 ΞΌs β¦ 395.964 ΞΌs β GC (min β¦ max): 0.00% β¦ 0.00%
Time (median): 17.751 ΞΌs β GC (median): 0.00%
Time (mean Β± Ο): 17.526 ΞΌs Β± 6.871 ΞΌs β GC (mean Β± Ο): 0.00% Β± 0.00%
ββββββ βββββ
ββββ
ββββββ
ββββββββββββββββ
ββββββββββββββββββββββββββββββββββββββ β
11.4 ΞΌs Histogram: frequency by time 26.5 ΞΌs <
Memory estimate: 2.62 KiB, allocs estimate: 6.
GPU benchmark
y = CUDA.ones(32)
G = gpu(C)
CUDA.@time CUDA.@sync G(y)
@benchmark CUDA.@sync G(y)
0.000273 seconds (102 CPU allocations: 5.766 KiB) (6 GPU allocations: 2.242 KiB, 9.46% memmgmt time)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min β¦ max): 43.629 ΞΌs β¦ 2.288 ms β GC (min β¦ max): 0.00% β¦ 0.00%
Time (median): 50.206 ΞΌs β GC (median): 0.00%
Time (mean Β± Ο): 57.249 ΞΌs Β± 27.861 ΞΌs β GC (mean Β± Ο): 0.00% Β± 0.00%
β
ββββββ
ββββββββββββ
βββββββββββββββββββββββββββββββββββββββββββββββ β
43.6 ΞΌs Histogram: frequency by time 97.7 ΞΌs <
Memory estimate: 5.77 KiB, allocs estimate: 102.