I ran a few tests which confirm that BLAS.set_num_threads(1)
should be set. On my system the code ran ~10 times faster. Please see below for codes and results
# start Julia with JULIA_NUM_THREADS=1 julia
using Flux
using BenchmarkTools
using LinearAlgebra
n = 100_000
p = 50
x = rand(Float32, p, n)
y = rand(Float32, n)
trdata = Flux.Data.DataLoader(x, y, batchsize=100)
m = [Chain(Dense(p, 100), Dense(100,100), Dense(100,1)) for i in 1:4]
@btime for i in 1:4
loss(x, y) = Flux.mse(m[i](x), y)
Flux.@epochs 1 Flux.train!(loss, Flux.params(m[i]), trdata, Flux.ADAM())
end
# 6.286 s (1992500 allocations: 2.24 GiB)
# start Julia with JULIA_NUM_THREADS=4 julia
using Flux
using BenchmarkTools
using LinearAlgebra
n = 100_000
p = 50
x = rand(Float32, p, n)
y = rand(Float32, n)
trdata = Flux.Data.DataLoader(x, y, batchsize=100)
m = [Chain(Dense(p, 100), Dense(100,100), Dense(100,1)) for i in 1:4]
@btime Threads.@threads for i in 1:4
loss(x, y) = Flux.mse(m[i](x), y)
Flux.@epochs 1 Flux.train!(loss, Flux.params(m[i]), trdata, Flux.ADAM())
end
# 10.864 s (1992523 allocations: 2.24 GiB)
# start Julia with JULIA_NUM_THREADS=4 julia
using Flux
using BenchmarkTools
using LinearAlgebra
BLAS.set_num_threads(1)
n = 100_000
p = 50
x = rand(Float32, p, n)
y = rand(Float32, n)
trdata = Flux.Data.DataLoader(x, y, batchsize=100)
m = [Chain(Dense(p, 100), Dense(100,100), Dense(100,1)) for i in 1:4]
@btime Threads.@threads for i in 1:4
loss(x, y) = Flux.mse(m[i](x), y)
Flux.@epochs 1 Flux.train!(loss, Flux.params(m[i]), trdata, Flux.ADAM())
end
# 1.076 s (1992515 allocations: 2.24 GiB)