Flux vs pytorch cpu performance

@d1cker could you perform the same benchmark on Flux master, now that NNlib’s PR got merged?
For some reason calling torch through pycall is not working for me

You can have the best of both worlds with Torch.jl

Batch size: 1
pytorch :  84.213 μs (6 allocations: 192 bytes)
flux    :  4.912 μs (80 allocations: 3.16 KiB)
Batch size: 10
pytorch :  94.982 μs (6 allocations: 192 bytes)
flux    :  18.803 μs (80 allocations: 10.13 KiB)
Batch size: 100
pytorch :  125.019 μs (6 allocations: 192 bytes)
flux    :  159.059 μs (82 allocations: 79.06 KiB)
Batch size: 1000
pytorch :  359.401 μs (6 allocations: 192 bytes)
flux    :  3.104 ms (84 allocations: 768.09 KiB)

It’s about the same.
Here are my packages:

(@v1.4) pkg> st
Status `~/.julia/environments/v1.4/Project.toml`
  [28f2ccd6] ApproxFun v0.11.14
  [e7028de2] AutoPreallocation v0.3.1
  [fbb218c0] BSON v0.2.6
  [6e4b80f9] BenchmarkTools v0.5.0
  [3895d2a7] CUDAapi v4.0.0
  [31a5f54b] Debugger v0.6.5
  [aae7a2af] DiffEqFlux v0.7.0
  [41bf760c] DiffEqSensitivity v4.4.0
  [31c24e10] Distributions v0.23.4
  [ced4e74d] DistributionsAD v0.1.1
  [587475ba] Flux v0.8.3
  [7073ff75] IJulia v1.21.2
  [2b0e0bc5] LanguageServer v3.1.0
  [bdcacae8] LoopVectorization v0.8.13
  [872c559c] NNlib v0.7.2
  [429524aa] Optim v0.20.1
  [1dea7af3] OrdinaryDiffEq v5.26.8
  [9b87118b] PackageCompiler v1.2.1
  [d96e819e] Parameters v0.12.1
  [91a5bcdd] Plots v0.28.4
  [c46f51b8] ProfileView v0.6.5
  [438e738f] PyCall v1.91.4
  [d330b81b] PyPlot v2.9.0
  [189a3867] Reexport v0.2.0
  [295af30f] Revise v2.7.3
  [90137ffa] StaticArrays v0.12.4
  [2913bbd2] StatsBase v0.32.2
  [cf896787] SymbolServer v4.5.0
  [37b6cedf] Traceur v0.3.1
  [e88e6eb3] Zygote v0.5.2
  [10745b16] Statistics

It’s looks like it doesn’t work without CUDA (and I don’t have gpu on my laptop :cry: )

ERROR: LoadError: InitError: could not load library "/home/dicker/.julia/artifacts/d6ce2ca09ab00964151aaeae71179deb8f9800d1/lib/libdoeye_caml.so"
libcuda.so.1: cannot open shared object file: No such file or directory

I saw this issue , but I couldn’t figure it out.

Mind sharing your versioninfo()? And perhaps also trying:

julia> @btime tanh.($xx);
  990.903 μs (2 allocations: 250.08 KiB)

julia> using LoopVectorization

julia> @btime @avx tanh.($xx); # vectorised
  148.581 μs (2 allocations: 250.08 KiB)

like from mcabbott’s post? Please also benchmark @btime $ww * $xx as well.

If your CPU does not have AVX2, you should see a big speedup in Julia 1.5 vs 1.4.
AVX2 added SIMD integer operations, which the tanh LoopVectorization uses for LLVM 8 and below (Julia 1.4 uses LLVM 8).
With LLVM 9 and newer (like on Julia 1.5), it no longer uses SIMD integer operations, making it much faster for old CPUs like sandybridge/ivybridge. Performance from CPUs with AVX2 should be more or less the same.

Here are results using Julia nightly thanks to baggepinnen on an Ivy Bridge cpu:

# Float64
237.160 μs (0 allocations: 0 bytes) # baseline
131.666 μs (0 allocations: 0 bytes) # avx on inner loop
50.324 μs  (0 allocations: 0 bytes) # avx + tanh_fast

# Float32
166.900 μs (0 allocations: 0 bytes) # baseline
50.082 μs  (0 allocations: 0 bytes) # avx on inner loop
20.320 μs  (0 allocations: 0 bytes) # avx + tanh_fast

versus Float64 on Julia 1.4:

238.923 μs (0 allocations: 0 bytes) # baseline
10.912 ms (0 allocations: 0 bytes) # avx on inner loop

10 ms down to 132 microseconds.

If you’re also on Ivy/Sandy bridge, you should expect similar speedups for tanh on Julia 1.5 or nightly. But if your CPU does have AVX2, performance shouldn’t change much, so I’m hoping it’s the former case. Although I suspect your results would have been even worse if that were the case…

julia> versioninfo()
Julia Version 1.4.1
Commit 381693d3df* (2020-04-14 17:20 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Core(TM) i7-8565U CPU @ 1.80GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-8.0.1 (ORCJIT, skylake)

Benchmarks:

julia> @btime $ww * $xx;
  21.679 μs (2 allocations: 125.08 KiB)

julia> @btime tanh.($xx);
  1.256 ms (2 allocations: 250.08 KiB)

julia> @btime @avx tanh.($xx);
  142.913 μs (2 allocations: 250.08 KiB)

Don’t think it is only the avx stuff that make pytorch fast because of the 10x(0.3ms vs 3ms) difference. I wonder what other optimization pytorch has :thinking:

SLEEFPirates.tanh_fast can also sometimes be faster (varying by CPU, OS, and Julia version), Julia 1.4:

julia> using VectorizationBase, SLEEFPirates, BenchmarkTools

julia> sx = SVec(ntuple(VectorizationBase.pick_vector_width_val(Float64)) do _ Core.VecElement(10randn(Float64)) end)
SVec{8,Float64}<-11.39873050155192, -5.486798123978332, 3.412707732482607, -15.818801711259969, -27.826023273267584, -9.409541830489616, -6.034771142124557, 23.388739638777807>

julia> @btime tanh($sx)
  21.976 ns (0 allocations: 0 bytes)
SVec{8,Float64}<-0.9999999997486849, -0.9999657034645373, 0.997830706021826, -0.9999999999999636, -1.0, -0.9999999865721709, -0.9999885371688979, 1.0>

julia> @btime SLEEFPirates.tanh_fast($sx)
  7.644 ns (0 allocations: 0 bytes)
SVec{8,Float64}<-0.999999999748685, -0.9999657034645373, 0.997830706021826, -0.9999999999999636, -1.0, -0.999999986572171, -0.999988537168898, 1.0>

julia> sx = SVec(ntuple(VectorizationBase.pick_vector_width_val(Float32)) do _ Core.VecElement(10randn(Float32)) end)
SVec{16,Float32}<-7.3119183f0, 10.299262f0, -3.6201859f0, 7.25937f0, -7.0897646f0, 2.9893537f0, 0.13792089f0, -0.4937947f0, 0.47192842f0, 4.8886666f0, 13.347324f0, 13.720837f0, -12.469461f0, -0.8011465f0, 2.6552308f0, -16.101555f0>

julia> @btime tanh($sx)
  20.119 ns (0 allocations: 0 bytes)
SVec{16,Float32}<-0.9999991f0, 1.0f0, -0.9985669f0, 0.999999f0, -0.9999986f0, 0.9949486f0, 0.13705297f0, -0.45722306f0, 0.43975613f0, 0.9998866f0, 1.0f0, 1.0f0, -1.0f0, -0.66467726f0, 0.9901693f0, -1.0f0>

julia> @btime SLEEFPirates.tanh_fast($sx)
  6.729 ns (0 allocations: 0 bytes)
SVec{16,Float32}<-0.9999991f0, 1.0f0, -0.998567f0, 0.999999f0, -0.99999857f0, 0.99494857f0, 0.13705297f0, -0.45722306f0, 0.43975613f0, 0.9998866f0, 1.0f0, 1.0f0, -1.0f0, -0.6646772f0, 0.9901693f0, -1.0f0>

Julia 1.6:

julia> using VectorizationBase, SLEEFPirates, BenchmarkTools

julia> sx = SVec(ntuple(VectorizationBase.pick_vector_width_val(Float64)) do _ Core.VecElement(10randn(Float64)) end)
SVec{8,Float64}<4.660910780715415, -1.791911063291962, 0.4152394300299185, -14.03991723342163, -6.947193534140638, 2.738775662784096, -8.620789233157913, -0.7529700170278613>

julia> @btime tanh($sx)
  12.669 ns (0 allocations: 0 bytes)
SVec{8,Float64}<0.9998211143579812, -0.9459618892732886, 0.3929123454879236, -0.9999999999987232, -0.9999981516936264, 0.9916756888100318, -0.9999999349709223, -0.6369174810382887>

julia> @btime SLEEFPirates.tanh_fast($sx)
  7.776 ns (0 allocations: 0 bytes)
SVec{8,Float64}<0.9998211143579812, -0.9459618892732885, 0.39291234548792353, -0.9999999999987232, -0.9999981516936264, 0.9916756888100318, -0.9999999349709223, -0.6369174810382886>

julia> sx = SVec(ntuple(VectorizationBase.pick_vector_width_val(Float32)) do _ Core.VecElement(10randn(Float32)) end)
SVec{16,Float32}<2.6490726f0, 0.6964538f0, 4.003585f0, 3.3185313f0, -0.42056453f0, 7.228591f0, -11.268135f0, 1.5071146f0, -20.739851f0, 1.3313888f0, 3.428663f0, 10.747046f0, 13.510487f0, 14.988632f0, 14.164627f0, 6.938663f0>

julia> @btime tanh($sx)
  8.034 ns (0 allocations: 0 bytes)
SVec{16,Float32}<0.99004805f0, 0.60211205f0, 0.9993341f0, 0.9973817f0, -0.39740592f0, 0.9999989f0, -1.0f0, 0.90642565f0, -1.0f0, 0.8695884f0, 0.99789876f0, 1.0f0, 1.0f0, 1.0f0, 1.0f0, 0.9999981f0>

julia> @btime SLEEFPirates.tanh_fast($sx)
  6.904 ns (0 allocations: 0 bytes)
SVec{16,Float32}<0.9900481f0, 0.60211205f0, 0.9993341f0, 0.99738175f0, -0.3974059f0, 0.99999887f0, -1.0f0, 0.90642565f0, -1.0f0, 0.8695884f0, 0.99789876f0, 1.0f0, 0.99999994f0, 1.0f0, 1.0f0, 0.99999815f0>

But it is less accurate, and prone to returning +/- 1 more often.

torch’s 262 microseconds for the batch of 1000 is impressive.

One suggestion I’d have for improving Flux here is to break up the batch among threads, and have each thread evaluate its chunk, rather than threading just the BLAS operations, but it’d take some work to really optimize memory bandwidth.
In this example though, simple doing the tanhes in parallel will help.

1 Like

torch’s 262 microseconds for the batch of 1000 is impressive.

I really want to know how they do that.
And in the meantime my colleagues keep saying that Python >> Julia :cry:

Is it possible that torch use your igpu ?

It would be really surprising since my computer doesn’t have one. (And the Cuda drivers prevent my from using Torch.jl)

1 Like

That is a strong argument indeed :wink:
Although you mention a laptop, so I guess that you do have a gpu of some sort (AMD discrete card ?). I don’t know nothing about torch but if they have an OpenCL backend they could use any GPU vendor (provided that the driver are properly installed).

PyTorch doesn’t support running on (Intel) iGPUs or OpenCL devices anyhow (it does have ROCm support, but that’s Linux only)

1 Like

00:02.0 VGA compatible controller: Intel Corporation UHD Graphics 620 (Whiskey Lake) (rev 02)

I don’t think pytorch uses that

Can it be that Torch does a sort of lazy calculation?

I guess you are right. Sorry for this bad clue. The cpu usage looks the same (all cores at 100%) ?

I suspect that because tanh is much slower than the matrix multiplication here (quoting @d1cker’s benchmarks ):

julia> @btime $ww * $xx;
  21.679 μs (2 allocations: 125.08 KiB)

julia> @btime tanh.($xx);
  1.256 ms (2 allocations: 250.08 KiB)

julia> @btime @avx tanh.($xx);
  142.913 μs (2 allocations: 250.08 KiB)

Julia is mostly single threaded, because @avx tanh.(x) is singlethreaded. I’ll add a vmapnt to LoopVectorization soon to make it easier to run this multithreaded.
There is currently a vmapntt which is threaded, but it uses non-temporal stores which are likely to hurt performance unless the batches are very large.

TL;DR ignoring tanh and setting Julia’s BLAS thread count = # of physical cores, Flux actually edges out PyTorch!

I’m assuming @d1cker’s CPU may have SMT enabled after testing on a processor of a similar vintage:

julia> versioninfo()
Julia Version 1.4.0
Commit b8e9a9ecc6 (2020-03-21 16:36 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Core(TM) i7-7700K CPU @ 4.20GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-8.0.1 (ORCJIT, skylake)

PyTorch uses 4 threads by default. I assume this is calculated using the number of physical cores:

In [3]: xx = torch.randn(1000, 8, dtype=torch.float32)                                                             

In [4]: nn = torch.nn.Sequential( 
  ...:     torch.nn.Linear(8, 64), 
  ...:     torch.nn.Linear(64, 32), 
  ...:     torch.nn.Linear(32, 2) 
  ...: )                                                                                                          

In [5]: %timeit nn(xx)[0,0]                                                                                        
115 µs ± 474 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [6]: torch.get_num_threads()                                                                                    
Out[6]: 4

In [7]: nn2 = torch.nn.Sequential( 
  ...:     torch.nn.Linear(8, 64), 
  ...:     torch.nn.Tanh(), 
  ...:     torch.nn.Linear(64, 32), 
  ...:     torch.nn.Tanh(), 
  ...:     torch.nn.Linear(32, 2), 
  ...:     torch.nn.Tanh() 
  ...: )                                                                                                          

In [8]: %timeit nn2(xx)[0,0]                                                                                       
239 µs ± 811 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Julia uses 8 threads, which I assume comes from the logical core count (2x4):

julia> xx = randn(Float32, 8, 1000)
8×1000 Array{Float32,2}:
  0.503908   1.18022   -1.06559     1.84436    …  -0.46637   -0.876909   0.82367    0.706861   0.440678
  1.44368   -1.49167   -0.0692981   0.174324      -1.43919    1.38345    0.667294   1.19638   -0.097154
 -0.151144  -0.264827  -1.69988     1.00654        0.402308   1.49821    0.123561   1.09967   -1.00765
 -1.14175   -0.231032   0.110278    0.0363431      0.101058   0.372235  -0.511466   0.526048  -1.4937
  0.614555  -0.713141  -2.32049     0.608342       0.376743   1.15385   -0.235271  -0.220544   1.01734
  1.6772    -0.538957  -1.28689    -2.1525     …  -0.679366   0.269263  -0.902019  -1.71012   -1.95944
  0.318371  -0.838747  -0.158875    0.407624       0.682344   0.293501  -0.964282  -0.795304  -0.719654
 -0.831405  -0.36298   -0.341576   -1.46556        0.927488  -0.485381   0.170069  -1.17026   -0.617091

julia> const nn = Chain(Dense(8,64),
                       Dense(64,32),
                       Dense(32,2))
Chain(Dense(8, 64), Dense(64, 32), Dense(32, 2))

julia> @benchmark nn($xx)
BenchmarkTools.Trial: 
  memory estimate:  766.19 KiB
  allocs estimate:  10
  --------------
  minimum time:     124.091 μs (0.00% GC)
  median time:      152.492 μs (0.00% GC)
  mean time:        170.617 μs (9.51% GC)
  maximum time:     1.460 ms (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1

julia> import LinearAlgebra.BLAS

julia> BLAS.set_num_threads(4)

julia> @benchmark nn($xx)
BenchmarkTools.Trial: 
  memory estimate:  766.19 KiB
  allocs estimate:  10
  --------------
  minimum time:     90.263 μs (0.00% GC)
  median time:      100.746 μs (0.00% GC)
  mean time:        112.331 μs (9.46% GC)
  maximum time:     836.423 μs (75.47% GC)
  --------------
  samples:          10000
  evals/sample:     1

julia> const nn2 = Chain(Dense(8,64,tanh),
                       Dense(64,32,tanh),
                       Dense(32,2,tanh))
Chain(Dense(8, 64, tanh), Dense(64, 32, tanh), Dense(32, 2, tanh))

julia> @benchmark nn2($xx)
BenchmarkTools.Trial: 
  memory estimate:  766.19 KiB
  allocs estimate:  10
  --------------
  minimum time:     1.575 ms (0.00% GC)
  median time:      1.589 ms (0.00% GC)
  mean time:        1.602 ms (0.68% GC)
  maximum time:     2.335 ms (27.41% GC)
  --------------
  samples:          3119
  evals/sample:     1

Not sure why SMT has such a large performance impact…

1 Like

You could also try MKL.jl. MKL won’t use more threads than the number of physical cores, and performs better for small matrices in general.

2 Likes

I set the number of cores that pytorch uses to 1

using PyCall
using Flux
using BenchmarkTools

torch = pyimport("torch")
torch.set_num_threads(1)

NN = torch.nn.Sequential(
    torch.nn.Linear(8, 64),
    torch.nn.Tanh(),
    torch.nn.Linear(64, 32),
    torch.nn.Tanh(),
    torch.nn.Linear(32, 2),
    torch.nn.Tanh()
)

torch_nn(in) = NN(in)

Flux_nn = Chain(Dense(8,64,tanh),
                Dense(64,32,tanh),
                Dense(32,2,tanh))

for i in [1, 10, 100, 1000]
    println("Batch size: $i")
    torch_in = torch.rand(i,8)
    flux_in = rand(Float32,8,i)
    print("pytorch     :")
    @btime torch_nn($torch_in)
    print("flux        :")
    @btime Flux_nn($flux_in)    
end
Batch size: 1
pytorch     :  88.087 μs (6 allocations: 192 bytes)
flux        :  3.567 μs (6 allocations: 1.25 KiB)
Batch size: 10
pytorch     :  100.711 μs (6 allocations: 192 bytes)
flux        :  18.236 μs (6 allocations: 8.22 KiB)
Batch size: 100
pytorch     :  140.269 μs (6 allocations: 192 bytes)
flux        :  162.120 μs (8 allocations: 77.16 KiB)
Batch size: 1000
pytorch     :  465.119 μs (6 allocations: 192 bytes)
flux        :  4.485 ms (10 allocations: 766.19 KiB)
               _
   _       _ _(_)_     |  Documentation: https://docs.julialang.org
  (_)     | (_) (_)    |
   _ _   _| |_  __ _   |  Type "?" for help, "]?" for Pkg help.
  | | | | | | |/ _` |  |
  | | |_| | | | (_| |  |  Version 1.4.1 (2020-04-14)
 _/ |\__'_|_|_|\__'_|  |  Official https://julialang.org/ release
|__/                   |

julia> torch.get_num_threads()
1

the result are about the same so I don’t think it is only the threads stuff. Pytorch probably has some other tricks. I will check MKL.jl later today, maybe that’s that.

about the SMT, I think it is disables on my laptop. but I’m sure I checked it right

dicker@dicker-X1:~$ grep -o '^flags\b.*: .*\bht\b' /proc/cpuinfo | tail -1
flags		: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht

I will try that, thanks!