Hi all,
Myself and a student are trying to implement a simple, relatively small convolutional neural network. The trick is, we need it to run as part of our existing real time controller (implemented in Julia, see Case study: Real time hardware control for adaptive optics with Julia).
To make this work, we pretty much need to all avoid allocations in the inference loop.
We tried following a MNIST tutorial for Flux.jl but immediately hit allocations when evaluating the model.
The MNIST example for SimpleChains.jl on the other hand doesn’t have any allocations, so it looks like the way to go—unless Lux.jl would be a better choice?
But now we’ve hit a snag. We’ve coded up this network, but keep hitting errors in what looks to be fairly deep internals about StrideArrays and the like. See below for a MWE.
Could we get some guidance on the best choice of package here? And if it is SimpleChains, some suggestions on how to implement it?
Thanks very much!
MWE
using SimpleChains
using Statistics
llowfs_images4 = rand(Float64, (20, 20, 1, 100000))
amplitudes = rand(Float64, (100000, 14))
my_chain = SimpleChain(
(static(20), static(20), static(1)),
SimpleChains.Conv(SimpleChains.relu, (5, 5), 6),
SimpleChains.MaxPool(2, 2),
SimpleChains.Conv(SimpleChains.relu, (5, 5), 16),
SimpleChains.MaxPool(2, 2),
Flatten(3),
TurboDense(SimpleChains.relu, 120),
TurboDense(SimpleChains.relu, 84),
TurboDense(identity, 10),
)
my_batch_size = 32
parameters = SimpleChains.init_params(my_chain);
grad_loss_of_params = SimpleChains.alloc_threaded_grad(my_chain);
model_loss = SimpleChains.add_loss(my_chain, SquaredLoss(amplitudes))
valgrad!(grad_loss_of_params, model_loss, llowfs_images4_reduced, parameters)
Gives an error message:
ArgumentError: Must unroll in vectorized direction for `Bit` loads with W < 8.
Stacktrace:
[1] bitload(AU::Int64, W::Int64, AV::Int64, F::Int64, UN::Int64, RS::Int64, mask::Bool)
@ VectorizationBase C:\Users\User\.julia\packages\VectorizationBase\wHnQd\src\vecunroll\memory.jl:511
[2] #s43#329
@ C:\Users\User\.julia\packages\VectorizationBase\wHnQd\src\vecunroll\memory.jl:560 [inlined]
[3] var"#s43#329"(T::Any, N::Any, C::Any, B::Any, AU::Any, F::Any, UN::Any, AV::Any, W::Any, M::Any, UX::Any, I::Any, A::Any, RS::Any, X::Any, ::Any, sptr::Any, u::Any, ::Any, ::Any, ::Any)
@ VectorizationBase .\none:0
[4] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
@ Core .\boot.jl:707
[5] _vload
@ C:\Users\User\.julia\packages\VectorizationBase\wHnQd\src\vecunroll\memory.jl:771 [inlined]
[6] macro expansion
@ C:\Users\User\.julia\packages\LoopVectorization\tIJUA\src\reconstruct_loopset.jl:1107 [inlined]
[7] _turbo_!
@ C:\Users\User\.julia\packages\LoopVectorization\tIJUA\src\reconstruct_loopset.jl:1107 [inlined]
[8] macro expansion
@ C:\Users\User\.julia\packages\LoopVectorization\tIJUA\src\condense_loopset.jl:1179 [inlined]
[9] update_C̄!(::typeof(relu), C̄::StrideArraysCore.PtrArray{Float32, 4, (1, 2, 3, 4), Tuple{Static.StaticInt{4}, Static.StaticInt{4}, Static.StaticInt{16}, Int64}, NTuple{4, Nothing}, NTuple{4, Static.StaticInt{1}}}, ∂C::StrideArraysCore.AbstractPtrArray{Bool, 4, (1, 2, 3, 4), Tuple{Static.StaticInt{4}, Static.StaticInt{4}, Static.StaticInt{16}, Int64}, NTuple{4, Nothing}, NTuple{4, Static.StaticInt{1}}, SIMDTypes.Bit})
@ SimpleChains C:\Users\User\.julia\packages\SimpleChains\mSbJT\src\dense.jl:960
[10] pullback_common!
@ C:\Users\User\.julia\packages\SimpleChains\mSbJT\src\conv.jl:1033 [inlined]
[11] pullback!
@ C:\Users\User\.julia\packages\SimpleChains\mSbJT\src\simple_chain.jl:657 [inlined]
...
@ C:\Users\User\.julia\packages\SimpleChains\mSbJT\src\memory.jl:50 [inlined]
[18] valgrad!(g::StrideArray{Float32, 1, (1,), Tuple{Int64}, Tuple{Nothing}, Tuple{Static.StaticInt{1}}, Vector{Float32}}, c::SimpleChain{Tuple{Static.StaticInt{20}, Static.StaticInt{20}, Static.StaticInt{1}}, Tuple{Conv{typeof(relu), Tuple{Static.StaticInt{5}, Static.StaticInt{5}}, Static.StaticInt{6}}, MaxPool{(2, 2)}, Conv{typeof(relu), Tuple{Static.StaticInt{5}, Static.StaticInt{5}}, Static.StaticInt{16}}, MaxPool{(2, 2)}, Flatten{3}, TurboDense{true, Static.StaticInt{120}, typeof(relu)}, TurboDense{true, Static.StaticInt{84}, typeof(relu)}, TurboDense{true, Static.StaticInt{10}, typeof(identity)}, SquaredLoss{Matrix{Float64}}}}, arg::Array{Float64, 4}, params::StrideArray{Float32, 1, (1,), Tuple{Int64}, Tuple{Nothing}, Tuple{Static.StaticInt{1}}, Vector{Float32}})
@ SimpleChains C:\Users\User\.julia\packages\SimpleChains\mSbJT\src\simple_chain.jl:571
[19] top-level scope
@ c:\Users\User\Desktop\ml_env\jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X20sZmlsZQ==.jl:6