Splitting Int8s?

I have a struct that is copied a lot in my program, so reducing the size of this struct has very noticeable impacts on performance. Of the variables I have, some are an integer (currently represented as Int8) that are bounded between -4 and 4 inclusive (so 9 possible options; each could theoretically be referred to with 4 bits). Pairs of these numbers are always accessed together. Is there a way to combine two of these numbers together in one byte that is both 1) performant and 2) still readable?

If you define something like

struct smallPair
    val::UInt8
end
function smallPair(part1::Integer, part2::Integer)
    smallPair((part1+8) + (part2+8)<<4)
end
get_part1(x::smallPair) = (x.val & 0x0F) - 8
get_part2(x::smallPair) = (x.val >> 4) - 8

I think it should work.
If you actually wanted to use this, you would probably want to customize printing, and enforce bounds on the parts in the constructor, but this general idea hopefully helps.

7 Likes

Note that doing - 8 here promotes the result to Int, which may be undesirable.

4 Likes

The theory behind that was that if you want -4 to 4 as your output range, you want some variety of signed Integer, and that since you are presumably only extracting when you are doing logic after you do the extract Int is probably the best type to use.

2 Likes

I think maybe in this specific case, wrapping it like - Int8(8) might be more performant, since the other numbers I’ll be doing arithmetic with them are Int8’s as well

Out of curiosity, is there a more efficient way to add these together? The basic way that doesn’t take advantage of the built-in structure is shown below

Base.:+(x::smallPair, y::smallPair) = smallPair(get_part1(x) + get_part1(y), get_part2(x) + get_part2(y))

If you know that overflow won’t occur, you could just use
Base.:+(x::smallPair, y::smallPair) = smallPair(x.val+y.val)

This doesn’t seem to work as expected.

Here is my setup (which should avoid overflow:

struct smallPair
    val::UInt8
end

function smallPair(part1::Int8, part2::Int8)
    smallPair((clamp(part1, Int8(-4), Int8(4)) + Int16(8)) + 
        (clamp(part2, Int8(-4), Int8(4)) + Int16(8))<<Int16(4))
end

get_part1(x::smallPair) = Int8(x.val & 0x0F) - Int8(8)
get_part2(x::smallPair) = Int8(x.val >> 4) - Int8(8)

This works as expected, returning (1, 1):

Base.:+(x::smallPair, y::smallPair) = smallPair(get_part1(x) + get_part1(y), get_part2(x) + get_part2(y))
x = smallPair(Int8(1), Int8(0))
y = smallPair(Int8(0), Int8(1))
z = x + y
println("$(get_part1(z)), $(get_part2(z))")

This does not, and returns (-7, -6)

Base.:+(x::smallPair, y::smallPair) = smallPair(x.val+y.val)
x = smallPair(Int8(1), Int8(0))
y = smallPair(Int8(0), Int8(1))
z = x + y
println("$(get_part1(z)), $(get_part2(z))")

Am I missing some overflow?

Oh yeah, good point. I forgot to compensate the offsetting. It should be Base.:+(x::smallPair, y::smallPair) = smallPair(x.val+y.val+UInt8(136)).
(136==8<<4 +8)

1 Like

I think it should be Base.:+(x::smallPair, y::smallPair) = smallPair(x.val+y.val+UInt8(120)) but otherwise works great!

1 Like

Unfortunately it doesn’t preserve the bounds like the other version though.

Yeah, if you want to check for overflow, I don’t think you’ll be able to do notably better than your version.

1 Like

You don’t even need to convert anything to UInt. Combination of signed and unsigned bitshifts allows this:

julia> struct IntPair
           val::Int8
       end

julia> IntPair(left::Int8, right::Int8) = IntPair((left << 4) | (right << 4 >>> 4))
IntPair

julia> IntPair(l,r) = IntPair(Int8(l),Int8(r))
IntPair

julia> decode(p::IntPair) = (p.val >> 4, p.val << 4 >> 4)
decode (generic function with 1 method)

check for correctness:

julia> all(p == decode(IntPair(p...)) for p in Iterators.product(-8:7, -8:7))
true

and piecewise addition becomes:

julia> Base.:+(l::IntPair, r::IntPair) = IntPair(((l.val & Int8(-16))+(r.val & Int8(-16))) + ((l.val+r.val) & Int8(15)))
# can't use UInt literals here, because they lead to conversion errors and this is even typestable throughout

check:

julia> all( (l .+ r) == decode(IntPair(l...) + IntPair(r...))
       for l in Iterators.product(-8:7, -8:7) for r in Iterators.product(-8:7, -8:7)
       if (-8≤(first(l)+first(r))≤7)&&(-8≤(last(l)+last(r))≤7)) # check that this combination won't overflow our 4bit range
true
3 Likes