Lux.jl Chain of Dense layers don't see speed up going from 32 bits -> 16 bits

using Lux, BenchmarkTools, Random
const rng  = Random.default_rng()
julia> let model = Chain(
                  Dense(23, 30, Lux.relu),
                  Dense(30, 25, Lux.relu),
                  Dense(25, 20, Lux.relu)
              )
              ps, st = Lux.setup(rng, model)
              @btime $model(inputs, $ps, $st)[1] setup=(inputs=rand(Float32, 23))
           end;
  364.164 ns (12 allocations: 1.17 KiB)

julia> let model = Chain(
                  Dense(23, 30, Lux.relu),
                  Dense(30, 25, Lux.relu),
                  Dense(25, 20, Lux.relu)
              )
              ps, st = Lux.setup(rng, model)
              @btime $model(inputs, $ps, $st)[1] setup=(inputs=rand(Float16, 23))
           end;
  400.045 ns (13 allocations: 1.33 KiB)

for reference, 64 bits is about 2x slower at 772.561 ns (12 allocations: 1.77 KiB)

(1.8-rc3)

Julia currently doesn’t use native float16 operations even if available and it adds some overhead due to the conversion overhead. If you were memory bound you might see some improvement but I doubt it. In addition I believe very few x86 chips have float16 operations so your mileage may vary. Hardware Float16 on A64fx · Issue #40216 · JuliaLang/julia · GitHub

3 Likes

that’s a bit unfortunate, if we can lift another ~2x from here we can compete with FPGA for some real time ML application…

it’s still possible we just need to find a CPU with better single core performance

If you have float16 operations you could try a build from source and comment out PM->add(createDemoteFloat16Pass()); in aotcompile.cpp to check if there is a possible improvement there. If you do there might be a 2x improvement

If those network sizes are representative of what you want to do then GitHub - PumasAI/SimpleChains.jl: Simple chains should give you a significant speedup.

jesus that’s fast

julia> let model = SimpleChain(
                static(23), # input dimension (optional)
                TurboDense{true}(relu, 30), # dense layer with bias that maps to 8 outputs and applies `tanh` activation
                TurboDense{true}(relu, 25),
                TurboDense{true}(relu, 20)
              )
              p = SimpleChains.init_params(a);
              @btime $model(inputs, $p) setup=(inputs=rand(Float32, 23))
           end;
  77.845 ns (0 allocations: 0 bytes)
3 Likes