I have a struct that is copied a lot in my program, so reducing the size of this struct has very noticeable impacts on performance. Of the variables I have, some are an integer (currently represented as Int8) that are bounded between -4 and 4 inclusive (so 9 possible options; each could theoretically be referred to with 4 bits). Pairs of these numbers are always accessed together. Is there a way to combine two of these numbers together in one byte that is both 1) performant and 2) still readable?
If you define something like
struct smallPair
val::UInt8
end
function smallPair(part1::Integer, part2::Integer)
smallPair((part1+8) + (part2+8)<<4)
end
get_part1(x::smallPair) = (x.val & 0x0F) - 8
get_part2(x::smallPair) = (x.val >> 4) - 8
I think it should work.
If you actually wanted to use this, you would probably want to customize printing, and enforce bounds on the parts in the constructor, but this general idea hopefully helps.
Note that doing - 8
here promotes the result to Int
, which may be undesirable.
The theory behind that was that if you want -4 to 4 as your output range, you want some variety of signed Integer
, and that since you are presumably only extracting when you are doing logic after you do the extract Int
is probably the best type to use.
I think maybe in this specific case, wrapping it like - Int8(8)
might be more performant, since the other numbers I’ll be doing arithmetic with them are Int8’s as well
Out of curiosity, is there a more efficient way to add these together? The basic way that doesn’t take advantage of the built-in structure is shown below
Base.:+(x::smallPair, y::smallPair) = smallPair(get_part1(x) + get_part1(y), get_part2(x) + get_part2(y))
If you know that overflow won’t occur, you could just use
Base.:+(x::smallPair, y::smallPair) = smallPair(x.val+y.val)
This doesn’t seem to work as expected.
Here is my setup (which should avoid overflow:
struct smallPair
val::UInt8
end
function smallPair(part1::Int8, part2::Int8)
smallPair((clamp(part1, Int8(-4), Int8(4)) + Int16(8)) +
(clamp(part2, Int8(-4), Int8(4)) + Int16(8))<<Int16(4))
end
get_part1(x::smallPair) = Int8(x.val & 0x0F) - Int8(8)
get_part2(x::smallPair) = Int8(x.val >> 4) - Int8(8)
This works as expected, returning (1, 1):
Base.:+(x::smallPair, y::smallPair) = smallPair(get_part1(x) + get_part1(y), get_part2(x) + get_part2(y))
x = smallPair(Int8(1), Int8(0))
y = smallPair(Int8(0), Int8(1))
z = x + y
println("$(get_part1(z)), $(get_part2(z))")
This does not, and returns (-7, -6)
Base.:+(x::smallPair, y::smallPair) = smallPair(x.val+y.val)
x = smallPair(Int8(1), Int8(0))
y = smallPair(Int8(0), Int8(1))
z = x + y
println("$(get_part1(z)), $(get_part2(z))")
Am I missing some overflow?
Oh yeah, good point. I forgot to compensate the offsetting. It should be Base.:+(x::smallPair, y::smallPair) = smallPair(x.val+y.val+UInt8(136))
.
(136==8<<4 +8
)
I think it should be Base.:+(x::smallPair, y::smallPair) = smallPair(x.val+y.val+UInt8(120))
but otherwise works great!
Unfortunately it doesn’t preserve the bounds like the other version though.
Yeah, if you want to check for overflow, I don’t think you’ll be able to do notably better than your version.
You don’t even need to convert anything to UInt. Combination of signed and unsigned bitshifts allows this:
julia> struct IntPair
val::Int8
end
julia> IntPair(left::Int8, right::Int8) = IntPair((left << 4) | (right << 4 >>> 4))
IntPair
julia> IntPair(l,r) = IntPair(Int8(l),Int8(r))
IntPair
julia> decode(p::IntPair) = (p.val >> 4, p.val << 4 >> 4)
decode (generic function with 1 method)
check for correctness:
julia> all(p == decode(IntPair(p...)) for p in Iterators.product(-8:7, -8:7))
true
and piecewise addition becomes:
julia> Base.:+(l::IntPair, r::IntPair) = IntPair(((l.val & Int8(-16))+(r.val & Int8(-16))) + ((l.val+r.val) & Int8(15)))
# can't use UInt literals here, because they lead to conversion errors and this is even typestable throughout
check:
julia> all( (l .+ r) == decode(IntPair(l...) + IntPair(r...))
for l in Iterators.product(-8:7, -8:7) for r in Iterators.product(-8:7, -8:7)
if (-8≤(first(l)+first(r))≤7)&&(-8≤(last(l)+last(r))≤7)) # check that this combination won't overflow our 4bit range
true