Implementing a ConvNet that doesn't allocate during inference (SimpleChains.jl?)

Hi all,
Myself and a student are trying to implement a simple, relatively small convolutional neural network. The trick is, we need it to run as part of our existing real time controller (implemented in Julia, see Case study: Real time hardware control for adaptive optics with Julia).
To make this work, we pretty much need to all avoid allocations in the inference loop.

We tried following a MNIST tutorial for Flux.jl but immediately hit allocations when evaluating the model.

The MNIST example for SimpleChains.jl on the other hand doesn’t have any allocations, so it looks like the way to go—unless Lux.jl would be a better choice?

But now we’ve hit a snag. We’ve coded up this network, but keep hitting errors in what looks to be fairly deep internals about StrideArrays and the like. See below for a MWE.

Could we get some guidance on the best choice of package here? And if it is SimpleChains, some suggestions on how to implement it?

Thanks very much!

MWE

using SimpleChains
using Statistics

llowfs_images4 = rand(Float64, (20, 20, 1, 100000))
amplitudes = rand(Float64, (100000, 14))


my_chain = SimpleChain(
    (static(20), static(20), static(1)), 
    SimpleChains.Conv(SimpleChains.relu, (5, 5), 6),
    SimpleChains.MaxPool(2, 2),
    SimpleChains.Conv(SimpleChains.relu, (5, 5), 16),
    SimpleChains.MaxPool(2, 2),
    Flatten(3),
    TurboDense(SimpleChains.relu, 120),
    TurboDense(SimpleChains.relu, 84),
    TurboDense(identity, 10),
  )

my_batch_size = 32

parameters = SimpleChains.init_params(my_chain);
grad_loss_of_params = SimpleChains.alloc_threaded_grad(my_chain);

model_loss = SimpleChains.add_loss(my_chain, SquaredLoss(amplitudes))

valgrad!(grad_loss_of_params, model_loss, llowfs_images4_reduced, parameters)

Gives an error message:

ArgumentError: Must unroll in vectorized direction for `Bit` loads with W < 8.
Stacktrace:
  [1] bitload(AU::Int64, W::Int64, AV::Int64, F::Int64, UN::Int64, RS::Int64, mask::Bool)
    @ VectorizationBase C:\Users\User\.julia\packages\VectorizationBase\wHnQd\src\vecunroll\memory.jl:511
  [2] #s43#329
    @ C:\Users\User\.julia\packages\VectorizationBase\wHnQd\src\vecunroll\memory.jl:560 [inlined]
  [3] var"#s43#329"(T::Any, N::Any, C::Any, B::Any, AU::Any, F::Any, UN::Any, AV::Any, W::Any, M::Any, UX::Any, I::Any, A::Any, RS::Any, X::Any, ::Any, sptr::Any, u::Any, ::Any, ::Any, ::Any)
    @ VectorizationBase .\none:0
  [4] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
    @ Core .\boot.jl:707
  [5] _vload
    @ C:\Users\User\.julia\packages\VectorizationBase\wHnQd\src\vecunroll\memory.jl:771 [inlined]
  [6] macro expansion
    @ C:\Users\User\.julia\packages\LoopVectorization\tIJUA\src\reconstruct_loopset.jl:1107 [inlined]
  [7] _turbo_!
    @ C:\Users\User\.julia\packages\LoopVectorization\tIJUA\src\reconstruct_loopset.jl:1107 [inlined]
  [8] macro expansion
    @ C:\Users\User\.julia\packages\LoopVectorization\tIJUA\src\condense_loopset.jl:1179 [inlined]
  [9] update_C̄!(::typeof(relu), C̄::StrideArraysCore.PtrArray{Float32, 4, (1, 2, 3, 4), Tuple{Static.StaticInt{4}, Static.StaticInt{4}, Static.StaticInt{16}, Int64}, NTuple{4, Nothing}, NTuple{4, Static.StaticInt{1}}}, ∂C::StrideArraysCore.AbstractPtrArray{Bool, 4, (1, 2, 3, 4), Tuple{Static.StaticInt{4}, Static.StaticInt{4}, Static.StaticInt{16}, Int64}, NTuple{4, Nothing}, NTuple{4, Static.StaticInt{1}}, SIMDTypes.Bit})
    @ SimpleChains C:\Users\User\.julia\packages\SimpleChains\mSbJT\src\dense.jl:960
 [10] pullback_common!
    @ C:\Users\User\.julia\packages\SimpleChains\mSbJT\src\conv.jl:1033 [inlined]
 [11] pullback!
    @ C:\Users\User\.julia\packages\SimpleChains\mSbJT\src\simple_chain.jl:657 [inlined]
...
    @ C:\Users\User\.julia\packages\SimpleChains\mSbJT\src\memory.jl:50 [inlined]
 [18] valgrad!(g::StrideArray{Float32, 1, (1,), Tuple{Int64}, Tuple{Nothing}, Tuple{Static.StaticInt{1}}, Vector{Float32}}, c::SimpleChain{Tuple{Static.StaticInt{20}, Static.StaticInt{20}, Static.StaticInt{1}}, Tuple{Conv{typeof(relu), Tuple{Static.StaticInt{5}, Static.StaticInt{5}}, Static.StaticInt{6}}, MaxPool{(2, 2)}, Conv{typeof(relu), Tuple{Static.StaticInt{5}, Static.StaticInt{5}}, Static.StaticInt{16}}, MaxPool{(2, 2)}, Flatten{3}, TurboDense{true, Static.StaticInt{120}, typeof(relu)}, TurboDense{true, Static.StaticInt{84}, typeof(relu)}, TurboDense{true, Static.StaticInt{10}, typeof(identity)}, SquaredLoss{Matrix{Float64}}}}, arg::Array{Float64, 4}, params::StrideArray{Float32, 1, (1,), Tuple{Int64}, Tuple{Nothing}, Tuple{Static.StaticInt{1}}, Vector{Float32}})
    @ SimpleChains C:\Users\User\.julia\packages\SimpleChains\mSbJT\src\simple_chain.jl:571
 [19] top-level scope
    @ c:\Users\User\Desktop\ml_env\jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X20sZmlsZQ==.jl:6
1 Like

Make sure it fits your ram create arrays for all weights and bias (twice to account for the gradient or even more for threading) try with serial it may get rid of threading and give a better error message.

It’s a LoopVectorization bug.
That error might magically go away with Float32 instead of Float64 if you have AVX*, or if you have AVX512 (either of which would result in W >= 8, which will avoid the W < 8 bug).

I don’t maintain LoopVectorization anymore, so it’s unlikely bugs there will get fixed.

*Which I assume because you’re on Windows, but you could still be on ARM, in which case it isn’t true.

1 Like

Thanks @yolhan_mannes and @Elrod — thanks for the info. Given the LoopVectorization deprecation, is there a current recommendation for this type of work?

Is it possible to use Lux without allocations during inference? Or some way to deploy XLA without allocations etc via Reactant these days?

1 Like