I set the number of cores that pytorch uses to 1
using PyCall
using Flux
using BenchmarkTools
torch = pyimport("torch")
torch.set_num_threads(1)
NN = torch.nn.Sequential(
torch.nn.Linear(8, 64),
torch.nn.Tanh(),
torch.nn.Linear(64, 32),
torch.nn.Tanh(),
torch.nn.Linear(32, 2),
torch.nn.Tanh()
)
torch_nn(in) = NN(in)
Flux_nn = Chain(Dense(8,64,tanh),
Dense(64,32,tanh),
Dense(32,2,tanh))
for i in [1, 10, 100, 1000]
println("Batch size: $i")
torch_in = torch.rand(i,8)
flux_in = rand(Float32,8,i)
print("pytorch :")
@btime torch_nn($torch_in)
print("flux :")
@btime Flux_nn($flux_in)
end
Batch size: 1
pytorch : 88.087 μs (6 allocations: 192 bytes)
flux : 3.567 μs (6 allocations: 1.25 KiB)
Batch size: 10
pytorch : 100.711 μs (6 allocations: 192 bytes)
flux : 18.236 μs (6 allocations: 8.22 KiB)
Batch size: 100
pytorch : 140.269 μs (6 allocations: 192 bytes)
flux : 162.120 μs (8 allocations: 77.16 KiB)
Batch size: 1000
pytorch : 465.119 μs (6 allocations: 192 bytes)
flux : 4.485 ms (10 allocations: 766.19 KiB)
_
_ _ _(_)_ | Documentation: https://docs.julialang.org
(_) | (_) (_) |
_ _ _| |_ __ _ | Type "?" for help, "]?" for Pkg help.
| | | | | | |/ _` | |
| | |_| | | | (_| | | Version 1.4.1 (2020-04-14)
_/ |\__'_|_|_|\__'_| | Official https://julialang.org/ release
|__/ |
julia> torch.get_num_threads()
1
the result are about the same so I don’t think it is only the threads stuff. Pytorch probably has some other tricks. I will check MKL.jl later today, maybe that’s that.
about the SMT, I think it is disables on my laptop. but I’m sure I checked it right
dicker@dicker-X1:~$ grep -o '^flags\b.*: .*\bht\b' /proc/cpuinfo | tail -1
flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht