Memory allocations in Flux evaluation/training (w/MVE)

I’m currently using Flux to do some GAN experiments and I’m finding that the training is painfully slow. On the CPU, a single epoch can take ~30min for a 286x32 sample dataloader. GPU evaluation is faster, but still feels slower than it should be. In trying to track down why this could be, I noticed that the model evaluations were allocating a significant amount of memory. The MVE to illustrate this is as follows:

using Flux, BenchmarkTools, StatsBase

function init_cvn()
	cv1=Conv((5,5), 12=>20, leakyrelu; bias=false)
	cv2=Conv((5,5), 20=>3, leakyrelu; bias=false)
	mp1=MeanPool((4,4)) 

	cval = Chain(
			Dense(20=>1, leakyrelu),
			x->permutedims(x,[2,1,3,4]),
			Dense(20=>1, leakyrelu),
			x->permutedims(x,[3,1,2,4]),
			Dense(3=>10, tanh),
	)
	
	Chain(cv1, mp1 ,cv2, cval)
end

function init_Dₓ()
	Chain(Dense(10=>10, leakyrelu), Dense(10=>1, sigmoid))
end

Dₙ=Chain(init_cvn(), init_Dₓ())

x=rand(Float32, 100,100,12,16)
@benchmark Dₙ($(x))

@benchmark for me gives an estimated allocation of 125.15 MiB with 171 allocations. Being a GAN, there are 5 models in the set, some with nested calls, and so this allocation adds up quickly it seems.

This all boils down to about three questions:

  1. Is there something weird about how I’ve set up the model that’s causing allocations? I know I haven’t provided any context on the problem, but if there’s also something clearly wrong outside of the allocation issue let me know.
  2. Am I going down the right rabbit hole in thinking this level of allocations are likely to be causing noticeable slowdowns?
  3. What options are available to address these allocations? I’ve looked into AllocatedArrays and SimpleChains, but I don’t think they will meet my needs because - respectively - I need to be able to train the models and I want to be able to move the model onto the GPU.

Any help on this would be greatly appreciated!

Setup information

Julia version: 1.12.6

Package data:
BenchmarkTools v1.8.0
Flux v0.16.10
StatsBase v0.34.10
Pkg v1.12.1

System data:
Intel i9-12900H
16GB RAM
GeForce RTX 3070Ti Laptop (not used here, but just in case it’s relevant)
Windows 11 Home

I get

julia> @benchmark Dₙ($(x))
BenchmarkTools.Trial: 364 samples with 1 evaluation per sample.
 Range (min … max):  11.956 ms … 157.352 ms  ┊ GC (min … max): 0.00% … 91.87%
 Time  (median):     13.070 ms               ┊ GC (median):    3.13%
 Time  (mean ± σ):   14.027 ms ±  10.759 ms  ┊ GC (mean ± σ):  8.59% ±  7.90%

     ▄ █▃
  █▁▄████▆▄▆▅▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▅▅▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄ ▆
  12 ms         Histogram: log(frequency) by time      23.6 ms <

 Memory estimate: 23.36 MiB, allocs estimate: 71.

Although my hardware is more generous (Studio M3 Ultra 96GB combined), the allocations are quite a bit lower. Just a data point, not entirely sure what to make of it.

Interesting, here’s my detailed benchmark results

Should probably also have mentioned I’m running all this in Pluto. I would doubt that would cause an issue, but you never know :man_shrugging:

I ran Profiler as well, although I don’t know how to helpfully interpret the results. Unsurprisingly, there seem to be a few things going on, but I’m not sure which might be the most significant.

unknown allocs
Flat Flat%	Sum%	Cum	Cum%	Name	Inlined?
22	12.87%	12.87%	22	12.87%	Alloc: Base.IntrusiveLinkedList{Task}	
18	10.53%	23.39%	18	10.53%	Alloc: Base.Threads.SpinLock	
16	9.36%	32.75%	16	9.36%	Alloc: Task	
16	9.36%	42.11%	16	9.36%	Alloc: NNlib.var\"#539#540\"{NNlib.var\"#conv_part#538\"{Array{Float32, 3}, Float32, Float32, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}, NNlib.DenseConvDims{3, 3, 3, 6, 3}, Int64, Int64, Int64}, UnitRange{Int64}, Int64}	
16	9.36%	51.46%	16	9.36%	Alloc: Base.GenericCondition{Base.Threads.SpinLock}	
12	7.02%	58.48%	12	7.02%	Alloc: Memory{Float32}	
10	5.85%	64.33%	10	5.85%	Alloc: Matrix{Float32}	
10	5.85%	70.18%	10	5.85%	Alloc: Array{Float32, 4}	
8	4.68%	74.85%	8	4.68%	Alloc: Array{Float32, 5}	
7	4.09%	78.95%	7	4.09%	Alloc: Profile.Allocs.BufferType	
4	2.34%	81.29%	4	2.34%	Alloc: Vector{UInt64}	
4	2.34%	83.63%	4	2.34%	Alloc: Memory{UInt64}	
4	2.34%	85.96%	4	2.34%	Alloc: Memory{Any}	
4	2.34%	88.30%	4	2.34%	Alloc: BitVector	
3	1.75%	90.06%	3	1.75%	Alloc: Tuple{Int64, Int64, Int64}	
2	1.17%	91.23%	2	1.17%	Alloc: Vector{Int64}	
2	1.17%	92.40%	2	1.17%	Alloc: Vector{Any}	
2	1.17%	93.57%	2	1.17%	Alloc: ReentrantLock	
2	1.17%	94.74%	2	1.17%	Alloc: Memory{Int64}	
2	1.17%	95.91%	2	1.17%	Alloc: InvalidStateException	
2	1.17%	97.08%	2	1.17%	Alloc: Channel{Any}	
2	1.17%	98.25%	2	1.17%	Alloc: Array{Float32, 3}	
1	0.58%	98.83%	1	0.58%	Alloc: NTuple{6, Int64}	
1	0.58%	99.42%	1	0.58%	Alloc: NNlib.PoolDims{3, 3, 3, 6, 3}	
1	0.58%	100.00%	1	0.58%	Alloc: @NamedTuple{alpha::Int64, beta::Int64}	
0	0.00%	100.00%	171	100.00%	with_logstate	
0	0.00%	100.00%	171	100.00%	with_logger_and_io_to_logs	
0	0.00%	100.00%	171	100.00%	with_logger	
0	0.00%	100.00%	171	100.00%	with_io_to_logs	
0	0.00%	100.00%	4	2.34%	vect	
0	0.00%	100.00%	2	1.17%	sync_end(::Channel{Any})	
0	0.00%	100.00%	171	100.00%	start_task	
0	0.00%	100.00%	31	18.13%	similar	
0	0.00%	100.00%	171	100.00%	run_inside_trycatch	
0	0.00%	100.00%	171	100.00%	run_expression	
0	0.00%	100.00%	18	10.53%	reshape	
0	0.00%	100.00%	4	2.34%	put_buffered(::Channel{Any}, ::Task)	
0	0.00%	100.00%	4	2.34%	put!	
0	0.00%	100.00%	4	2.34%	push!	
0	0.00%	100.00%	6	3.51%	permutedims!	
0	0.00%	100.00%	17	9.94%	permutedims	
0	0.00%	100.00%	19	11.11%	new_as_memoryref	
0	0.00%	100.00%	6	3.51%	meanpool_direct!	
0	0.00%	100.00%	8	4.68%	meanpool!	
0	0.00%	100.00%	11	6.43%	meanpool	
0	0.00%	100.00%	171	100.00%	maybe_record_alloc_to_profile	
0	0.00%	100.00%	171	100.00%	macro expansion	
0	0.00%	100.00%	171	100.00%	jl_toplevel_eval_flex	
0	0.00%	100.00%	171	100.00%	jl_interpret_toplevel_thunk	
0	0.00%	100.00%	36	21.05%	jl_gc_alloc_	
0	0.00%	100.00%	171	100.00%	jl_f_invokelatest	
0	0.00%	100.00%	171	100.00%	jl_f__apply_iterate	
0	0.00%	100.00%	171	100.00%	jl_apply	
0	0.00%	100.00%	27	15.79%	jl_alloc_genericmemory_unchecked	
0	0.00%	100.00%	12	7.02%	isperm	
0	0.00%	100.00%	8	4.68%	insert_singleton_spatial_dimension	
0	0.00%	100.00%	171	100.00%	ijl_toplevel_eval_in	
0	0.00%	100.00%	171	100.00%	ijl_toplevel_eval	
0	0.00%	100.00%	16	9.36%	ijl_new_task	
0	0.00%	100.00%	128	74.85%	ijl_gc_small_alloc	
0	0.00%	100.00%	7	4.09%	ijl_gc_managed_malloc	
0	0.00%	100.00%	12	7.02%	falses	
0	0.00%	100.00%	171	100.00%	eval_value	
0	0.00%	100.00%	171	100.00%	eval_stmt_value	
0	0.00%	100.00%	171	100.00%	eval_body	
0	0.00%	100.00%	171	100.00%	eval(::Module, ::Any)	
0	0.00%	100.00%	171	100.00%	do_call	
0	0.00%	100.00%	106	61.99%	conv_im2col!	
0	0.00%	100.00%	106	61.99%	conv_group	
0	0.00%	100.00%	112	65.50%	conv!	
0	0.00%	100.00%	118	69.01%	conv	
0	0.00%	100.00%	171	100.00%	compute	
0	0.00%	100.00%	2	1.17%	close	
0	0.00%	100.00%	6	3.51%	checkdims_perm	
0	0.00%	100.00%	4	2.34%	array_new_memory	
0	0.00%	100.00%	12	7.02%	_isperm	
0	0.00%	100.00%	4	2.34%	_growend!	
0	0.00%	100.00%	171	100.00%	_applychain	
0	0.00%	100.00%	48	28.07%	_Task	
0	0.00%	100.00%	171	100.00%	[unknown function]	
0	0.00%	100.00%	4	2.34%	Val	
0	0.00%	100.00%	80	46.78%	Task	
0	0.00%	100.00%	18	10.53%	SpinLock	
0	0.00%	100.00%	6	3.51%	ReentrantLock	
0	0.00%	100.00%	11	6.43%	MeanPool	
0	0.00%	100.00%	22	12.87%	IntrusiveLinkedList	
0	0.00%	100.00%	29	16.96%	GenericMemory	
0	0.00%	100.00%	40	23.39%	GenericCondition	
0	0.00%	100.00%	21	12.28%	Dense	
0	0.00%	100.00%	118	69.01%	Conv	
0	0.00%	100.00%	14	8.19%	Channel	
0	0.00%	100.00%	171	100.00%	Chain	
0	0.00%	100.00%	12	7.02%	BitArray	
0	0.00%	100.00%	45	26.32%	Array	
0	0.00%	100.00%	11	6.43%	*	
0	0.00%	100.00%	4	2.34%	(::Base.var\"#_growend!##0#_growend!##1\"{Vector{Any}, Int64, Int64, Int64, Int64, Int64, Memory{Any}, MemoryRef{Any}})()	
0	0.00%	100.00%	171	100.00%	#with_logger_and_io_to_logs#121	
0	0.00%	100.00%	171	100.00%	#with_io_to_logs#125	
0	0.00%	100.00%	171	100.00%	#run_expression#28	
0	0.00%	100.00%	6	3.51%	#meanpool_direct!#564	
0	0.00%	100.00%	11	6.43%	#meanpool#377	
0	0.00%	100.00%	8	4.68%	#meanpool!#361	
0	0.00%	100.00%	6	3.51%	#meanpool!#346	
0	0.00%	100.00%	10	5.85%	#init_cvn##2	
0	0.00%	100.00%	11	6.43%	#init_cvn##0	
0	0.00%	100.00%	171	100.00%	#handle##0	
0	0.00%	100.00%	100	58.48%	#conv_im2col!#536	
0	0.00%	100.00%	118	69.01%	#conv#124	
0	0.00%	100.00%	106	61.99%	#conv!#181	
0	0.00%	100.00%	112	65.50%	#conv!#143	
0	0.00%	100.00%	171	100.00%	#36	
0	0.00%	100.00%	171	100.00%	#32	
0	0.00%	100.00%	171	100.00%	#123	
0	0.00%	100.00%	171	100.00%	##function_wrapped_cell#632