SLEEFPirates.tanh_fast
can also sometimes be faster (varying by CPU, OS, and Julia version), Julia 1.4:
julia> using VectorizationBase, SLEEFPirates, BenchmarkTools
julia> sx = SVec(ntuple(VectorizationBase.pick_vector_width_val(Float64)) do _ Core.VecElement(10randn(Float64)) end)
SVec{8,Float64}<-11.39873050155192, -5.486798123978332, 3.412707732482607, -15.818801711259969, -27.826023273267584, -9.409541830489616, -6.034771142124557, 23.388739638777807>
julia> @btime tanh($sx)
21.976 ns (0 allocations: 0 bytes)
SVec{8,Float64}<-0.9999999997486849, -0.9999657034645373, 0.997830706021826, -0.9999999999999636, -1.0, -0.9999999865721709, -0.9999885371688979, 1.0>
julia> @btime SLEEFPirates.tanh_fast($sx)
7.644 ns (0 allocations: 0 bytes)
SVec{8,Float64}<-0.999999999748685, -0.9999657034645373, 0.997830706021826, -0.9999999999999636, -1.0, -0.999999986572171, -0.999988537168898, 1.0>
julia> sx = SVec(ntuple(VectorizationBase.pick_vector_width_val(Float32)) do _ Core.VecElement(10randn(Float32)) end)
SVec{16,Float32}<-7.3119183f0, 10.299262f0, -3.6201859f0, 7.25937f0, -7.0897646f0, 2.9893537f0, 0.13792089f0, -0.4937947f0, 0.47192842f0, 4.8886666f0, 13.347324f0, 13.720837f0, -12.469461f0, -0.8011465f0, 2.6552308f0, -16.101555f0>
julia> @btime tanh($sx)
20.119 ns (0 allocations: 0 bytes)
SVec{16,Float32}<-0.9999991f0, 1.0f0, -0.9985669f0, 0.999999f0, -0.9999986f0, 0.9949486f0, 0.13705297f0, -0.45722306f0, 0.43975613f0, 0.9998866f0, 1.0f0, 1.0f0, -1.0f0, -0.66467726f0, 0.9901693f0, -1.0f0>
julia> @btime SLEEFPirates.tanh_fast($sx)
6.729 ns (0 allocations: 0 bytes)
SVec{16,Float32}<-0.9999991f0, 1.0f0, -0.998567f0, 0.999999f0, -0.99999857f0, 0.99494857f0, 0.13705297f0, -0.45722306f0, 0.43975613f0, 0.9998866f0, 1.0f0, 1.0f0, -1.0f0, -0.6646772f0, 0.9901693f0, -1.0f0>
Julia 1.6:
julia> using VectorizationBase, SLEEFPirates, BenchmarkTools
julia> sx = SVec(ntuple(VectorizationBase.pick_vector_width_val(Float64)) do _ Core.VecElement(10randn(Float64)) end)
SVec{8,Float64}<4.660910780715415, -1.791911063291962, 0.4152394300299185, -14.03991723342163, -6.947193534140638, 2.738775662784096, -8.620789233157913, -0.7529700170278613>
julia> @btime tanh($sx)
12.669 ns (0 allocations: 0 bytes)
SVec{8,Float64}<0.9998211143579812, -0.9459618892732886, 0.3929123454879236, -0.9999999999987232, -0.9999981516936264, 0.9916756888100318, -0.9999999349709223, -0.6369174810382887>
julia> @btime SLEEFPirates.tanh_fast($sx)
7.776 ns (0 allocations: 0 bytes)
SVec{8,Float64}<0.9998211143579812, -0.9459618892732885, 0.39291234548792353, -0.9999999999987232, -0.9999981516936264, 0.9916756888100318, -0.9999999349709223, -0.6369174810382886>
julia> sx = SVec(ntuple(VectorizationBase.pick_vector_width_val(Float32)) do _ Core.VecElement(10randn(Float32)) end)
SVec{16,Float32}<2.6490726f0, 0.6964538f0, 4.003585f0, 3.3185313f0, -0.42056453f0, 7.228591f0, -11.268135f0, 1.5071146f0, -20.739851f0, 1.3313888f0, 3.428663f0, 10.747046f0, 13.510487f0, 14.988632f0, 14.164627f0, 6.938663f0>
julia> @btime tanh($sx)
8.034 ns (0 allocations: 0 bytes)
SVec{16,Float32}<0.99004805f0, 0.60211205f0, 0.9993341f0, 0.9973817f0, -0.39740592f0, 0.9999989f0, -1.0f0, 0.90642565f0, -1.0f0, 0.8695884f0, 0.99789876f0, 1.0f0, 1.0f0, 1.0f0, 1.0f0, 0.9999981f0>
julia> @btime SLEEFPirates.tanh_fast($sx)
6.904 ns (0 allocations: 0 bytes)
SVec{16,Float32}<0.9900481f0, 0.60211205f0, 0.9993341f0, 0.99738175f0, -0.3974059f0, 0.99999887f0, -1.0f0, 0.90642565f0, -1.0f0, 0.8695884f0, 0.99789876f0, 1.0f0, 0.99999994f0, 1.0f0, 1.0f0, 0.99999815f0>
But it is less accurate, and prone to returning +/- 1 more often.
torch’s 262 microseconds for the batch of 1000 is impressive.
One suggestion I’d have for improving Flux here is to break up the batch among threads, and have each thread evaluate its chunk, rather than threading just the BLAS operations, but it’d take some work to really optimize memory bandwidth.
In this example though, simple doing the tanh
es in parallel will help.