SIMD struggles, seeking solutions (with KangarooTwelve.jl)

A bit ago I mentioned a new hashing library I have been working on:

After a huge amount of effort, I’ve just managed to make the “simple” single-threaded version (k12_singlethreaded) zero-alloc and only off the C performance by a factor of 2 (I’m seeing ~1.5 GB/s, should be seeing 3-4 GB/s).

One feature that should help make up the gap is SIMD. I’ve been trying to use SIMD.jl to accelerate the parallel branches of the hashing, however despite my best efforts so far, the SIMD take is:

  • Slower
  • Sometimes incorrect

To show this, here’s a quick example:

(KangarooTwelve) julia> bitpattern(num::Int) = Iterators.take(Iterators.cycle(0x00:0xfa), num) |> collect
bitpattern (generic function with 1 method)

(KangarooTwelve) julia> data = bitpattern(17^7);

(KangarooTwelve) julia> @time k12_singlethreaded(data, UInt8[])
  0.238118 seconds (2 allocations: 64 bytes)
0x05d354e539382fe58260c271776a39a4

(KangarooTwelve) julia> @time k12_singlethreaded_simd(data, UInt8[])
  0.977218 seconds (25.05 k allocations: 11.464 MiB)
0x05d354e539382fe58260c271776a39a4

We can make the SIMD version faster (new time: ~0.6s) by adding @inline to the rol64 inner function within keccak_p1600, however then the results are non-deterministic :confused:.

I’d hugely appreciate any help getting this closer to the performance of the C/Rust implementations floating around, and getting SIMD working :pray:.

2 Likes

I haven’t looked at the code in much detail but I would note SIMD.jl can generate pretty bad code especially compared to established x86 c++ simd libraries. Also, I wonder if your ntuple constructions are generating nice code, e.g. in some instances they certainly don’t Specialize `setindex` for isbits `Tuple`s and `_totuple` for isbits `NTuple`s by Zentrik · Pull Request #51748 · JuliaLang/julia · GitHub.

1 Like

One thing that SIMD.jl do different compared to libraries that do native assembly code is that SIMD.jl creates target independent LLVM IR and it relies on LLVM to do the optimization and selection of the actual intrinsics that will run on the native machine. This means that code written with SIMD.jl is portable but if LLVM does not do that great of a job with the assembly generation it will end up slower than hand picked machine specific intrinsics.

4 Likes

The thing that gets me is that this really feels like this is a situation where SIMD should be an easy win. The core permutation function that is SIMD’d is

Where the operations applied to the SIMD.jl Vec are UInt64 bitshifts, not-s, or-s, and xor-s. Adding @inline to the rol64 function improves the performance of the SIMD method a bit, but also makes the results non-deterministic!!

This leaves me at quite a loss, which is a pity because I suspect this is one of the major factors stopping us from nearing the performance of C/Rust implementations.

If inlining changes the result it’s either a bug with inlining, or a bug in your code. The second is more likely

ρs[1] = 0 so in roll64 you do a >> 64 is this legal in Julia? I feel like I remember reading how this isn’t well defined in C but I may be wrong.

EDIT:
Microsoft seems to claim its undefined behaviour Warning C26452 | Microsoft Learn.
Also clang with ubsan does as well, Compiler Explorer.

1 Like

Thanks for that @Zentrik! That is indeed the source of the correctness issue, and I’ve just pushed a fix: Fix undefined behaviour in rol64 call · tecosaur/KangarooTwelve.jl@d4d95f8 · GitHub

Now I just need to work out what’s going on with the SIMD performance difference…

If the docs are right, it’s not UB in Julia:

help?>  1 >> 64
  >>(x, n)

  Right bit shift operator, x >> n. For n >= 0, the result is x shifted right by n bits, filling with 0s if x >= 0, 1s if x <
  0, preserving the sign of x. This is equivalent to fld(x, 2^n). For n < 0, this is equivalent to x << -n.

Though this may be a case of hardware semantics not exactly lining up with Julia semantics. If that’s the case, a MWE would be a good bug report.

2 Likes

I would profile to find the hotspots, presumably its keccak_p1600. Then compare LLVM IR/ assembly to see if there’s a difference there, perf could be useful for doing both steps at once.

1 Like

It does seem like it’s a bug somewhere as there are previous issues/prs fix #37880, overflow in shift amount range check in codegen by JeffBezanson · Pull Request #37891 · JuliaLang/julia (github.com). This also shows some code that’s seems to check for overflow another fix to run-time ashr_int by JeffBezanson · Pull Request #24575 · JuliaLang/julia (github.com).

EDIT: wrong function, we’re probably calling lshr_int but still shows that this probably should have been checked for.

I think I’ve made some progress on figuring out the performance difference.

It seems like when the SIMD path is used, it makes a big difference whether the input array is UInt64s, or reinterpreted UInt64s.

julia> longmsg_u8 = rand(UInt8, 100000000 * 8);

julia> longmsg_u8r64 = reinterpret(UInt64, longmsg_u8);

julia> longmsg_u64 = rand(UInt64, 100000000);

julia> @time KangarooTwelve.turboshake(UInt128, longmsg_u64) # non-SIMD
  0.432940 seconds (4 allocations: 352 bytes)
0x3b5ea3aa194f8f0be7c2986cea24b0b0

julia> @time KangarooTwelve.turboshake(UInt128, longmsg_u8r64) # non-SIMD
  0.440677 seconds (4 allocations: 352 bytes)
0xa8a43cb9fa252800f2d0c4a7a971dbcd

julia> @time KangarooTwelve.turboshake(UInt128, (longmsg_u64, longmsg_u64, longmsg_u64, longmsg_u64))
  1.125633 seconds (3 allocations: 944 bytes)
(0x3b5ea3aa194f8f0be7c2986cea24b0b0, 0x3b5ea3aa194f8f0be7c2986cea24b0b0, 0x3b5ea3aa194f8f0be7c2986cea24b0b0, 0x3b5ea3aa194f8f0be7c2986cea24b0b0)

julia> @time KangarooTwelve.turboshake(UInt128, (longmsg_u8r64, longmsg_u8r64, longmsg_u8r64, longmsg_u8r64))
  4.212588 seconds (3 allocations: 976 bytes)
(0xa8a43cb9fa252800f2d0c4a7a971dbcd, 0xa8a43cb9fa252800f2d0c4a7a971dbcd, 0xa8a43cb9fa252800f2d0c4a7a971dbcd, 0xa8a43cb9fa252800f2d0c4a7a971dbcd)

I use PtrArray to get around that.

From glancing at GitHub - JuliaSIMD/StrideArraysCore.jl: The core AbstractStrideArray type, separated from StrideArrays.jl to avoid circular dependencies., PtrArray just seems like another take on ReinterpretArray?

Perhaps the compiler has less information on the alignment of reinterpreted data. Just don’t reinterpret, in any case.

I think that’s right. You allow it to make assumptions. … that might be wrong (??). So in principle this seems less safe. Experts can say how unsafe.

Yea, that’s not really a good option here. I need to xor with the UInt64 based state, which means I need UInt64s. It would be possible to instead reinterpret the NTuple{25, UInt64} state as whatever the input vector is (say UInt8s), but I’m pretty sure that will reduce performance in another way.

If possible, reintepret each object separately, instead of the whole array.

Maybe I’m just not seeing it, but at a glance that sounds arduous/tedious for little benefit?

I think I had performance problems with reinterpreted arrays before. At the time I “solved” it with some pointer hackery, although I’m not sure that’s appropriate for your issue.

In the profiles you can see that a lot of time is spent on indexing.
This is @profview KangarooTwelve.turboshake(UInt128, (longmsg_u64, longmsg_u64, longmsg_u64, longmsg_u64))

And this the reinterpreted one

A single flame on the left looks like this expanded

1 Like

I just played around with this a bit more, and it seems like with this rather hacky reinterpreter that just reads the bytes through a converted pointer access, it’s as fast as the native UInt64 version for me:

struct UnsafeReinterpretedVector <: AbstractArray{UInt64,1}
  v::Vector{UInt8}
  len::Int
  function UnsafeReinterpretedVector(v::Vector{UInt8})
    length(v) % 8 == 0 || error("Incompatible length")
    new(v, length(v) ÷ 8)
  end
end

Base.size(u::UnsafeReinterpretedVector) = (u.len,)

function Base.getindex(u::UnsafeReinterpretedVector, i::Int)
  0 < i <= u.len || throw(BoundsError(u, i))
  v = u.v
  GC.@preserve v begin
    p = pointer(v)
    p64 = Ptr{UInt64}(p)
    unsafe_load(p64, i)
  end
end

longmsg_unsafe_u64 = UnsafeReinterpretedVector(longmsg_u8);

@time KangarooTwelve.turboshake(UInt128, (longmsg_unsafe_u64, longmsg_unsafe_u64, longmsg_unsafe_u64, longmsg_unsafe_u64))

1.905471 seconds (3 allocations: 976 bytes)
(0xcc093729c8f34a590fb7e34103068fbb, 0xcc093729c8f34a590fb7e34103068fbb, 0xcc093729c8f34a590fb7e34103068fbb, 0xcc093729c8f34a590fb7e34103068fbb)