Memory allocations in Flux evaluation/training (w/MVE)

I’m currently using Flux to do some GAN experiments and I’m finding that the training is painfully slow. On the CPU, a single epoch can take ~30min for a 286x32 sample dataloader. GPU evaluation is faster, but still feels slower than it should be. In trying to track down why this could be, I noticed that the model evaluations were allocating a significant amount of memory. The MVE to illustrate this is as follows:

using Flux, BenchmarkTools, StatsBase

function init_cvn()
	cv1=Conv((5,5), 12=>20, leakyrelu; bias=false)
	cv2=Conv((5,5), 20=>3, leakyrelu; bias=false)
	mp1=MeanPool((4,4)) 

	cval = Chain(
			Dense(20=>1, leakyrelu),
			x->permutedims(x,[2,1,3,4]),
			Dense(20=>1, leakyrelu),
			x->permutedims(x,[3,1,2,4]),
			Dense(3=>10, tanh),
	)
	
	Chain(cv1, mp1 ,cv2, cval)
end

function init_Dₓ()
	Chain(Dense(10=>10, leakyrelu), Dense(10=>1, sigmoid))
end

Dₙ=Chain(init_cvn(), init_Dₓ())

x=rand(Float32, 100,100,12,16)
@benchmark Dₙ($(x))

@benchmark for me gives an estimated allocation of 125.15 MiB with 171 allocations. Being a GAN, there are 5 models in the set, some with nested calls, and so this allocation adds up quickly it seems.

This all boils down to about three questions:

  1. Is there something weird about how I’ve set up the model that’s causing allocations? I know I haven’t provided any context on the problem, but if there’s also something clearly wrong outside of the allocation issue let me know.
  2. Am I going down the right rabbit hole in thinking this level of allocations are likely to be causing noticeable slowdowns?
  3. What options are available to address these allocations? I’ve looked into AllocatedArrays and SimpleChains, but I don’t think they will meet my needs because - respectively - I need to be able to train the models and I want to be able to move the model onto the GPU.

Any help on this would be greatly appreciated!

Setup information

Julia version: 1.12.6

Package data:
BenchmarkTools v1.8.0
Flux v0.16.10
StatsBase v0.34.10
Pkg v1.12.1

System data:
Intel i9-12900H
16GB RAM
GeForce RTX 3070Ti Laptop (not used here, but just in case it’s relevant)
Windows 11 Home

I get

julia> @benchmark Dₙ($(x))
BenchmarkTools.Trial: 364 samples with 1 evaluation per sample.
 Range (min … max):  11.956 ms … 157.352 ms  ┊ GC (min … max): 0.00% … 91.87%
 Time  (median):     13.070 ms               ┊ GC (median):    3.13%
 Time  (mean ± σ):   14.027 ms ±  10.759 ms  ┊ GC (mean ± σ):  8.59% ±  7.90%

     ▄ █▃
  █▁▄████▆▄▆▅▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▅▅▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄ ▆
  12 ms         Histogram: log(frequency) by time      23.6 ms <

 Memory estimate: 23.36 MiB, allocs estimate: 71.

Although my hardware is more generous (Studio M3 Ultra 96GB combined), the allocations are quite a bit lower. Just a data point, not entirely sure what to make of it.

Interesting, here’s my detailed benchmark results

Should probably also have mentioned I’m running all this in Pluto. I would doubt that would cause an issue, but you never know :man_shrugging:

I ran Profiler as well, although I don’t know how to helpfully interpret the results. Unsurprisingly, there seem to be a few things going on, but I’m not sure which might be the most significant.

unknown allocs
Flat Flat%	Sum%	Cum	Cum%	Name	Inlined?
22	12.87%	12.87%	22	12.87%	Alloc: Base.IntrusiveLinkedList{Task}	
18	10.53%	23.39%	18	10.53%	Alloc: Base.Threads.SpinLock	
16	9.36%	32.75%	16	9.36%	Alloc: Task	
16	9.36%	42.11%	16	9.36%	Alloc: NNlib.var\"#539#540\"{NNlib.var\"#conv_part#538\"{Array{Float32, 3}, Float32, Float32, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}, NNlib.DenseConvDims{3, 3, 3, 6, 3}, Int64, Int64, Int64}, UnitRange{Int64}, Int64}	
16	9.36%	51.46%	16	9.36%	Alloc: Base.GenericCondition{Base.Threads.SpinLock}	
12	7.02%	58.48%	12	7.02%	Alloc: Memory{Float32}	
10	5.85%	64.33%	10	5.85%	Alloc: Matrix{Float32}	
10	5.85%	70.18%	10	5.85%	Alloc: Array{Float32, 4}	
8	4.68%	74.85%	8	4.68%	Alloc: Array{Float32, 5}	
7	4.09%	78.95%	7	4.09%	Alloc: Profile.Allocs.BufferType	
4	2.34%	81.29%	4	2.34%	Alloc: Vector{UInt64}	
4	2.34%	83.63%	4	2.34%	Alloc: Memory{UInt64}	
4	2.34%	85.96%	4	2.34%	Alloc: Memory{Any}	
4	2.34%	88.30%	4	2.34%	Alloc: BitVector	
3	1.75%	90.06%	3	1.75%	Alloc: Tuple{Int64, Int64, Int64}	
2	1.17%	91.23%	2	1.17%	Alloc: Vector{Int64}	
2	1.17%	92.40%	2	1.17%	Alloc: Vector{Any}	
2	1.17%	93.57%	2	1.17%	Alloc: ReentrantLock	
2	1.17%	94.74%	2	1.17%	Alloc: Memory{Int64}	
2	1.17%	95.91%	2	1.17%	Alloc: InvalidStateException	
2	1.17%	97.08%	2	1.17%	Alloc: Channel{Any}	
2	1.17%	98.25%	2	1.17%	Alloc: Array{Float32, 3}	
1	0.58%	98.83%	1	0.58%	Alloc: NTuple{6, Int64}	
1	0.58%	99.42%	1	0.58%	Alloc: NNlib.PoolDims{3, 3, 3, 6, 3}	
1	0.58%	100.00%	1	0.58%	Alloc: @NamedTuple{alpha::Int64, beta::Int64}	
0	0.00%	100.00%	171	100.00%	with_logstate	
0	0.00%	100.00%	171	100.00%	with_logger_and_io_to_logs	
0	0.00%	100.00%	171	100.00%	with_logger	
0	0.00%	100.00%	171	100.00%	with_io_to_logs	
0	0.00%	100.00%	4	2.34%	vect	
0	0.00%	100.00%	2	1.17%	sync_end(::Channel{Any})	
0	0.00%	100.00%	171	100.00%	start_task	
0	0.00%	100.00%	31	18.13%	similar	
0	0.00%	100.00%	171	100.00%	run_inside_trycatch	
0	0.00%	100.00%	171	100.00%	run_expression	
0	0.00%	100.00%	18	10.53%	reshape	
0	0.00%	100.00%	4	2.34%	put_buffered(::Channel{Any}, ::Task)	
0	0.00%	100.00%	4	2.34%	put!	
0	0.00%	100.00%	4	2.34%	push!	
0	0.00%	100.00%	6	3.51%	permutedims!	
0	0.00%	100.00%	17	9.94%	permutedims	
0	0.00%	100.00%	19	11.11%	new_as_memoryref	
0	0.00%	100.00%	6	3.51%	meanpool_direct!	
0	0.00%	100.00%	8	4.68%	meanpool!	
0	0.00%	100.00%	11	6.43%	meanpool	
0	0.00%	100.00%	171	100.00%	maybe_record_alloc_to_profile	
0	0.00%	100.00%	171	100.00%	macro expansion	
0	0.00%	100.00%	171	100.00%	jl_toplevel_eval_flex	
0	0.00%	100.00%	171	100.00%	jl_interpret_toplevel_thunk	
0	0.00%	100.00%	36	21.05%	jl_gc_alloc_	
0	0.00%	100.00%	171	100.00%	jl_f_invokelatest	
0	0.00%	100.00%	171	100.00%	jl_f__apply_iterate	
0	0.00%	100.00%	171	100.00%	jl_apply	
0	0.00%	100.00%	27	15.79%	jl_alloc_genericmemory_unchecked	
0	0.00%	100.00%	12	7.02%	isperm	
0	0.00%	100.00%	8	4.68%	insert_singleton_spatial_dimension	
0	0.00%	100.00%	171	100.00%	ijl_toplevel_eval_in	
0	0.00%	100.00%	171	100.00%	ijl_toplevel_eval	
0	0.00%	100.00%	16	9.36%	ijl_new_task	
0	0.00%	100.00%	128	74.85%	ijl_gc_small_alloc	
0	0.00%	100.00%	7	4.09%	ijl_gc_managed_malloc	
0	0.00%	100.00%	12	7.02%	falses	
0	0.00%	100.00%	171	100.00%	eval_value	
0	0.00%	100.00%	171	100.00%	eval_stmt_value	
0	0.00%	100.00%	171	100.00%	eval_body	
0	0.00%	100.00%	171	100.00%	eval(::Module, ::Any)	
0	0.00%	100.00%	171	100.00%	do_call	
0	0.00%	100.00%	106	61.99%	conv_im2col!	
0	0.00%	100.00%	106	61.99%	conv_group	
0	0.00%	100.00%	112	65.50%	conv!	
0	0.00%	100.00%	118	69.01%	conv	
0	0.00%	100.00%	171	100.00%	compute	
0	0.00%	100.00%	2	1.17%	close	
0	0.00%	100.00%	6	3.51%	checkdims_perm	
0	0.00%	100.00%	4	2.34%	array_new_memory	
0	0.00%	100.00%	12	7.02%	_isperm	
0	0.00%	100.00%	4	2.34%	_growend!	
0	0.00%	100.00%	171	100.00%	_applychain	
0	0.00%	100.00%	48	28.07%	_Task	
0	0.00%	100.00%	171	100.00%	[unknown function]	
0	0.00%	100.00%	4	2.34%	Val	
0	0.00%	100.00%	80	46.78%	Task	
0	0.00%	100.00%	18	10.53%	SpinLock	
0	0.00%	100.00%	6	3.51%	ReentrantLock	
0	0.00%	100.00%	11	6.43%	MeanPool	
0	0.00%	100.00%	22	12.87%	IntrusiveLinkedList	
0	0.00%	100.00%	29	16.96%	GenericMemory	
0	0.00%	100.00%	40	23.39%	GenericCondition	
0	0.00%	100.00%	21	12.28%	Dense	
0	0.00%	100.00%	118	69.01%	Conv	
0	0.00%	100.00%	14	8.19%	Channel	
0	0.00%	100.00%	171	100.00%	Chain	
0	0.00%	100.00%	12	7.02%	BitArray	
0	0.00%	100.00%	45	26.32%	Array	
0	0.00%	100.00%	11	6.43%	*	
0	0.00%	100.00%	4	2.34%	(::Base.var\"#_growend!##0#_growend!##1\"{Vector{Any}, Int64, Int64, Int64, Int64, Int64, Memory{Any}, MemoryRef{Any}})()	
0	0.00%	100.00%	171	100.00%	#with_logger_and_io_to_logs#121	
0	0.00%	100.00%	171	100.00%	#with_io_to_logs#125	
0	0.00%	100.00%	171	100.00%	#run_expression#28	
0	0.00%	100.00%	6	3.51%	#meanpool_direct!#564	
0	0.00%	100.00%	11	6.43%	#meanpool#377	
0	0.00%	100.00%	8	4.68%	#meanpool!#361	
0	0.00%	100.00%	6	3.51%	#meanpool!#346	
0	0.00%	100.00%	10	5.85%	#init_cvn##2	
0	0.00%	100.00%	11	6.43%	#init_cvn##0	
0	0.00%	100.00%	171	100.00%	#handle##0	
0	0.00%	100.00%	100	58.48%	#conv_im2col!#536	
0	0.00%	100.00%	118	69.01%	#conv#124	
0	0.00%	100.00%	106	61.99%	#conv!#181	
0	0.00%	100.00%	112	65.50%	#conv!#143	
0	0.00%	100.00%	171	100.00%	#36	
0	0.00%	100.00%	171	100.00%	#32	
0	0.00%	100.00%	171	100.00%	#123	
0	0.00%	100.00%	171	100.00%	##function_wrapped_cell#632	

If I’m reading this correctly, 33% of your allocations are just the Julia scheduler trying to keep its head above water, and 75% of your total execution time is spent on Garbage Collection. The code is asking for high-performance neural net operations (like im2col), but your hardware is responding with ‘Out of Memory’ and ‘Thermal Throttling.’ Try running a smaller version of the MWE (half the batch size). If the IntrusiveLinkedList and SpinLock percentages drop significantly that’s a pretty good indication that your laptop is underpowered for what you’re trying to do.

I ran it with 8 samples instead of the 16 from the original MVE and got nearly identical results, at least for those two allocation types. As far as the laptop being underpowered or overtaxed, it’s certainly possible, but I have run other models on it that seemed to perform reasonably and weren’t too different from what I’m doing here. At least, didn’t seem all that different…

Flat	Flat%	Sum%	Cum	Cum%	Name	Inlined?
22	12.94%	12.94%	22	12.94%	Alloc: Base.IntrusiveLinkedList{Task}	
18	10.59%	23.53%	18	10.59%	Alloc: Base.Threads.SpinLock	
16	9.41%	32.94%	16	9.41%	Alloc: Task	
16	9.41%	42.35%	16	9.41%	Alloc: NNlib.var\"#539#540\"{NNlib.var\"#conv_part#538\"{Array{Float32, 3}, Float32, Float32, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}, NNlib.DenseConvDims{3, 3, 3, 6, 3}, Int64, Int64, Int64}, UnitRange{Int64}, Int64}	
16	9.41%	51.76%	16	9.41%	Alloc: Base.GenericCondition{Base.Threads.SpinLock}	
12	7.06%	58.82%	12	7.06%	Alloc: Memory{Float32}	
10	5.88%	64.71%	10	5.88%	Alloc: Matrix{Float32}	
10	5.88%	70.59%	10	5.88%	Alloc: Array{Float32, 4}	
8	4.71%	75.29%	8	4.71%	Alloc: Array{Float32, 5}	
5	2.94%	78.24%	5	2.94%	Alloc: Profile.Allocs.BufferType	
4	2.35%	80.59%	4	2.35%	Alloc: Vector{UInt64}	
4	2.35%	82.94%	4	2.35%	Alloc: Memory{UInt64}	
4	2.35%	85.29%	4	2.35%	Alloc: Memory{Any}	
4	2.35%	87.65%	4	2.35%	Alloc: BitVector	
3	1.76%	89.41%	3	1.76%	Alloc: Tuple{Int64, Int64, Int64}	
2	1.18%	90.59%	2	1.18%	Alloc: Vector{Int64}	
2	1.18%	91.76%	2	1.18%	Alloc: Vector{Any}	
2	1.18%	92.94%	2	1.18%	Alloc: ReentrantLock	
2	1.18%	94.12%	2	1.18%	Alloc: Memory{Int64}	
2	1.18%	95.29%	2	1.18%	Alloc: InvalidStateException	
2	1.18%	96.47%	2	1.18%	Alloc: Channel{Any}	
2	1.18%	97.65%	2	1.18%	Alloc: Array{Float32, 3}	
1	0.59%	98.24%	1	0.59%	Alloc: NTuple{6, Int64}	
1	0.59%	98.82%	1	0.59%	Alloc: NNlib.PoolDims{3, 3, 3, 6, 3}	
1	0.59%	99.41%	1	0.59%	Alloc: Flux.Chain{Tuple{Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(NNlib.leakyrelu), Array{Float32, 4}, Bool}, Flux.MeanPool{2, 4}, Flux.Conv{2, 4, typeof(NNlib.leakyrelu), Array{Float32, 4}, Bool}, Flux.Chain{Tuple{Flux.Dense{typeof(NNlib.leakyrelu), Matrix{Float32}, Vector{Float32}}, Main.var\"workspace#22\".var\"#init_cvn##0#init_cvn##1\", Flux.Dense{typeof(NNlib.leakyrelu), Matrix{Float32}, Vector{Float32}}, Main.var\"workspace#22\".var\"#init_cvn##2#init_cvn##3\", Flux.Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}}, Flux.Chain{Tuple{Flux.Dense{typeof(NNlib.leakyrelu), Matrix{Float32}, Vector{Float32}}, Flux.Dense{typeof(NNlib.σ), Matrix{Float32}, Vector{Float32}}}}}}	
1	0.59%	100.00%	1	0.59%	Alloc: @NamedTuple{alpha::Int64, beta::Int64}	
0	0.00%	100.00%	170	100.00%	with_logstate	
0	0.00%	100.00%	170	100.00%	with_logger_and_io_to_logs	
0	0.00%	100.00%	170	100.00%	with_logger	
0	0.00%	100.00%	170	100.00%	with_io_to_logs	
0	0.00%	100.00%	4	2.35%	vect	
0	0.00%	100.00%	2	1.18%	sync_end(::Channel{Any})	
0	0.00%	100.00%	170	100.00%	start_task	
0	0.00%	100.00%	29	17.06%	similar	
0	0.00%	100.00%	170	100.00%	run_inside_trycatch	
0	0.00%	100.00%	170	100.00%	run_expression	
0	0.00%	100.00%	18	10.59%	reshape	
0	0.00%	100.00%	4	2.35%	put_buffered(::Channel{Any}, ::Task)	
0	0.00%	100.00%	4	2.35%	put!	
0	0.00%	100.00%	4	2.35%	push!	
0	0.00%	100.00%	6	3.53%	permutedims!	
0	0.00%	100.00%	16	9.41%	permutedims	
0	0.00%	100.00%	17	10.00%	new_as_memoryref	
0	0.00%	100.00%	6	3.53%	meanpool_direct!	
0	0.00%	100.00%	8	4.71%	meanpool!	
0	0.00%	100.00%	11	6.47%	meanpool	
0	0.00%	100.00%	170	100.00%	maybe_record_alloc_to_profile	
0	0.00%	100.00%	170	100.00%	macro expansion	
0	0.00%	100.00%	170	100.00%	jl_toplevel_eval_flex	
0	0.00%	100.00%	170	100.00%	jl_interpret_toplevel_thunk	
0	0.00%	100.00%	36	21.18%	jl_gc_alloc_	
0	0.00%	100.00%	170	100.00%	jl_f_invokelatest	
0	0.00%	100.00%	170	100.00%	jl_f__apply_iterate	
0	0.00%	100.00%	170	100.00%	jl_apply	
0	0.00%	100.00%	25	14.71%	jl_alloc_genericmemory_unchecked	
0	0.00%	100.00%	12	7.06%	isperm	
0	0.00%	100.00%	8	4.71%	insert_singleton_spatial_dimension	
0	0.00%	100.00%	170	100.00%	ijl_toplevel_eval_in	
0	0.00%	100.00%	170	100.00%	ijl_toplevel_eval	
0	0.00%	100.00%	16	9.41%	ijl_new_task	
0	0.00%	100.00%	129	75.88%	ijl_gc_small_alloc	
0	0.00%	100.00%	5	2.94%	ijl_gc_managed_malloc	
0	0.00%	100.00%	12	7.06%	falses	
0	0.00%	100.00%	170	100.00%	eval_value	
0	0.00%	100.00%	170	100.00%	eval_stmt_value	
0	0.00%	100.00%	170	100.00%	eval_body	
0	0.00%	100.00%	170	100.00%	eval(::Module, ::Any)	
0	0.00%	100.00%	170	100.00%	do_call	
0	0.00%	100.00%	106	62.35%	conv_im2col!	
0	0.00%	100.00%	106	62.35%	conv_group	
0	0.00%	100.00%	112	65.88%	conv!	
0	0.00%	100.00%	118	69.41%	conv	
0	0.00%	100.00%	170	100.00%	compute	
0	0.00%	100.00%	2	1.18%	close	
0	0.00%	100.00%	6	3.53%	checkdims_perm	
0	0.00%	100.00%	4	2.35%	array_new_memory	
0	0.00%	100.00%	12	7.06%	_isperm	
0	0.00%	100.00%	4	2.35%	_growend!	
0	0.00%	100.00%	169	99.41%	_applychain	
0	0.00%	100.00%	48	28.24%	_Task	
0	0.00%	100.00%	170	100.00%	[unknown function]	
0	0.00%	100.00%	4	2.35%	Val	
0	0.00%	100.00%	80	47.06%	Task	
0	0.00%	100.00%	18	10.59%	SpinLock	
0	0.00%	100.00%	6	3.53%	ReentrantLock	
0	0.00%	100.00%	11	6.47%	MeanPool	
0	0.00%	100.00%	22	12.94%	IntrusiveLinkedList	
0	0.00%	100.00%	27	15.88%	GenericMemory	
0	0.00%	100.00%	40	23.53%	GenericCondition	
0	0.00%	100.00%	20	11.76%	Dense	
0	0.00%	100.00%	118	69.41%	Conv	
0	0.00%	100.00%	14	8.24%	Channel	
0	0.00%	100.00%	169	99.41%	Chain	
0	0.00%	100.00%	12	7.06%	BitArray	
0	0.00%	100.00%	43	25.29%	Array	
0	0.00%	100.00%	10	5.88%	*	
0	0.00%	100.00%	4	2.35%	(::Base.var\"#_growend!##0#_growend!##1\"{Vector{Any}, Int64, Int64, Int64, Int64, Int64, Memory{Any}, MemoryRef{Any}})()	
0	0.00%	100.00%	170	100.00%	#with_logger_and_io_to_logs#121	
0	0.00%	100.00%	170	100.00%	#with_io_to_logs#125	
0	0.00%	100.00%	170	100.00%	#run_expression#28	
0	0.00%	100.00%	6	3.53%	#meanpool_direct!#564	
0	0.00%	100.00%	11	6.47%	#meanpool#377	
0	0.00%	100.00%	8	4.71%	#meanpool!#361	
0	0.00%	100.00%	6	3.53%	#meanpool!#346	
0	0.00%	100.00%	10	5.88%	#init_cvn##2	
0	0.00%	100.00%	10	5.88%	#init_cvn##0	
0	0.00%	100.00%	170	100.00%	#handle##0	
0	0.00%	100.00%	100	58.82%	#conv_im2col!#536	
0	0.00%	100.00%	118	69.41%	#conv#124	
0	0.00%	100.00%	106	62.35%	#conv!#181	
0	0.00%	100.00%	112	65.88%	#conv!#143	
0	0.00%	100.00%	170	100.00%	#36	
0	0.00%	100.00%	170	100.00%	#32	
0	0.00%	100.00%	170	100.00%	#123	
0	0.00%	100.00%	170	100.00%	##function_wrapped_cell#665	

Isn’t permutedims the culprit for the allocation?

Look at my profile on the original MWE

Overhead ╎ [+additional indent] Count File:Line  Function
=========================================================
        ╎24482328 @Base/client.jl:561  _start()
        ╎ 24482328 @Base/client.jl:586  repl_main
        ╎  24482328 @Base/client.jl:499  run_main_repl(interactive::Bool, quiet…
        ╎   24482328 @Base/client.jl:478  run_std_repl(REPL::Module, quiet::Boo…
        ╎    24482328 @REPL/src/REPL.jl:639  run_repl(repl::REPL.AbstractREPL, …
        ╎     24482328 @REPL/…rc/REPL.jl:653  run_repl(repl::REPL.AbstractREPL,…
        ╎    ╎ 24482328 @REPL/…rc/REPL.jl:424  start_repl_backend
        ╎    ╎  24482328 @REPL/…c/REPL.jl:427  start_repl_backend(backend::REPL…
        ╎    ╎   24482328 @REPL/…c/REPL.jl:452  repl_backend_loop(backend::REPL…
        ╎    ╎    24482328 @REPL/…c/REPL.jl:330  eval_user_input(ast::Any, back…
        ╎    ╎     24482328 @REPL/…/REPL.jl:305  toplevel_eval_with_hooks
        ╎    ╎    ╎ 24482328 @REPL/…/REPL.jl:312  toplevel_eval_with_hooks(mod:…
        ╎    ╎    ╎  24482328 @REPL/…REPL.jl:312  toplevel_eval_with_hooks(mod:…
        ╎    ╎    ╎   24482328 @REPL/…REPL.jl:312  toplevel_eval_with_hooks(mod…
        ╎    ╎    ╎    24482328 @REPL/…REPL.jl:308  toplevel_eval_with_hooks(mo…
     144╎    ╎    ╎     24482328 @REPL/…EPL.jl:301  __repl_entry_eval_expanded_…
        ╎    ╎    ╎    ╎ 24482184 @Flux/…sic.jl:65  (::Chain{Tuple{Chain{Tuple{…
        ╎    ╎    ╎    ╎  24482184 @Flux/…sic.jl:68  _applychain
        ╎    ╎    ╎    ╎   24482184 @Flux/…ic.jl:68  macro expansion
        ╎    ╎    ╎    ╎    24482184 @Flux/…ic.jl:65  Chain
        ╎    ╎    ╎    ╎     1056     @Flux/…ic.jl:68  _applychain
        ╎    ╎    ╎    ╎    ╎ 1056     @Flux/…ic.jl:68  macro expansion
        ╎    ╎    ╎    ╎    ╎  1056     @Flux/…ic.jl:204  (::Dense{typeof(σ), M…
        ╎    ╎    ╎    ╎    ╎   224      @Base/…ay.jl:132  reshape
        ╎    ╎    ╎    ╎    ╎    224      @Base/…ay.jl:133  reshape(parent::Mat…
     224╎    ╎    ╎    ╎    ╎     224      @Base/…ay.jl:50  reshape
        ╎    ╎    ╎    ╎    ╎   832      @Flux/…ic.jl:199  (::Dense{typeof(σ), …
        ╎    ╎    ╎    ╎    ╎    832      @LinearAlgebra/…:136  *
        ╎    ╎    ╎    ╎    ╎     832      @Base/…ay.jl:377  similar
        ╎    ╎    ╎    ╎    ╎    ╎ 832      @Base/…ot.jl:661  Array
      96╎    ╎    ╎    ╎    ╎    ╎  832      @Base/…ot.jl:651  Array
        ╎    ╎    ╎    ╎    ╎    ╎   736      @Base/…ot.jl:604  new_as_memoryref
     736╎    ╎    ╎    ╎    ╎    ╎    736      @Base/…ot.jl:588  GenericMemory
        ╎    ╎    ╎    ╎     24481128 @Flux/…ic.jl:68  _applychain(layers::Tupl…
        ╎    ╎    ╎    ╎    ╎ 24481128 @Flux/…ic.jl:68  macro expansion
        ╎    ╎    ╎    ╎    ╎  9952     @Flux/…ic.jl:65  Chain
        ╎    ╎    ╎    ╎    ╎   9952     @Flux/…ic.jl:68  _applychain(layers::T…
        ╎    ╎    ╎    ╎    ╎    9952     @Flux/…ic.jl:68  macro expansion
        ╎    ╎    ╎    ╎    ╎     4200     REPL[2]:8  #init_cvn##0
        ╎    ╎    ╎    ╎    ╎    ╎ 96       @Base/…ay.jl:161  vect
        ╎    ╎    ╎    ╎    ╎    ╎  64       @Base/…ot.jl:647  Array
      64╎    ╎    ╎    ╎    ╎    ╎   64       @Base/…ot.jl:588  GenericMemory
      32╎    ╎    ╎    ╎    ╎    ╎  32       @Base/…ot.jl:648  Array
        ╎    ╎    ╎    ╎    ╎    ╎ 88       @Base/…al.jl:1672  permutedims(B::A…
        ╎    ╎    ╎    ╎    ╎    ╎  88       @Base/…cs.jl:76  isperm
        ╎    ╎    ╎    ╎    ╎    ╎   88       @Base/…cs.jl:80  _isperm(A::Vecto…
        ╎    ╎    ╎    ╎    ╎    ╎    88       @Base/…ay.jl:398  falses
        ╎    ╎    ╎    ╎    ╎    ╎     88       @Base/…ay.jl:400  falses
        ╎    ╎    ╎    ╎    ╎    ╎    ╎ 88       @Base/…ay.jl:71  BitArray
        ╎    ╎    ╎    ╎    ╎    ╎    ╎  56       @Base/…ay.jl:37  BitArray
        ╎    ╎    ╎    ╎    ╎    ╎    ╎   24       @Base/…ot.jl:647  Array
      24╎    ╎    ╎    ╎    ╎    ╎    ╎    24       @Base/…ot.jl:588  GenericMe…
      32╎    ╎    ╎    ╎    ╎    ╎    ╎   32       @Base/…ot.jl:648  Array
      32╎    ╎    ╎    ╎    ╎    ╎    ╎  32       @Base/…ay.jl:39  BitArray
        ╎    ╎    ╎    ╎    ╎    ╎ 3928     @Base/…al.jl:1674  permutedims(B::A…
        ╎    ╎    ╎    ╎    ╎    ╎  3928     @Base/…ay.jl:378  similar
        ╎    ╎    ╎    ╎    ╎    ╎   3928     @Base/…ot.jl:663  Array
      64╎    ╎    ╎    ╎    ╎    ╎    3928     @Base/…ot.jl:657  Array
        ╎    ╎    ╎    ╎    ╎    ╎     3864     @Base/…ot.jl:604  new_as_memory…
    3864╎    ╎    ╎    ╎    ╎    ╎    ╎ 3864     @Base/…ot.jl:588  GenericMemory
        ╎    ╎    ╎    ╎    ╎    ╎ 88       @Base/…al.jl:1675  permutedims(B::A…
        ╎    ╎    ╎    ╎    ╎    ╎  88       @Base/…al.jl:1689  permutedims!(P:…
        ╎    ╎    ╎    ╎    ╎    ╎   88       @Base/…al.jl:1691  macro expansion
        ╎    ╎    ╎    ╎    ╎    ╎    88       @Base/…al.jl:1681  checkdims_per…
        ╎    ╎    ╎    ╎    ╎    ╎     88       @Base/…cs.jl:76  isperm
        ╎    ╎    ╎    ╎    ╎    ╎    ╎ 88       @Base/…cs.jl:80  _isperm(A::Ve…
        ╎    ╎    ╎    ╎    ╎    ╎    ╎  88       @Base/…ay.jl:398  falses
        ╎    ╎    ╎    ╎    ╎    ╎    ╎   88       @Base/…ay.jl:400  falses
        ╎    ╎    ╎    ╎    ╎    ╎    ╎    88       @Base/…ay.jl:71  BitArray
        ╎    ╎    ╎    ╎    ╎    ╎    ╎     56       @Base/…ay.jl:37  BitArray
        ╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎ 24       @Base/…ot.jl:647  Array
      24╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎  24       @Base/…ot.jl:588  Generi…
      32╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎ 32       @Base/…ot.jl:648  Array
      32╎    ╎    ╎    ╎    ╎    ╎    ╎     32       @Base/…ay.jl:39  BitArray
        ╎    ╎    ╎    ╎    ╎     544      REPL[2]:10  #init_cvn##2
        ╎    ╎    ╎    ╎    ╎    ╎ 96       @Base/…ay.jl:161  vect
        ╎    ╎    ╎    ╎    ╎    ╎  64       @Base/…ot.jl:647  Array
      64╎    ╎    ╎    ╎    ╎    ╎   64       @Base/…ot.jl:588  GenericMemory
      32╎    ╎    ╎    ╎    ╎    ╎  32       @Base/…ot.jl:648  Array
        ╎    ╎    ╎    ╎    ╎    ╎ 88       @Base/…al.jl:1672  permutedims(B::A…
        ╎    ╎    ╎    ╎    ╎    ╎  88       @Base/…cs.jl:76  isperm
        ╎    ╎    ╎    ╎    ╎    ╎   88       @Base/…cs.jl:80  _isperm(A::Vecto…
        ╎    ╎    ╎    ╎    ╎    ╎    88       @Base/…ay.jl:398  falses
        ╎    ╎    ╎    ╎    ╎    ╎     88       @Base/…ay.jl:400  falses
        ╎    ╎    ╎    ╎    ╎    ╎    ╎ 88       @Base/…ay.jl:71  BitArray
        ╎    ╎    ╎    ╎    ╎    ╎    ╎  56       @Base/…ay.jl:37  BitArray
        ╎    ╎    ╎    ╎    ╎    ╎    ╎   24       @Base/…ot.jl:647  Array
      24╎    ╎    ╎    ╎    ╎    ╎    ╎    24       @Base/…ot.jl:588  GenericMe…
      32╎    ╎    ╎    ╎    ╎    ╎    ╎   32       @Base/…ot.jl:648  Array
      32╎    ╎    ╎    ╎    ╎    ╎    ╎  32       @Base/…ay.jl:39  BitArray
        ╎    ╎    ╎    ╎    ╎    ╎ 272      @Base/…al.jl:1674  permutedims(B::A…
        ╎    ╎    ╎    ╎    ╎    ╎  272      @Base/…ay.jl:378  similar
        ╎    ╎    ╎    ╎    ╎    ╎   272      @Base/…ot.jl:663  Array
      64╎    ╎    ╎    ╎    ╎    ╎    272      @Base/…ot.jl:657  Array
        ╎    ╎    ╎    ╎    ╎    ╎     208      @Base/…ot.jl:604  new_as_memory…
     208╎    ╎    ╎    ╎    ╎    ╎    ╎ 208      @Base/…ot.jl:588  GenericMemory
        ╎    ╎    ╎    ╎    ╎    ╎ 88       @Base/…al.jl:1675  permutedims(B::A…
        ╎    ╎    ╎    ╎    ╎    ╎  88       @Base/…al.jl:1689  permutedims!(P:…
        ╎    ╎    ╎    ╎    ╎    ╎   88       @Base/…al.jl:1691  macro expansion
        ╎    ╎    ╎    ╎    ╎    ╎    88       @Base/…al.jl:1681  checkdims_per…
        ╎    ╎    ╎    ╎    ╎    ╎     88       @Base/…cs.jl:76  isperm
        ╎    ╎    ╎    ╎    ╎    ╎    ╎ 88       @Base/…cs.jl:80  _isperm(A::Ve…
        ╎    ╎    ╎    ╎    ╎    ╎    ╎  88       @Base/…ay.jl:398  falses
        ╎    ╎    ╎    ╎    ╎    ╎    ╎   88       @Base/…ay.jl:400  falses
        ╎    ╎    ╎    ╎    ╎    ╎    ╎    88       @Base/…ay.jl:71  BitArray
        ╎    ╎    ╎    ╎    ╎    ╎    ╎     56       @Base/…ay.jl:37  BitArray
        ╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎ 24       @Base/…ot.jl:647  Array
      24╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎  24       @Base/…ot.jl:588  Generi…
      32╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎ 32       @Base/…ot.jl:648  Array
      32╎    ╎    ╎    ╎    ╎    ╎    ╎     32       @Base/…ay.jl:39  BitArray
        ╎    ╎    ╎    ╎    ╎     5208     @Flux/…ic.jl:204  (::Dense{typeof(ta…
        ╎    ╎    ╎    ╎    ╎    ╎ 336      @Base/…ay.jl:132  reshape
        ╎    ╎    ╎    ╎    ╎    ╎  336      @Base/…ay.jl:133  reshape(parent::…
     336╎    ╎    ╎    ╎    ╎    ╎   336      @Base/…ay.jl:50  reshape
        ╎    ╎    ╎    ╎    ╎    ╎ 4872     @Flux/…ic.jl:199  (::Dense{typeof(t…
        ╎    ╎    ╎    ╎    ╎    ╎  4872     @LinearAlgebra/…:136  *
        ╎    ╎    ╎    ╎    ╎    ╎   4872     @Base/…ay.jl:377  similar
        ╎    ╎    ╎    ╎    ╎    ╎    4872     @Base/…ot.jl:661  Array
     144╎    ╎    ╎    ╎    ╎    ╎     4872     @Base/…ot.jl:651  Array
        ╎    ╎    ╎    ╎    ╎    ╎    ╎ 4728     @Base/…ot.jl:604  new_as_memor…
    4728╎    ╎    ╎    ╎    ╎    ╎    ╎  4728     @Base/…ot.jl:588  GenericMemo…
        ╎    ╎    ╎    ╎    ╎  23733328 @Flux/…nv.jl:201  Conv
        ╎    ╎    ╎    ╎    ╎   23733328 @NNlib/…v.jl:83  conv
        ╎    ╎    ╎    ╎    ╎    11873600 @NNlib/…v.jl:86  conv(x::Array{Float3…
        ╎    ╎    ╎    ╎    ╎     11873600 @Base/…ay.jl:824  similar
        ╎    ╎    ╎    ╎    ╎    ╎ 11873600 @Base/…ay.jl:377  similar
        ╎    ╎    ╎    ╎    ╎    ╎  11873600 @Base/…ot.jl:663  Array
     128╎    ╎    ╎    ╎    ╎    ╎   11873600 @Base/…ot.jl:657  Array
        ╎    ╎    ╎    ╎    ╎    ╎    11873472 @Base/…ot.jl:604  new_as_memoryr…
11873472╎    ╎    ╎    ╎    ╎    ╎     11873472 @Base/…ot.jl:588  GenericMemory
        ╎    ╎    ╎    ╎    ╎    11859728 @NNlib/…v.jl:88  conv(x::Array{Float3…
        ╎    ╎    ╎    ╎    ╎     11859728 @NNlib/…v.jl:140  conv!
        ╎    ╎    ╎    ╎    ╎    ╎ 11859728 @NNlib/…v.jl:145  conv!(y::Array{Fl…
        ╎    ╎    ╎    ╎    ╎    ╎  11859344 @NNlib/…v.jl:185  conv!
        ╎    ╎    ╎    ╎    ╎    ╎   11859344 @NNlib/…v.jl:218  conv!(out::Arra…
        ╎    ╎    ╎    ╎    ╎    ╎    11859344 @NNlib/…v.jl:209  conv_group
        ╎    ╎    ╎    ╎    ╎    ╎     11859344 @NNlib/…l.jl:23  conv_im2col!(y…
        ╎    ╎    ╎    ╎    ╎    ╎    ╎ 11859344 @Base/…ay.jl:822  similar
        ╎    ╎    ╎    ╎    ╎    ╎    ╎  11859344 @Base/…ay.jl:67  similar
        ╎    ╎    ╎    ╎    ╎    ╎    ╎   11859344 @Base/…ay.jl:377  similar
        ╎    ╎    ╎    ╎    ╎    ╎    ╎    11859344 @Base/…ot.jl:662  Array
      96╎    ╎    ╎    ╎    ╎    ╎    ╎     11859344 @Base/…ot.jl:654  Array
        ╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎ 11859248 @Base/…ot.jl:604  new_as_…
11859248╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎  11859248 @Base/…ot.jl:588  Generi…
        ╎    ╎    ╎    ╎    ╎    ╎  384      @NNlib/…s.jl:75  insert_singleton_…
        ╎    ╎    ╎    ╎    ╎    ╎   384      @NNlib/…s.jl:69  insert_singleton…
        ╎    ╎    ╎    ╎    ╎    ╎    384      @Base/…ay.jl:130  reshape
     384╎    ╎    ╎    ╎    ╎    ╎     384      @Base/…ay.jl:50  reshape
        ╎    ╎    ╎    ╎    ╎  737848   @Flux/…nv.jl:797  MeanPool
        ╎    ╎    ╎    ╎    ╎   737848   @NNlib/…g.jl:114  meanpool
        ╎    ╎    ╎    ╎    ╎    737368   @NNlib/…g.jl:117  meanpool(x::Array{F…
        ╎    ╎    ╎    ╎    ╎     737368   @Base/…ay.jl:823  similar
        ╎    ╎    ╎    ╎    ╎    ╎ 737368   @Base/…ay.jl:377  similar
        ╎    ╎    ╎    ╎    ╎    ╎  737368   @Base/…ot.jl:663  Array
      64╎    ╎    ╎    ╎    ╎    ╎   737368   @Base/…ot.jl:657  Array
        ╎    ╎    ╎    ╎    ╎    ╎    737304   @Base/…ot.jl:604  new_as_memoryr…
  737304╎    ╎    ╎    ╎    ╎    ╎     737304   @Base/…ot.jl:588  GenericMemory
        ╎    ╎    ╎    ╎    ╎    480      @NNlib/…g.jl:119  meanpool(x::Array{F…
        ╎    ╎    ╎    ╎    ╎     480      @NNlib/…g.jl:70  meanpool!
        ╎    ╎    ╎    ╎    ╎    ╎ 480      @NNlib/…g.jl:73  meanpool!(y::Array…
        ╎    ╎    ╎    ╎    ╎    ╎  128      @NNlib/…s.jl:75  insert_singleton_…
        ╎    ╎    ╎    ╎    ╎    ╎   128      @NNlib/…s.jl:69  insert_singleton…
        ╎    ╎    ╎    ╎    ╎    ╎    128      @Base/…ay.jl:130  reshape
     128╎    ╎    ╎    ╎    ╎    ╎     128      @Base/…ay.jl:50  reshape
        ╎    ╎    ╎    ╎    ╎    ╎  352      @NNlib/…g.jl:38  meanpool!
        ╎    ╎    ╎    ╎    ╎    ╎   352      @NNlib/…g.jl:41  #meanpool!#346
        ╎    ╎    ╎    ╎    ╎    ╎    352      @NNlib/…t.jl:4  meanpool_direct!
     192╎    ╎    ╎    ╎    ╎    ╎     352      @NNlib/…t.jl:7  meanpool_direct…
     160╎    ╎    ╎    ╎    ╎    ╎    ╎ 160      @Base/…ls.jl:1040  Val
Total snapshots: 79
Total bytes: 24482328

Just about all of the effort is going into the math. No SpinLock or Task heavy overhead. I am able to grab big blocks of RAM from the OS, but with your setup, the OS is hogging half of the RAM and the allocations trigger partial garbage collection frequently (ijl_gc_small_alloc).

For CPU heavy workload I would go with SimpleChain either directly or as a backend for Lux. Also remember to interpolate D_n when you benchmark.

In permutedims use a tuple for the dimensions.

You could also use Base.PermutedArray but not sure it helps

If you want GPU at the end go directly with Reactant won’t be able to beat that in Julia directly without too much effort.

Benchmark I get :
initial code

BenchmarkTools.Trial: 46 samples with 1 evaluation per sample.
 Range (min … max):   36.676 ms … 237.198 ms  ┊ GC (min … max):  0.00% … 66.67%
 Time  (median):     114.694 ms               ┊ GC (median):    57.06%
 Time  (mean ± σ):   109.263 ms ±  54.403 ms  ┊ GC (mean ± σ):  47.93% ± 29.71%

  ▁ ▁  █                    ▁▁                                   
  █▇█▁▄█▁▁▁▄▄▁▁▄▁▁▄▁▁▄▄▁▄▇▁▄██▁▁▁▁▄▇▁▇▄▄▁▄▁▁▇▄▇▁▁▁▁▁▁▁▁▁▁▄▁▁▁▁▄ ▁
  36.7 ms          Histogram: frequency by time          237 ms <

 Memory estimate: 238.26 MiB, allocs estimate: 253.

interpolate D_n :

BenchmarkTools.Trial: 45 samples with 1 evaluation per sample.
 Range (min … max):   38.248 ms … 264.310 ms  ┊ GC (min … max):  0.00% … 57.08%
 Time  (median):     112.183 ms               ┊ GC (median):    55.26%
 Time  (mean ± σ):   113.435 ms ±  57.543 ms  ┊ GC (mean ± σ):  51.51% ± 27.26%

  ▄█▁▄▁           ▄  █▁ ▄           ▄▁▁                          
  █████▁▁▆▁▁▁▁▁▁▁▁█▆▁██▆█▆▁▁▁▁▆▁▁▁▆▁███▆▆▆▁▁▁▆▁▁▁▆▁▁▁▁▁▁▁▁▁▁▁▁▆ ▁
  38.2 ms          Histogram: frequency by time          264 ms <

 Memory estimate: 238.26 MiB, allocs estimate: 253.

using Tuple in permutedims

BenchmarkTools.Trial: 48 samples with 1 evaluation per sample.
 Range (min … max):   38.291 ms … 225.602 ms  ┊ GC (min … max):  0.00% … 52.38%
 Time  (median):      98.350 ms               ┊ GC (median):    32.88%
 Time  (mean ± σ):   104.278 ms ±  60.259 ms  ┊ GC (mean ± σ):  44.97% ± 30.25%

  █▁▁▁▁                                                          
  █████▄▁▇▄▁▄▁▁▁▁▁▄▁▁▁▁▁▄▄▄▁▇▁▁▇▄▄▄▄▁▄▁▁▄▁▁▄▁▄▁▇▁▁▁▄▇▁▄▄▁▁▄▁▁▁▄ ▁
  38.3 ms          Histogram: frequency by time          226 ms <

 Memory estimate: 238.26 MiB, allocs estimate: 237.

using Base.PermutedDimsArray

BenchmarkTools.Trial: 43 samples with 1 evaluation per sample.
 Range (min … max):   40.265 ms … 244.956 ms  ┊ GC (min … max):  0.00% … 53.89%
 Time  (median):     118.640 ms               ┊ GC (median):    58.99%
 Time  (mean ± σ):   117.793 ms ±  63.335 ms  ┊ GC (mean ± σ):  51.61% ± 30.31%

  █  ▂                           ▂                               
  █▄▄█▁▄▄▄▄▁▁▁▁▁▁▁▁▄▁▄▁▄▄▄▄▄▁▁▁▄▁█▁▁▁▄▄▁▁▄█▁▁▄▆▁▁▄▁▁▁▄▁▄▄▁▁▁▁▁▄ ▁
  40.3 ms          Histogram: frequency by time          245 ms <

 Memory estimate: 238.26 MiB, allocs estimate: 230.

Reactant gpu
code :

using Flux, BenchmarkTools, StatsBase, Reactant

function init_cvn()
	cv1=Conv((5,5), 12=>20, leakyrelu; bias=false)
	cv2=Conv((5,5), 20=>3, leakyrelu; bias=false)
	mp1=MeanPool((4,4)) 

	cval = Chain(
			Dense(20=>1, leakyrelu),
			x->permutedims(x,[2,1,3,4]),
			Dense(20=>1, leakyrelu),
			x->permutedims(x,[3,1,2,4]),
			Dense(3=>10, tanh),
	)
	
	Chain(cv1, mp1 ,cv2, cval)
end

function init_Dₓ()
	Chain(Dense(10=>10, leakyrelu), Dense(10=>1, sigmoid))
end

Dₙ=Chain(init_cvn(), init_Dₓ()) |> Reactant.to_rarray

x=rand(Float32, 100,100,12,16) |> Reactant.to_rarray
Dₙ_compile = @compile sync=true Dₙ(x)
@benchmark $Dₙ_compile($x)

result

BenchmarkTools.Trial: 3573 samples with 1 evaluation per sample.
 Range (min … max):  963.073 μs …  11.249 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):       1.165 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):     1.396 ms ± 665.036 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▅██▆▅▆▅▄▃▄▃▃▂▃▁▂▁ ▁                                           ▂
  ███████████████████████▇█▆█▇▇▇▇▅▆▆▆▄▃▄▅▅▅▆▆▆▆▄▄▁▄▄▃▄▁▃▅▄▁▃▄▅▅ █
  963 μs        Histogram: log(frequency) by time       4.69 ms <

 Memory estimate: 496 bytes, allocs estimate: 14.

Reactant cpu :

BenchmarkTools.Trial: 669 samples with 1 evaluation per sample.
 Range (min … max):  4.865 ms … 11.781 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     7.342 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   7.463 ms ±  1.013 ms  ┊ GC (mean ± σ):  0.00% ± 0.00%

               ▂ ▁▄▃▆▃▆▇▄▆█▅▄▄▃▄▂▃▂ ▁                         
  ▃▂▁▂▃▁▃▃▃▄▅▆▆█▆██████████████████▇█▄▅▄▆▅▅▄▅▄▅▃▃▄▃▁▃▁▁▁▂▁▁▂ ▄
  4.86 ms        Histogram: frequency by time        10.8 ms <

 Memory estimate: 496 bytes, allocs estimate: 14.

Reactant looks promising, I’ll give it a try. Based on that I can understand at first glance, the performance improvement is more compiler-based than memory management per se. Hopefully that means the model can still be trained in Flux?

Flux training using Reactant doesn’t seem to be working for me. I get a “llvmcall requires the compiler” error after a bunch of Zygote pullbacks.

I tried looking to see if there was anything available for Flux and Reactant, but all the material seems to be for Lux. Any idea if Reactant can be made to work with Flux.train! or do I need to switch to something like Lux?

Edit: Lux seemed to need the loss function AD’ed and compiled as well. I’ve only compiled the model and converted the data into a Reactant array, so might I need to do additional work on the loss function?

You should generally use Enzyme for autodiff within Reactant, not Zygote

It is compatible with Flux and pretty much 80% of Julia array workloads too.

For Flux.train I’m not sure but you can use Optimisers by itself and calculate gradient with Enzyme it should just work.

When calling any function on a Reactant array you need to compile it before or use ‘@jit’ if you don’t need that to be fast or if it’s a one time calculation. With the exception of Lux.Training which uses syntactic sugar.

I’ve been experimenting with Reactant and I’ve found an unexpected behavior that I’m not sure how to deal with. So the GAN training involves several terms where a generator feeds its output directly into a discriminator, so D(G(z)), where z is some random value. When I @benchmark D or G using Reactant, I get about 550B over 15 allocations for either. However, when I @benchmark the composite function compiled using Reactant I get 21.5KiB over 205 allocation. This is still orders of magnitude less allocated memory than the uncompiled version, but “only” about 2/3 of the number of allocations.

Any ideas why the composite function would make so many more allocations than the two functions individually? Happy to post my code if this is a non-obvious behavior I’ve stumbled onto.

You need to use Reactant.@timed or Reactant.@profile instead and handle the multi run by hand for now.
See Profiling | Reactant.jl for everything related with profiling.
Also be careful of which device Reactant is running. Also number of allocations is far less important that their size.

Yes, the Reactant profiling did show some of the issues were overestimations on the part of @benchmark. I’ve got what seems to be the whole training loop going, but it’s still taking me ~700s to do a single epoch. This is about 55% the time of the vanilla Flux CPU training time, but is still longer than I was hoping (perhaps unreasonably) that it would be. Especially since the compiled functions seem to be allocating about 2,000x less memory than the vanilla.

So first, is this the level of improvement you’d expect? Second, I’ve noticed that the single largest loss of speed seems to be the Enzyme gradient step. The loss function itself takes about 50ms to evaluate but the gradient calculation takes about 1-1.2 seconds, which seems like a big hit.

I also can’t use the GPU backend because I’m on Windows. @wsmoses I’d like to submit my vote for Windows support too :wink:

On CPU for single device Reactant won’t help much besides caching better and for the gradient, I would go with wsl for now. Or use Collab and go full tpu but that may not be what you want.

Fair enough, I’ll give WSL a try. When you say Reactant helps with the gradient, is that the ~2x speed up I’m observing? Just want to make sure I’m seeing the intended benefit and haven’t missed something somewhere.