Sqrt(abs(x)) is even faster than sqrt!

I accidentally found that computing both sqrt and abs is even faster than sqrt alone.

julia> using BenchmarkTools

julia> sqrt_abs(x) = sqrt(abs(x));

julia> const N = 1000;

julia> const positive_x = randn(N) .^ 2 .+ 1;

julia> const y = zeros(N);

julia> @btime $y .= sqrt.($positive_x);
  5.017 μs (0 allocations: 0 bytes)

julia> @btime $y .= sqrt_abs.($positive_x);
  2.523 μs (0 allocations: 0 bytes)

# check correctness
julia> sqrt.(positive_x) == sqrt_abs.(positive_x)
true

Julia Version 1.6.0
Any explanation? :grimacing:

2 Likes

On a scalar level they are similar:

julia> using Core.Intrinsics: sqrt_llvm

julia> @btime sqrt(1.23);
  1.607 ns (0 allocations: 0 bytes)

julia> @btime sqrt(abs(1.23));
  1.665 ns (0 allocations: 0 bytes)

julia> @btime sqrt_llvm(1.23); # sqrt but without x < 0.0 check
  1.608 ns (0 allocations: 0 bytes)

Interesting. The time of sqrt(abs) is close to sqrt_llvm.

julia> @btime $y .= sqrt_llvm.($positive_x);
  2.524 μs (0 allocations: 0 bytes)

Does it mean that the compiler is smart enough to remove < 0.0 check if we include the abs additionally?

The error check prevents SIMD.
abs eliminates the error check (yes, the compiler is smart enough), so that version is SIMD.

vector.body:                                      ; preds = %vector.body, %vector.ph
  %index = phi i64 [ 0, %vector.ph ], [ %index.next, %vector.body ]
  %38 = getelementptr inbounds double, double* %33, i64 %index
  %39 = bitcast double* %38 to <8 x double>*
  %wide.load = load <8 x double>, <8 x double>* %39, align 8
  %40 = getelementptr inbounds double, double* %38, i64 8
  %41 = bitcast double* %40 to <8 x double>*
  %wide.load31 = load <8 x double>, <8 x double>* %41, align 8
  %42 = getelementptr inbounds double, double* %38, i64 16
  %43 = bitcast double* %42 to <8 x double>*
  %wide.load32 = load <8 x double>, <8 x double>* %43, align 8
  %44 = getelementptr inbounds double, double* %38, i64 24
  %45 = bitcast double* %44 to <8 x double>*
  %wide.load33 = load <8 x double>, <8 x double>* %45, align 8
  %46 = call <8 x double> @llvm.fabs.v8f64(<8 x double> %wide.load)
  %47 = call <8 x double> @llvm.fabs.v8f64(<8 x double> %wide.load31)
  %48 = call <8 x double> @llvm.fabs.v8f64(<8 x double> %wide.load32)
  %49 = call <8 x double> @llvm.fabs.v8f64(<8 x double> %wide.load33)
  %50 = call <8 x double> @llvm.sqrt.v8f64(<8 x double> %46)
  %51 = call <8 x double> @llvm.sqrt.v8f64(<8 x double> %47)
  %52 = call <8 x double> @llvm.sqrt.v8f64(<8 x double> %48)
  %53 = call <8 x double> @llvm.sqrt.v8f64(<8 x double> %49)
  %54 = getelementptr inbounds double, double* %36, i64 %index
  %55 = bitcast double* %54 to <8 x double>*
  store <8 x double> %50, <8 x double>* %55, align 8
  %56 = getelementptr inbounds double, double* %54, i64 8
  %57 = bitcast double* %56 to <8 x double>*
  store <8 x double> %51, <8 x double>* %57, align 8
  %58 = getelementptr inbounds double, double* %54, i64 16
  %59 = bitcast double* %58 to <8 x double>*
  store <8 x double> %52, <8 x double>* %59, align 8
  %60 = getelementptr inbounds double, double* %54, i64 24
  %61 = bitcast double* %60 to <8 x double>*
  store <8 x double> %53, <8 x double>* %61, align 8
  %index.next = add i64 %index, 32
  %62 = icmp eq i64 %index.next, %n.vec
  br i1 %62, label %middle.block, label %vector.body

This is what we get with @. y = sqrt(abs(x)).

12 Likes

Is there another way to tell the compiler that it doesn’t need to error check?

I naively tried

julia> function possqrt(x)
           for i in x
               @assert i > 0
           end
           Core.Intrinsics.sqrt_llvm(x)
       end
possqrt (generic function with 1 method)

julia> @btime possqrt.($x);
  1.498 μs (1 allocation: 7.94 KiB)

julia> @btime sqrt.($x);
  1.495 μs (1 allocation: 7.94 KiB)

julia> @btime sqrt.(abs.($x));
  781.618 ns (1 allocation: 7.94 KiB)

Amazing. The internal implementation is

@inline function sqrt(x::Union{Float32,Float64})
    x < zero(x) && throw_complex_domainerror(:sqrt, x)
    sqrt_llvm(x)
end

Can I say that, if we know an input array is positive, we should use sqrt_llvm for best performance? Is there any macro like @inbounds to eliminate the check explicitly?

1 Like

@fastmath

You could also use @llvm.assume.
E.g.:

julia> using VectorizationBase: assume

julia> @inline function possqrt(x)
           assume(!(x < zero(x)))
           sqrt(x)
       end
possqrt (generic function with 1 method)

julia> x = rand(1024); y = similar(x);

julia> @benchmark @. $y = possqrt($x)
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     669.350 ns (0.00% GC)
  median time:      669.478 ns (0.00% GC)
  mean time:        670.400 ns (0.00% GC)
  maximum time:     1.304 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     157

julia> @benchmark @. $y = sqrt($x)
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     1.310 μs (0.00% GC)
  median time:      1.312 μs (0.00% GC)
  mean time:        1.318 μs (0.00% GC)
  maximum time:     2.613 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     10

julia> @benchmark @. $y = Base.FastMath.sqrt_fast($x)
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     669.344 ns (0.00% GC)
  median time:      669.478 ns (0.00% GC)
  mean time:        670.775 ns (0.00% GC)
  maximum time:     974.127 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     157

julia> @benchmark @. $y = Base.sqrt_llvm($x)
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     669.350 ns (0.00% GC)
  median time:      669.478 ns (0.00% GC)
  mean time:        670.943 ns (0.00% GC)
  maximum time:     1.296 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     157

julia> @benchmark @. $y = sqrt(abs($x))
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     669.350 ns (0.00% GC)
  median time:      669.484 ns (0.00% GC)
  mean time:        670.493 ns (0.00% GC)
  maximum time:     1.232 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     157

Note that abs is basically free, because it doesn’t use the same execution ports as sqrt, and sqrt is way slower. This is why sqrt.(abs.(x)) is exactly as fast as the other SIMD versions.

3 Likes

I think NaNMath.jl with the NaNMath.sqrt should also work (and you get NaNs for negative values).

2 Likes

Works indeed, thanks.

julia> @btime NaNMath.sqrt.($x);
  818.852 ns (1 allocation: 7.94 KiB)

julia> NaNMath.sqrt(-1)
NaN

As for the NaN of square root, sqrt_llvm already does it.

using Core.Intrinsics: sqrt_llvm

sqrt_llvm(-2.3)
NaN

Recall that

The implementation of sqrt in NaNMath.jl is (obtained by @edit NaNMath.sqrt(2.3))

sqrt(x::Real) = x < 0.0 ? NaN : Base.sqrt(x)

Note that there is still a x < 0.0 check. However, the speed is the same as sqrt_llvm:

julia> using NaNMath

julia> @btime $y .= NaNMath.sqrt.($positive_x);
  2.523 μs (0 allocations: 0 bytes)

I thought the check would prohibit SIMD as mentioned by @Elrod. Why is it fast here?
For your information, the implementation of Base.sqrt is

@inline function sqrt(x::Union{Float32,Float64})
    x < zero(x) && throw_complex_domainerror(:sqrt, x)
    sqrt_llvm(x)
end

which is similar to NaNMath.sqrt except the throw.

The check itself doesn’t prevent SIMD.
The problem with branches is that to SIMD them, you basically do both sides of the check, and then combine the answer.

You can SIMD a comparison.
You can SIMD sqrt.
You can SIMD returning NaN.
So the NaNMath variant is fine.

But you can’t SIMD throwing an error, which is where the Base definition runs into problems.

14 Likes