Square roots of integers

I need a function that returns the square root of a positive integer less than Nmax but does the computation only once for a given integer.
In C, I would compute all values on the first call using a static variable like this:

double function sqrttable(unsigned int i)
{
     static double *cache = NULL:
     if (cache == NULL)
     {
             cache = (double*)malloc((Nmax+1)*sizeof(double));
             for (unsigned int k=0; k<=Nmax; ++k) cache[k] = sqrt(k)
     }
     // add a check for i<=Nmax
     return cache[i]
}

What is the best way to achieve this in Julia?

1 Like
julia> const sqrttable=[sqrt(i) for i=1:1_000_000];
julia> sqrt_tab(x)=sqrttable[x];
julia> sqrt_tab_ib(x)=(@inbounds r= sqrttable[x]; r);
julia> vals=rand(1:1_000_000, 1_000_000);
julia> using BenchmarkTools
julia> @btime copy($vals);
  842.508 ��s (2 allocations: 7.63 MiB)
julia> @btime map($sqrt, $vals);
  3.365 ms (3 allocations: 7.63 MiB)
julia> @btime map($sqrt_tab, $vals);
  10.001 ms (3 allocations: 7.63 MiB)
julia> @btime map($sqrt_tab_ib, $vals);
  9.612 ms (3 allocations: 7.63 MiB)

Unless your table is truly tiny or you have to use a CPU from the early 90s, you are better off not using a table. Cache latency > Memory bandwidth > instruction latency > instruction throughput

Note that the above is memory bandwidth constrained (the CPU out-of-order reads ahead). If you CPU cannot predict the lookup (via out-of-order), your cache will blow up in your face even worse.

Note that the timings almost fit: 8x read-amplification means that the lookups should take 6.8 ms if your CPU never stalls due to latency.

6 Likes

You can use a let block to create a new scope where the cache lives, define an anonymous function inside that scope, and then store the resulting anonymous function as a const variable (since otherwise it would be a non-global constant):

julia> const sqrttable = let cache = Float64[]
          for i in 1:100
             push!(cache, sqrt(i))
          end
          i -> cache[i] # this anonymous function is returned from the 
                        # let block and thus becomes bound to the
                        # const sqrttable
       end
(::#1) (generic function with 1 method)

julia> sqrttable(4)
2.0

julia> sqrttable(16)
4.0

although it’s worth nothing that you don’t save much time by doing this:

julia> using BenchmarkTools

julia> @btime sqrttable(4)
  1.824 ns (0 allocations: 0 bytes)
2.0

julia> @btime sqrt(4)
  2.171 ns (0 allocations: 0 bytes)
2.0
5 Likes

So just to extend, my benchmarks were of course a bit off because I did not account for higher-order caches (which are always hot in the benchmark loop, and large enough to account for a sizable fraction of the table). Just to show you how fast your CPU is compared to its puny memory bus (repeat on your target computer):

julia> const sqrttable=[sqrt(i) for i=1:10_000_000];
julia> sqrt_tab_ib(x)=(@inbounds r= sqrttable[x]; r);
julia> vals=rand(1:10_000_000, 10_000_000);

julia> function cpx(V)
       res=similar(V)
       @simd for i=1:length(V)
       @inbounds res[i]=V[i]
       end
       res
       end


julia> @btime copy(vals);
  48.847 ms (2 allocations: 76.29 MiB)
julia> @btime cpx(vals);
  39.668 ms (2 allocations: 76.29 MiB)
julia> @btime sqrt.(vals);
  66.954 ms (26 allocations: 76.30 MiB)
julia> @btime sqrt_tab_ib.(vals);
  159.656 ms (26 allocations: 76.30 MiB)

Vectorized sqrt is not much slower than memcopy from/to main memory. I have no idea why the default copy appears to be slow on my system (maybe I borked my 0.6 sysimg?).

In the table lookup case, each entry costs 1x write and 1+8x read (one cache line = 8 Float64). In the memcopy case we need 1x write and 1x read for an expected factor of 5 I’d call that close enough.

Postscript: My system sucks. The problem is not julia’s slow copy, it is that memcopy is significantly slower that the simd-copy on my system.

julia> function memcpx(V)
       res = similar(V)
       ccall(:memcpy, Ptr{Void}, (Ptr{Void}, Ptr{Void},Csize_t), pointer(res), pointer(V), sizeof(V))
       res
       end
julia> @btime memcpx(vals);
  47.814 ms (2 allocations: 76.29 MiB)

Pstscript2: Meh, I’m really bad at guessing what I’m benchmarking. Need to pre-allocate buffers. All the copying does not measure memory speed; probably how fast the kernel is at faulting in freshly zeroed memory.

julia> memcpy(dst,src)=ccall(:memcpy, Ptr{Void}, (Ptr{Void}, Ptr{Void},Csize_t), pointer(dst), pointer(src), sizeof(src));
julia> src=zeros(10_000_000); dst=similar(src);
julia> @btime memcpy(dst,src)
  8.518 ms (7 allocations: 112 bytes)
julia> @btime zeros(10_000_000);
  36.124 ms (2 allocations: 76.29 MiB)
3 Likes

Thank you both for showing me it is not a good idea. Even with a perfectly predictable access pattern, the gain is not significant:

julia> vals=collect(1:10_000_000);

julia> @btime sqrt.(vals);
  42.875 ms (26 allocations: 76.30 MiB)

julia> @btime sqrt_tab_ib.(vals);
  31.707 ms (26 allocations: 76.30 MiB)
1 Like

I have something similar:

julia> vals=rand(1:10_000_000, 10_000_000);

julia> @btime copy(vals);
  35.318 ms (2 allocations: 76.29 MiB)

julia> @btime cpx(vals);
  27.363 ms (2 allocations: 76.29 MiB)

julia> @btime memcpx(vals);
  33.627 ms (2 allocations: 76.29 MiB)

memcpy is provided by the binaries of the Linux distribution which are compiled for a generic x64 system.
Can’t julia compilation using LLVM achieve a better copy using the best SIMD instructions available in the actual CPU?

Edit: I guess it would not be consistent with the copy being limited only by memory bandwidth.

On my machine, running version 0.7-alpha on Windows 10, I get no difference at all:

julia> const sqrttable = let cache = Float64[]
          for i in 1:100
             push!(cache, sqrt(i))
          end
          i -> cache[i] # this anonymous function is returned from the 
                        # let block and thus becomes bound to the
                        # const sqrttable
       end
(::#1) (generic function with 1 method)

julia> using BenchmarkTools

julia> @btime sqrttable(4)
  1.026 ns (0 allocations: 0 bytes)
2.0

julia> @btime sqrt(4)
  1.026 ns (0 allocations: 0 bytes)
2.0

I suspect this has to do with literal propagation. Your functions get compiled specifically for the literal input 4, so what you’re measuring is the time it takes to return the constant 2.0.

2 Likes

You could also do @generated s(x) = :(sqrt(x)), but that’s not much faster either. Caching would probably make more sense if you are computing more complicated expressions than sqrt.

Why would that help?

In my post I said that it doesn’t give you any benefit, since it’s on par with taking sqrt on its own. However, it is another possible way of caching the results.

No, it doesn’t cache anything; it’s just the same as s(x) = sqrt(x) (but slower to compile).

julia> @generated s(x) = :(sqrt(x))
s (generic function with 1 method)

julia> @code_llvm s(3)

define double @julia_s_62894(i64) #0 !dbg !5 {
top:
  %1 = icmp sgt i64 %0, -1
  br i1 %1, label %pass, label %fail

fail:                                             ; preds = %top
  call void @jl_throw(i8** inttoptr (i64 4482821056 to i8**))
  unreachable

pass:                                             ; preds = %top
  %2 = sitofp i64 %0 to double
  %3 = call double @llvm.sqrt.f64(double %2)
  ret double %3
}

julia> @code_llvm sqrt(3)

define double @julia_sqrt_62897(i64) #0 !dbg !5 {
top:
  %1 = icmp sgt i64 %0, -1
  br i1 %1, label %pass, label %fail

fail:                                             ; preds = %top
  call void @jl_throw(i8** inttoptr (i64 4482821056 to i8**))
  unreachable

pass:                                             ; preds = %top
  %2 = sitofp i64 %0 to double
  %3 = call double @llvm.sqrt.f64(double %2)
  ret double %3
}

You’re right, it doesn’t quite work the way I thought it did. It does cache the code, but not the results.

One way that you can use to speed up sqrt is to avoid the range check. This can allow the sqrt to be vectorized (if the caller can be vectorized otherwise of course…).

You can do this with the @fastmath version of the function if you convert the integer input to float, or you can make sure you use a unsigned integer type in which case LLVM should be smart enough to tell that the range check can be skipped.