Is there a way to avoid allocations when calling a Flux model? Especially on a GPU

When using Flux, especially with CUDA, each inference call uses a lot of allocations, particularly on a GPU, below is an example:

julia> using Flux, CUDA;

julia> model_cpu = Flux.Chain(Flux.Dense(2, 8), Flux.Dense(8,8), Flux.Dense(8,1));

julia> model_gpu = fmap(cu, model_cpu);

julia> dataset_cpu = rand(2, 4096); dataset_gpu = cu(dataset_cpu);

julia> using BenchmarkTools;

julia> @btime model_cpu(dataset_cpu);
  225.000 μs (15 allocations: 1.06 MiB)

julia> @btime model_gpu(dataset_gpu);
  50.700 μs (267 allocations: 5.91 KiB)

Is there a way to cache intermediate arrays to avoid new allocations, or even pre-allocate or the output?

There’s really no way to get around intermediate heap (or device memory) allocations for something as expensive as an ML model, but you can try using GitHub - oxinabox/AutoPreallocation.jl: What if your code allocated less? Remember what memory we needed last time and use it again every time after. Also, Flux exports a function gpu (and cpu) that does the fmap(cu, model) for you.

Thanks!

I’ve just tried using preallocate on the model as shown below. However, this crashes the REPL when executed. I’ve tried a few different variations but they all seem to crash.

x, preallocated_cpu_f = preallocate(model_cpu, dataset_cpu);

The error seems to crash with a very large error:

FATAL ERROR: Symbol "__nv_copysign"not found
signal (22): SIGABRT
in expression starting at REPL[11]:1
crt_sig_handler at /cygdrive/c/buildbot/worker/package_win64/build/src\signals-win.c:93
raise at C:\WINDOWS\System32\msvcrt.dll (unknown line)
abort at C:\WINDOWS\System32\msvcrt.dll (unknown line)
addModule at /cygdrive/c/buildbot/worker/package_win64/build/src\jitlayers.cpp:772
jl_add_to_ee at /cygdrive/c/buildbot/worker/package_win64/build/src\jitlayers.cpp:1068
jl_add_to_ee at /cygdrive/c/buildbot/worker/package_win64/build/src\jitlayers.cpp:1112
jl_add_to_ee at /cygdrive/c/buildbot/worker/package_win64/build/src\jitlayers.cpp:1097
jl_add_to_ee at /cygdrive/c/buildbot/worker/package_win64/build/src\jitlayers.cpp:1097
jl_add_to_ee at /cygdrive/c/buildbot/worker/package_win64/build/src\jitlayers.cpp:1097
jl_add_to_ee at /cygdrive/c/buildbot/worker/package_win64/build/src\jitlayers.cpp:1097
jl_add_to_ee at /cygdrive/c/buildbot/worker/package_win64/build/src\jitlayers.cpp:1134 [inlined]
_jl_compile_codeinst at /cygdrive/c/buildbot/worker/package_win64/build/src\jitlayers.cpp:154
jl_generate_fptr at /cygdrive/c/buildbot/worker/package_win64/build/src\jitlayers.cpp:352
jl_compile_method_internal at /cygdrive/c/buildbot/worker/package_win64/build/src\gf.c:1970
jl_compile_method_internal at /cygdrive/c/buildbot/worker/package_win64/build/src\gf.c:1924 [inlined]
_jl_invoke at /cygdrive/c/buildbot/worker/package_win64/build/src\gf.c:2229 [inlined]
jl_apply_generic at /cygdrive/c/buildbot/worker/package_win64/build/src\gf.c:2419
jl_apply at /cygdrive/c/buildbot/worker/package_win64/build/src\julia.h:1703 [inlined]
do_apply at /cygdrive/c/buildbot/worker/package_win64/build/src\builtins.c:670
.
.
.
etc

If there is no way around this, then it’s okay, I can write my own wrapper to do this. AutoPreallocation seems like a very useful package and I’ll come back to it in the future.

Thanks again for your reply!

That GPU intrinsic function name looks suspect. Is part of the model or input inadvertently using a CuArray? A full stack trace would help here.

I have restricted this so that the model is only on the CPU, without CUDA being loaded.

The full stack trace is:

FATAL ERROR: Symbol "__nv_copysign"not found
signal (22): SIGABRT
in expression starting at REPL[7]:1
crt_sig_handler at /cygdrive/c/buildbot/worker/package_win64/build/src\signals-win.c:93
raise at C:\WINDOWS\System32\msvcrt.dll (unknown line)
abort at C:\WINDOWS\System32\msvcrt.dll (unknown line)
addModule at /cygdrive/c/buildbot/worker/package_win64/build/src\jitlayers.cpp:772
jl_add_to_ee at /cygdrive/c/buildbot/worker/package_win64/build/src\jitlayers.cpp:1068
jl_add_to_ee at /cygdrive/c/buildbot/worker/package_win64/build/src\jitlayers.cpp:1112
jl_add_to_ee at /cygdrive/c/buildbot/worker/package_win64/build/src\jitlayers.cpp:1134 [inlined]
_jl_compile_codeinst at /cygdrive/c/buildbot/worker/package_win64/build/src\jitlayers.cpp:154
jl_generate_fptr at /cygdrive/c/buildbot/worker/package_win64/build/src\jitlayers.cpp:352
jl_compile_method_internal at /cygdrive/c/buildbot/worker/package_win64/build/src\gf.c:1970
jl_compile_method_internal at /cygdrive/c/buildbot/worker/package_win64/build/src\gf.c:1924 [inlined]
_jl_invoke at /cygdrive/c/buildbot/worker/package_win64/build/src\gf.c:2229 [inlined]
jl_apply_generic at /cygdrive/c/buildbot/worker/package_win64/build/src\gf.c:2419
jl_apply at /cygdrive/c/buildbot/worker/package_win64/build/src\julia.h:1703 [inlined]
do_call at /cygdrive/c/buildbot/worker/package_win64/build/src\interpreter.c:115
eval_value at /cygdrive/c/buildbot/worker/package_win64/build/src\interpreter.c:204
eval_stmt_value at /cygdrive/c/buildbot/worker/package_win64/build/src\interpreter.c:155 [inlined]
eval_body at /cygdrive/c/buildbot/worker/package_win64/build/src\interpreter.c:575
jl_interpret_toplevel_thunk at /cygdrive/c/buildbot/worker/package_win64/build/src\interpreter.c:669
jl_toplevel_eval_flex at /cygdrive/c/buildbot/worker/package_win64/build/src\toplevel.c:877
jl_toplevel_eval_flex at /cygdrive/c/buildbot/worker/package_win64/build/src\toplevel.c:825
jl_toplevel_eval_flex at /cygdrive/c/buildbot/worker/package_win64/build/src\toplevel.c:825
eval_body at /cygdrive/c/buildbot/worker/package_win64/build/src\interpreter.c:524
eval_body at /cygdrive/c/buildbot/worker/package_win64/build/src\interpreter.c:489
jl_interpret_toplevel_thunk at /cygdrive/c/buildbot/worker/package_win64/build/src\interpreter.c:669
jl_toplevel_eval_flex at /cygdrive/c/buildbot/worker/package_win64/build/src\toplevel.c:877
jl_toplevel_eval at /cygdrive/c/buildbot/worker/package_win64/build/src\toplevel.c:886 [inlined]
jl_toplevel_eval_in at /cygdrive/c/buildbot/worker/package_win64/build/src\toplevel.c:929
eval at .\boot.jl:360 [inlined]
eval_user_input at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.6\REPL\src\REPL.jl:139
repl_backend_loop at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.6\REPL\src\REPL.jl:200
start_repl_backend at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.6\REPL\src\REPL.jl:185
#run_repl#42 at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.6\REPL\src\REPL.jl:317
run_repl at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.6\REPL\src\REPL.jl:305
#874 at .\client.jl:387
jfptr_YY.874_23672.clone_1 at C:\Users\jamie\AppData\Local\Programs\Julia-1.6.0\lib\julia\sys.dll (unknown line)
jl_apply at /cygdrive/c/buildbot/worker/package_win64/build/src\julia.h:1703 [inlined]
jl_f__call_latest at /cygdrive/c/buildbot/worker/package_win64/build/src\builtins.c:714
#invokelatest#2 at .\essentials.jl:708 [inlined]
invokelatest at .\essentials.jl:706 [inlined]
run_main_repl at .\client.jl:372
exec_options at .\client.jl:302
_start at .\client.jl:485
jfptr__start_43335.clone_1 at C:\Users\jamie\AppData\Local\Programs\Julia-1.6.0\lib\julia\sys.dll (unknown line)
jl_apply at /cygdrive/c/buildbot/worker/package_win64/build/src\julia.h:1703 [inlined]
true_main at /cygdrive/c/buildbot/worker/package_win64/build/src\jlapi.c:560
repl_entrypoint at /cygdrive/c/buildbot/worker/package_win64/build/src\jlapi.c:702
mainCRTStartup at /cygdrive/c/buildbot/worker/package_win64/build/cli\loader_exe.c:51
BaseThreadInitThunk at C:\WINDOWS\System32\KERNEL32.DLL (unknown line)
RtlUserThreadStart at C:\WINDOWS\SYSTEM32\ntdll.dll (unknown line)
Allocations: 43691142 (Pool: 43647190; Big: 43952); GC: 520

I have tried it with the GPU version, and still there is an error that crashes the REPL:

FATAL ERROR: Symbol "cudaGetErrorString"not found
signal (22): SIGABRT
in expression starting at REPL[6]:1
crt_sig_handler at /cygdrive/c/buildbot/worker/package_win64/build/src\signals-win.c:93
raise at C:\WINDOWS\System32\msvcrt.dll (unknown line)
abort at C:\WINDOWS\System32\msvcrt.dll (unknown line)
addModule at /cygdrive/c/buildbot/worker/package_win64/build/src\jitlayers.cpp:772
jl_add_to_ee at /cygdrive/c/buildbot/worker/package_win64/build/src\jitlayers.cpp:1068
jl_add_to_ee at /cygdrive/c/buildbot/worker/package_win64/build/src\jitlayers.cpp:1112
jl_add_to_ee at /cygdrive/c/buildbot/worker/package_win64/build/src\jitlayers.cpp:1097
jl_add_to_ee at /cygdrive/c/buildbot/worker/package_win64/build/src\jitlayers.cpp:1134 [inlined]
_jl_compile_codeinst at /cygdrive/c/buildbot/worker/package_win64/build/src\jitlayers.cpp:154
jl_generate_fptr at /cygdrive/c/buildbot/worker/package_win64/build/src\jitlayers.cpp:352
jl_compile_method_internal at /cygdrive/c/buildbot/worker/package_win64/build/src\gf.c:1970
jl_compile_method_internal at /cygdrive/c/buildbot/worker/package_win64/build/src\gf.c:1924 [inlined]
_jl_invoke at /cygdrive/c/buildbot/worker/package_win64/build/src\gf.c:2229 [inlined]
jl_apply_generic at /cygdrive/c/buildbot/worker/package_win64/build/src\gf.c:2419
overdub at C:\Users\jamie\.julia\packages\CUDA\k52QH\src\state.jl:71 [inlined]
overdub at C:\Users\jamie\.julia\packages\Cassette\jxIEh\src\overdub.jl:0
device() at C:\Users\jamie\.julia\packages\CUDA\k52QH\src\state.jl:207 [inlined]
overdub at C:\Users\jamie\.julia\packages\CUDA\k52QH\src\state.jl:207 [inlined]
overdub at C:\Users\jamie\.julia\packages\Cassette\jxIEh\src\overdub.jl:0
overdub at C:\Users\jamie\.julia\packages\CUDA\k52QH\lib\cublas\wrappers.jl:740 [inlined]
overdub at C:\Users\jamie\.julia\packages\Cassette\jxIEh\src\overdub.jl:0
overdub at C:\Users\jamie\.julia\packages\CUDA\k52QH\lib\cublas\linalg.jl:220 [inlined]
overdub at C:\Users\jamie\.julia\packages\Cassette\jxIEh\src\overdub.jl:0
mul!(::CuArray{Float32, 2}, ::CuArray{Float32, 2}, ::CuArray{Float32, 2}, ::Bool, ::Bool) at C:\Users\jamie\.julia\packages\CUDA\k52QH\lib\cublas\linalg.jl:233 [inlined]
overdub at C:\Users\jamie\.julia\packages\CUDA\k52QH\lib\cublas\linalg.jl:233 [inlined]
mul!(::CuArray{Float32, 2}, ::CuArray{Float32, 2}, ::CuArray{Float32, 2}) at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.6\LinearAlgebra\src\matmul.jl:275 [inlined]
overdub at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.6\LinearAlgebra\src\matmul.jl:275 [inlined]
overdub at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.6\LinearAlgebra\src\matmul.jl:160 [inlined]
overdub at C:\Users\jamie\.julia\packages\Flux\6o4DQ\src\layers\basic.jl:147 [inlined]
overdub at C:\Users\jamie\.julia\packages\Cassette\jxIEh\src\overdub.jl:0
applychain(::Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}, ::CuArray{Float32, 2}) at C:\Users\jamie\.julia\packages\Flux\6o4DQ\src\layers\basic.jl:36 [inlined]
overdub at C:\Users\jamie\.julia\packages\Flux\6o4DQ\src\layers\basic.jl:36 [inlined]
(::Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}})(::CuArray{Float32, 2}) at C:\Users\jamie\.julia\packages\Flux\6o4DQ\src\layers\basic.jl:38 [inlined]
overdub at C:\Users\jamie\.julia\packages\Flux\6o4DQ\src\layers\basic.jl:38 [inlined]
#record_allocations#4 at C:\Users\jamie\.julia\packages\AutoPreallocation\pywkM\src\recording.jl:21
record_allocations at C:\Users\jamie\.julia\packages\AutoPreallocation\pywkM\src\recording.jl:20 [inlined]
preallocate at C:\Users\jamie\.julia\packages\AutoPreallocation\pywkM\src\preallocate.jl:27
unknown function (ip: 000000005eed4f44)
jl_apply at /cygdrive/c/buildbot/worker/package_win64/build/src\julia.h:1703 [inlined]
do_call at /cygdrive/c/buildbot/worker/package_win64/build/src\interpreter.c:115
eval_value at /cygdrive/c/buildbot/worker/package_win64/build/src\interpreter.c:204
eval_stmt_value at /cygdrive/c/buildbot/worker/package_win64/build/src\interpreter.c:155 [inlined]
eval_body at /cygdrive/c/buildbot/worker/package_win64/build/src\interpreter.c:575
jl_interpret_toplevel_thunk at /cygdrive/c/buildbot/worker/package_win64/build/src\interpreter.c:669
jl_toplevel_eval_flex at /cygdrive/c/buildbot/worker/package_win64/build/src\toplevel.c:877
jl_toplevel_eval_flex at /cygdrive/c/buildbot/worker/package_win64/build/src\toplevel.c:825
jl_toplevel_eval_flex at /cygdrive/c/buildbot/worker/package_win64/build/src\toplevel.c:825
eval_body at /cygdrive/c/buildbot/worker/package_win64/build/src\interpreter.c:524
eval_body at /cygdrive/c/buildbot/worker/package_win64/build/src\interpreter.c:489
jl_interpret_toplevel_thunk at /cygdrive/c/buildbot/worker/package_win64/build/src\interpreter.c:669
jl_toplevel_eval_flex at /cygdrive/c/buildbot/worker/package_win64/build/src\toplevel.c:877
jl_toplevel_eval at /cygdrive/c/buildbot/worker/package_win64/build/src\toplevel.c:886 [inlined]
jl_toplevel_eval_in at /cygdrive/c/buildbot/worker/package_win64/build/src\toplevel.c:929
eval at .\boot.jl:360 [inlined]
eval_user_input at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.6\REPL\src\REPL.jl:139
repl_backend_loop at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.6\REPL\src\REPL.jl:200
start_repl_backend at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.6\REPL\src\REPL.jl:185
#run_repl#42 at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.6\REPL\src\REPL.jl:317
run_repl at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.6\REPL\src\REPL.jl:305
#874 at .\client.jl:387
jfptr_YY.874_23672.clone_1 at C:\Users\jamie\AppData\Local\Programs\Julia-1.6.0\lib\julia\sys.dll (unknown line)
jl_apply at /cygdrive/c/buildbot/worker/package_win64/build/src\julia.h:1703 [inlined]
jl_f__call_latest at /cygdrive/c/buildbot/worker/package_win64/build/src\builtins.c:714
#invokelatest#2 at .\essentials.jl:708 [inlined]
invokelatest at .\essentials.jl:706 [inlined]
run_main_repl at .\client.jl:372
exec_options at .\client.jl:302
_start at .\client.jl:485
jfptr__start_43335.clone_1 at C:\Users\jamie\AppData\Local\Programs\Julia-1.6.0\lib\julia\sys.dll (unknown line)
jl_apply at /cygdrive/c/buildbot/worker/package_win64/build/src\julia.h:1703 [inlined]
true_main at /cygdrive/c/buildbot/worker/package_win64/build/src\jlapi.c:560
repl_entrypoint at /cygdrive/c/buildbot/worker/package_win64/build/src\jlapi.c:702
mainCRTStartup at /cygdrive/c/buildbot/worker/package_win64/build/cli\loader_exe.c:51
BaseThreadInitThunk at C:\WINDOWS\System32\KERNEL32.DLL (unknown line)
RtlUserThreadStart at C:\WINDOWS\SYSTEM32\ntdll.dll (unknown line)
Allocations: 62763152 (Pool: 62688276; Big: 74876); GC: 67

What versions of Julia, Flux and AutoPreallocation are you using? It may be that the latter is depending on an old Cassette that doesn’t work on 1.6? Can you see if the examples in GitHub - oxinabox/AutoPreallocation.jl: What if your code allocated less? Remember what memory we needed last time and use it again every time after work at all?