I’m conflicted if I should support just binary, or also ternary (in addition to the third format).
To simplify and have only two different formats, one dense and one sparse, I’m thinking:
- 64 binary values with a twist, call it pseudo-ternary, where every other value is 1 or 0, and the other values are -1 or 0.
- Then the other format float8s or maybe alternatively only allow for that format a bunch of float8, then float7, then maybe float5… to squeeze in a few more values in the sparse format; plus locate information for the whole row of such.
EDIT: When I think about it, storing location information is maybe not needed or wanted. It complicates encoding, meaning I have to try out for all the possible locations, see which fit best.
Simpler for the sparse format, store (8 lowest bits as zero to signal it), and then 7 float8s that have to be at the start of the vector.
A cell in JPEG is 8x8, one DC component, then 63 frequency components. I’m thinking, this would catch the DC and a few of the most important components with float8.
Do you know if neural networks are fed with fully decoded images, or possibly half decoded JPEGs (which seems sensible, as I explain here)? I know Nvidia has JPEG hardware on the GPUs to make training of neural nets faster. It seems most sensible to only half decode and keep the frequency components in the frequency domain. The DC component alone gives a self-similar image already, in the “time” domain, just scaled by 8 for both X and Y direction, i.e. 64 times smaller. (It’s also very well possible that the neural networks figure nothing out from the frequency components, basically disabling 63/64 of the training data…, but if not, then you want to bias computer vision to the big picture, and away from textures, and this might do that, while also using some of the texture info).
The point of the format is to force a sparse representation of weights. I think this format would induce that. To begin with you need to initialize weights somehow, and I think it’s just done randomly, so I think a random number for the dense format is good enough.
This is basically done as a proof-of-concept, e.g. encode_64_full, the main function. I just need to finish the trivial corresponding decode function, and validate correctness (I believe the functions it calls are ok).
Those are some ugly (long) function definitions (longest I’ve ever made, is there a better way to do this?):
using Statistics
function encode_64_full((w0, w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15, w16, w17, w18, w19, w20, w21, w22, w23, w24, w25, w26, w27, w28, w29, w30, w31, w32, w33, w34, w35, w36, w37, w38, w39, w40, w41, w42, w43, w44, w45, w46, w47, w48, w49, w50, w51, w52, w53, w54, w55, w56, w57, w58, w59, w60, w61, w62, w63))
codeb = encode_64b((w0, w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15, w16, w17, w18, w19, w20, w21, w22, w23, w24, w25, w26, w27, w28, w29, w30, w31, w32, w33, w34, w35, w36, w37, w38, w39, w40, w41, w42, w43, w44, w45, w46, w47, w48, w49, w50, w51, w52, w53, w54, w55, w56, w57, w58, w59, w60, w61, w62, w63))
code = encode_64((w0, w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15, w16, w17, w18, w19, w20, w21, w22, w23, w24, w25, w26, w27, w28, w29, w30, w31, w32, w33, w34, w35, w36, w37, w38, w39, w40, w41, w42, w43, w44, w45, w46, w47, w48, w49, w50, w51, w52, w53, w54, w55, w56, w57, w58, w59, w60, w61, w62, w63))
meanb = mean(abs.(decode_64b(codeb) .- (w0, w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15, w16, w17, w18, w19, w20, w21, w22, w23, w24, w25, w26, w27, w28, w29, w30, w31, w32, w33, w34, w35, w36, w37, w38, w39, w40, w41, w42, w43, w44, w45, w46, w47, w48, w49, w50, w51, w52, w53, w54, w55, w56, w57, w58, w59, w60, w61, w62, w63)))
meana = mean(abs.(decode_64(code) .- (w0, w1, w2, w3, w4, w5, w6, w7))); # println(meanb, "\n", meana);
if meanb > meana
('b', codeb) # Binary, dense
else
('n', code .% UInt8) # Not binary, sparse
end
end
Running this functions, needs all the function definitions below:
The returned values may seems strange, but they are correct clamped values (or rounded), as intended:
julia> decode_64_full(encode_64_full((0, -0.5, 0.5, 1.5, 2.5, -1.5, -2.5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)))
(0.0, -1.0, 1.0, 0.0, 1.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
julia> encode_64_full((0, -0.5, 0.5, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63))
('b', (false, true, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false))
julia> encode_64_full((0, -0.5, 0.5, 1.5, 2.5, -1.5, -2.5, 10, 20, 30, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0))
('n', (0x00, 0xc8, 0x38, 0x44, 0x4a, 0xbc, 0xb6, 0x5a))
function decode_64_full((is_b, tuple))
if is_b == 'b'
decode_64b(tuple)
else
decode_64(tuple)
end
end
using SoftPosit
PInt(n) = reinterpret(UInt8, Posit8(n))
function encode_64((w0, w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15, w16, w17, w18, w19, w20, w21, w22, w23, w24, w25, w26, w27, w28, w29, w30, w31, w32, w33, w34, w35, w36, w37, w38, w39, w40, w41, w42, w43, w44, w45, w46, w47, w48, w49, w50, w51, w52, w53, w54, w55, w56, w57, w58, w59, w60, w61, w62, w63))
(PInt(w0), PInt(w1), PInt(w2), PInt(w3), PInt(w4), PInt(w5), PInt(w6), PInt(w7)) # TODO: Add many extra, UInt8(0))
end
function encode_64b((w0, w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15, w16, w17, w18, w19, w20, w21, w22, w23, w24, w25, w26, w27, w28, w29, w30, w31, w32, w33, w34, w35, w36, w37, w38, w39, w40, w41, w42, w43, w44, w45, w46, w47, w48, w49, w50, w51, w52, w53, w54, w55, w56, w57, w58, w59, w60, w61, w62, w63))
0.5 .<= (w0, -w1, w2, -w3, w4, -w5, w6, -w7, w8, -w9, w10, -w11, w12, -w13, w14, -w15, w16, -w17, w18, -w19, w20, -w21, w22, -w23, w24, -w25, w26, -w27, w28, -w29, w30, -w31, w32, -w33, w34, -w35, w36, -w37, w38, -w39, w40, -w41, w42, -w43, w44, -w45, w46, -w47, w48, -w49, w50, -w51, w52, -w53, w54, -w55, w56, -w57, w58, -w59, w60, -w61, w62, -w63)
end
julia> encode_64b((0, -0.5, 0.5, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63))
(false, true, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false)
Clamp done, those need combining into one function, and to return UInt64 not (different) tuples.
function decode_64b((w0, w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15, w16, w17, w18, w19, w20, w21, w22, w23, w24, w25, w26, w27, w28, w29, w30, w31, w32, w33, w34, w35, w36, w37, w38, w39, w40, w41, w42, w43, w44, w45, w46, w47, w48, w49, w50, w51, w52, w53, w54, w55, w56, w57, w58, w59, w60, w61, w62, w63))
Float64.((w0, -w1, w2, -w3, w4, -w5, w6, -w7, w8, -w9, w10, -w11, w12, -w13, w14, -w15, w16, -w17, w18, -w19, w20, -w21, w22, -w23, w24, -w25, w26, -w27, w28, -w29, w30, -w31, w32, -w33, w34, -w35, w36, -w37, w38, -w39, w40, -w41, w42, -w43, w44, -w45, w46, -w47, w48, -w49, w50, -w51, w52, -w53, w54, -w55, w56, -w57, w58, -w59, w60, -w61, w62, -w63))
end
This might seem like a bug, but is as intended, because of clamping:
julia> decode_64b(encode_64b((0, -0.5, 0.5, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63)))
(0.0, -1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0)
PInt_to_float(n) = Float64(reinterpret(Posit8, n))
function decode_64((w0, w1, w2, w3, w4, w5, w6, dummy))
PInt_to_float.((w0, w1, w2, w3, w4, w5, w6, UInt8(0))) # , UInt8(0))) # Need to pad with more zeros
end
julia> decode_64(encode_64((0, -0.5, 0.5, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63)))
(0.0, -0.5, 0.5, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0)
julia> function decode_64_(n::UInt64) # (w0, w1, w2, w3, w4, w5, w6, dummy))
curr = n >> 8
w6 = PInt_to_float(curr & 255); curr >>= 8
w5 = PInt_to_float(curr & 255); curr >>= 8
w4 = PInt_to_float(curr & 255); curr >>= 8
w3 = PInt_to_float(curr & 255); curr >>= 8
w2 = PInt_to_float(curr & 255); curr >>= 8
w1 = PInt_to_float(curr & 255); curr >>= 8
w0 = PInt_to_float(curr)
(w0, w1, w2, w3, w4, w5, w6)
end
julia> PInt_to_float(n) = Float64(reinterpret(Posit8, n % UInt8))
julia> decode_64_(UInt64(232555454547376583))
(3.814697265625e-6, 0.625, 0.34375, -0.0009765625, 0.15625, -0.0625, 0.5625)
@code_native encode_64_full((0, -0.5, 0.5, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63))
is huge (less of a worry, I’ll look into optimizing, one reason might be it’s still type-unstable), but
res = encode_64b((0, -0.5, 0.5, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63))
@code_native decode_64b(res)
is also huge (currently, will look at first), more of a worry.
Hi @elrod, can you look at the (latter) assembly? Does anything stick out, and is it normal to get such long assembly for vectorized code? I haven’t timed it yet, it might be ok. Is there a good way to count instructions or just screenfull pages…? It would be nice to be able to pipe into less and wc.