diff --git a/coding/coding.pro b/coding/coding.pro index 38d4355fbf..ef20fd1c4e 100644 --- a/coding/coding.pro +++ b/coding/coding.pro @@ -34,6 +34,7 @@ SOURCES += \ file_name_utils.cpp \ varint_vector.cpp \ arithmetic_codec.cpp \ + compressed_bit_vector.cpp \ HEADERS += \ internal/xmlparser.hpp \ @@ -97,3 +98,4 @@ HEADERS += \ matrix_traversal.hpp \ varint_vector.hpp \ arithmetic_codec.hpp \ + compressed_bit_vector.hpp \ diff --git a/coding/coding_tests/coding_tests.pro b/coding/coding_tests/coding_tests.pro index fbaf306832..978406b722 100644 --- a/coding/coding_tests/coding_tests.pro +++ b/coding/coding_tests/coding_tests.pro @@ -47,6 +47,7 @@ SOURCES += ../../testing/testingmain.cpp \ file_utils_test.cpp \ varint_vector_test.cpp \ arithmetic_codec_test.cpp \ + compressed_bit_vector_test.cpp \ HEADERS += \ reader_test.hpp \ diff --git a/coding/coding_tests/compressed_bit_vector_test.cpp b/coding/coding_tests/compressed_bit_vector_test.cpp new file mode 100644 index 0000000000..5611cc90a3 --- /dev/null +++ b/coding/coding_tests/compressed_bit_vector_test.cpp @@ -0,0 +1,124 @@ +#include "../compressed_bit_vector.hpp" +#include "../reader.hpp" +#include "../writer.hpp" + +#include "../../testing/testing.hpp" +#include "../../base/pseudo_random.hpp" + +uint32_t const c_nums_count = 12345; + +namespace { + uint64_t GetRand64() { + static PseudoRNG32 g_rng; + uint64_t result = g_rng.Generate(); + result ^= uint64_t(g_rng.Generate()) << 32; + return result; + } +} + +UNIT_TEST(CompressedBitVector_Sparse) { + vector pos_ones; + uint32_t sum = 0; + for (uint32_t i = 0; i < c_nums_count; ++i) { + uint32_t byte_size = GetRand64() % 2 + 1; + uint64_t num = GetRand64() & ((uint64_t(1) << (byte_size * 7)) - 1); + if (num == 0) num = 1; + sum += num; + pos_ones.push_back(sum); + } + for (uint32_t j = 0; j < 5; ++j) { + if (j == 1) pos_ones.insert(pos_ones.begin(), 1, 0); + if (j == 2) pos_ones.clear(); + if (j == 3) pos_ones.push_back(1); + if (j == 4) { pos_ones.clear(); pos_ones.push_back(10); } + for (uint32_t ienc = 0; ienc < 4; ++ienc) { + vector serial_bit_vector; + MemWriter< vector > writer(serial_bit_vector); + BuildCompressedBitVector(writer, pos_ones, ienc); + MemReader reader(serial_bit_vector.data(), serial_bit_vector.size()); + vector dec_pos_ones = DecodeCompressedBitVector(reader); + TEST_EQUAL(pos_ones, dec_pos_ones, ()); + } + } +} + +UNIT_TEST(CompressedBitVector_Dense) { + vector pos_ones; + uint32_t prev_pos = 0; + uint32_t sum = 0; + for (uint32_t i = 0; i < c_nums_count; ++i) { + uint32_t zeroes_byte_size = GetRand64() % 2 + 1; + uint64_t zeroes_range_size = (GetRand64() & ((uint64_t(1) << (zeroes_byte_size * 7)) - 1)) + 1; + sum += zeroes_range_size; + uint32_t ones_byte_size = GetRand64() % 1 + 1; + uint64_t ones_range_size = (GetRand64() & ((uint64_t(1) << (ones_byte_size * 7)) - 1)) + 1; + for (uint32_t j = 0; j < ones_range_size; ++j) pos_ones.push_back(sum + j); + sum += ones_range_size; + } + for (uint32_t j = 0; j < 5; ++j) { + if (j == 1) pos_ones.insert(pos_ones.begin(), 1, 0); + if (j == 2) pos_ones.clear(); + if (j == 3) pos_ones.push_back(1); + if (j == 4) { pos_ones.clear(); pos_ones.push_back(10); } + for (uint32_t ienc = 0; ienc < 4; ++ienc) { + vector serial_bit_vector; + MemWriter< vector > writer(serial_bit_vector); + BuildCompressedBitVector(writer, pos_ones, ienc); + MemReader reader(serial_bit_vector.data(), serial_bit_vector.size()); + vector dec_pos_ones = DecodeCompressedBitVector(reader); + TEST_EQUAL(pos_ones, dec_pos_ones, ()); + } + } +} + +UNIT_TEST(BitVectors_And) { + vector v1(c_nums_count * 2, false), v2(c_nums_count * 2, false); + for (uint32_t i = 0; i < c_nums_count; ++i) { + v1[GetRand64() % v1.size()] = true; + v2[GetRand64() % v2.size()] = true; + } + vector pos_ones1, pos_ones2, and_pos; + for (uint32_t i = 0; i < v1.size(); ++i) { + if (v1[i]) pos_ones1.push_back(i); + if (v2[i]) pos_ones2.push_back(i); + if (v1[i] && v2[i]) and_pos.push_back(i); + } + vector actual_and_pos = BitVectorsAnd(pos_ones1.begin(), pos_ones1.end(), pos_ones2.begin(), pos_ones2.end()); + TEST_EQUAL(and_pos, actual_and_pos, ()); +} + +UNIT_TEST(BitVectors_Or) { + vector v1(c_nums_count * 2, false), v2(c_nums_count * 2, false); + for (uint32_t i = 0; i < c_nums_count; ++i) { + v1[GetRand64() % v1.size()] = true; + v2[GetRand64() % v2.size()] = true; + } + vector pos_ones1, pos_ones2, or_pos; + for (uint32_t i = 0; i < v1.size(); ++i) { + if (v1[i]) pos_ones1.push_back(i); + if (v2[i]) pos_ones2.push_back(i); + if (v1[i] || v2[i]) or_pos.push_back(i); + } + vector actual_or_pos = BitVectorsOr(pos_ones1.begin(), pos_ones1.end(), pos_ones2.begin(), pos_ones2.end()); + TEST_EQUAL(or_pos, actual_or_pos, ()); +} + +UNIT_TEST(BitVectors_SubAnd) { + vector v1(c_nums_count * 2, false); + uint64_t num_v1_ones = 0; + for (uint32_t i = 0; i < v1.size(); ++i) v1[i] = (GetRand64() % 2) == 0; + vector pos_ones1; + for (uint32_t i = 0; i < v1.size(); ++i) if (v1[i]) pos_ones1.push_back(i); + vector v2(pos_ones1.size(), false); + for (uint32_t i = 0; i < v2.size(); ++i) v2[i] = (GetRand64() % 2) == 0; + vector pos_ones2, suband_pos; + for (uint32_t i = 0; i < v2.size(); ++i) if (v2[i]) pos_ones2.push_back(i); + for (uint32_t i = 0, j = 0; i < v1.size(); ++i) { + if (v1[i]) { + if (v2[j]) suband_pos.push_back(i); + ++j; + } + } + vector actual_suband_pos = BitVectorsSubAnd(pos_ones1.begin(), pos_ones1.end(), pos_ones2.begin(), pos_ones2.end()); + TEST_EQUAL(suband_pos, actual_suband_pos, ()); +} diff --git a/coding/compressed_bit_vector.cpp b/coding/compressed_bit_vector.cpp new file mode 100644 index 0000000000..fbb1d8a07b --- /dev/null +++ b/coding/compressed_bit_vector.cpp @@ -0,0 +1,552 @@ +#include "compressed_bit_vector.hpp" + +#include "arithmetic_codec.hpp" +#include "reader.hpp" +#include "writer.hpp" + +#include "../base/assert.hpp" + +using std::vector; + +namespace { + void VarintEncode(vector & dst, uint64_t n) { + if (n == 0) { + dst.push_back(0); + } else { + while (n != 0) { + uint8_t b = n & 0x7F; + n >>= 7; + b |= n == 0 ? 0 : 0x80; + dst.push_back(b); + } + } + } + void VarintEncode(Writer & writer, uint64_t n) { + if (n == 0) { + writer.Write(&n, 1); + } else { + while (n != 0) { + uint8_t b = n & 0x7F; + n >>= 7; + b |= n == 0 ? 0 : 0x80; + writer.Write(&b, 1); + } + } + } + uint64_t VarintDecode(void * src, uint64_t & offset) { + uint64_t n = 0; + int shift = 0; + while (1) { + uint8_t b = *(((uint8_t*)src) + offset); + ASSERT_LESS_OR_EQUAL(shift, 56, ()); + n |= uint64_t(b & 0x7F) << shift; + ++offset; + if ((b & 0x80) == 0) break; + shift += 7; + } + return n; + } + uint64_t VarintDecode(Reader & reader, uint64_t & offset) { + uint64_t n = 0; + int shift = 0; + while (1) { + uint8_t b = 0; + reader.Read(offset, &b, 1); + ASSERT_LESS_OR_EQUAL(shift, 56, ()); + n |= uint64_t(b & 0x7F) << shift; + ++offset; + if ((b & 0x80) == 0) break; + shift += 7; + } + return n; + } + + inline uint32_t NumUsedBits(uint64_t n) { + uint32_t result = 0; + while (n != 0) { ++result; n >>= 1; } + return result; + } + vector SerialFreqsToDistrTable(Reader & reader, uint64_t & decode_offset, uint64_t cnt) { + vector freqs; + for (uint64_t i = 0; i < cnt; ++i) freqs.push_back(VarintDecode(reader, decode_offset)); + return FreqsToDistrTable(freqs); + } +} + +class BitWriter { +public: + BitWriter(Writer & _writer) + : writer_(_writer), last_byte_(0), size_(0) {} + ~BitWriter() { Finalize(); } + uint64_t NumBitsWritten() const { return size_; } + void Write(uint64_t bits, uint32_t write_size) { + if (write_size == 0) return; + total_bits_ += write_size; + uint32_t rem_size = size_ % 8; + ASSERT_LESS_OR_EQUAL(write_size, 64 - rem_size, ()); + if (rem_size > 0) { + bits <<= rem_size; + bits |= last_byte_; + write_size += rem_size; + size_ -= rem_size; + } + uint32_t write_bytes_size = write_size / 8; + writer_.Write(&bits, write_bytes_size); + last_byte_ = (bits >> (write_bytes_size * 8)) & ((1 << (write_size % 8)) - 1); + size_ += write_size; + } + void Finalize() { if (size_ % 8 > 0) writer_.Write(&last_byte_, 1); } +private: + Writer & writer_; + uint8_t last_byte_; + uint64_t size_; + uint64_t total_bits_; +}; + +class BitReader { +public: + BitReader(Reader & reader) + : reader_(reader), serial_cur_(0), serial_end_(reader.Size()), + bits_(0), bits_size_(0), total_bits_read_(0) {} + uint64_t NumBitsWritten() const { return total_bits_read_; } + uint64_t Read(uint32_t read_size) { + total_bits_read_ += read_size; + if (read_size == 0) return 0; + ASSERT_LESS_OR_EQUAL(read_size, 64, ()); + // First read, sets bits that are in the bits_ buffer. + uint32_t first_read_size = read_size <= bits_size_ ? read_size : bits_size_; + uint64_t result = bits_ & (~uint64_t(0) >> (64 - first_read_size)); + bits_ >>= first_read_size; + bits_size_ -= first_read_size; + read_size -= first_read_size; + // Second read, does an extra read using reader_. + if (read_size > 0) { + uint32_t read_byte_size = serial_cur_ + sizeof(bits_) <= serial_end_ ? sizeof(bits_) : serial_end_ - serial_cur_; + reader_.Read(serial_cur_, &bits_, read_byte_size); + serial_cur_ += read_byte_size; + bits_size_ += read_byte_size * 8; + if (read_size > bits_size_) ASSERT_LESS_OR_EQUAL(read_size, bits_size_, ()); + result |= (bits_ & (~uint64_t(0) >> (64 - read_size))) << first_read_size; + bits_ >>= read_size; + bits_size_ -= read_size; + read_size = 0; + } + return result; + } +private: + Reader & reader_; + uint64_t serial_cur_; + uint64_t serial_end_; + uint64_t bits_; + uint32_t bits_size_; + uint64_t total_bits_read_; +}; + +void BuildCompressedBitVector(Writer & writer, vector const & pos_ones, int chosen_enc_type) { + uint32_t const c_block_size = 7; + // First stage of compression is analysis run through data ones. + uint64_t num_bytes_diffs_enc_vint = 0, num_bytes_ranges_enc_vint = 0, num_bits_diffs_enc_arith = 0, num_bits_ranges_enc_arith = 0; + int64_t prev_one_pos = -1; + uint64_t ones_range_len = 0; + vector diffs_sizes_freqs(65, 0), ranges0_sizes_freqs(65, 0), ranges1_sizes_freqs(65, 0); + for (uint32_t i = 0; i < pos_ones.size(); ++i) { + ASSERT_LESS(prev_one_pos, pos_ones[i], ()); + // Accumulate size of diff encoding. + uint64_t diff = pos_ones[i] - prev_one_pos; + uint32_t diff_bitsize = NumUsedBits(diff - 1); + num_bytes_diffs_enc_vint += (diff_bitsize + c_block_size - 1) / c_block_size; + num_bits_diffs_enc_arith += diff_bitsize > 0 ? diff_bitsize - 1 : 0; + ++diffs_sizes_freqs[diff_bitsize]; + // Accumulate sizes of ranges encoding. + if (pos_ones[i] - prev_one_pos > 1) { + if (ones_range_len > 0) { + // Accumulate size of ones-range encoding. + uint32_t ones_range_len_bitsize = NumUsedBits(ones_range_len - 1); + num_bytes_ranges_enc_vint += (ones_range_len_bitsize + c_block_size - 1) / c_block_size; + num_bits_ranges_enc_arith += ones_range_len_bitsize > 0 ? ones_range_len_bitsize - 1 : 0; + ++ranges1_sizes_freqs[ones_range_len_bitsize]; + ones_range_len = 0; + } + // Accumulate size of zeros-range encoding. + uint32_t zeros_range_len_bitsize = NumUsedBits(pos_ones[i] - prev_one_pos - 2); + num_bytes_ranges_enc_vint += (zeros_range_len_bitsize + c_block_size - 1) / c_block_size; + num_bits_ranges_enc_arith += zeros_range_len_bitsize > 0 ? zeros_range_len_bitsize - 1 : 0; + ++ranges0_sizes_freqs[zeros_range_len_bitsize]; + } + ++ones_range_len; + prev_one_pos = pos_ones[i]; + } + // Accumulate size of remaining ones-range encoding. + if (ones_range_len > 0) { + uint32_t ones_range_len_bitsize = NumUsedBits(ones_range_len - 1); + num_bytes_ranges_enc_vint += (ones_range_len_bitsize + c_block_size - 1) / c_block_size; + num_bits_ranges_enc_arith = ones_range_len_bitsize > 0 ? ones_range_len_bitsize - 1 : 0; + ++ranges1_sizes_freqs[ones_range_len_bitsize]; + ones_range_len = 0; + } + // Compute arithmetic encoding size. + uint64_t diffs_sizes_total_freq = 0, ranges0_sizes_total_freq = 0, ranges1_sizes_total_freq = 0; + for (uint32_t i = 0; i < diffs_sizes_freqs.size(); ++i) diffs_sizes_total_freq += diffs_sizes_freqs[i]; + for (uint32_t i = 0; i < ranges0_sizes_freqs.size(); ++i) ranges0_sizes_total_freq += ranges0_sizes_freqs[i]; + for (uint32_t i = 0; i < ranges1_sizes_freqs.size(); ++i) ranges1_sizes_total_freq += ranges1_sizes_freqs[i]; + // Compute number of bits for arith encoded diffs sizes. + double num_sizes_bits_diffs_enc_arith = 0; + uint32_t nonzero_diffs_sizes_freqs_end = 0; + for (uint32_t i = 0; i < diffs_sizes_freqs.size(); ++i) { + if (diffs_sizes_freqs[i] > 0) { + double prob = double(diffs_sizes_freqs[i]) / diffs_sizes_total_freq; + num_sizes_bits_diffs_enc_arith += - prob * log(prob) / log(2); + nonzero_diffs_sizes_freqs_end = i + 1; + } + } + vector diffs_sizes_freqs_serial; + for (uint32_t i = 0; i < nonzero_diffs_sizes_freqs_end; ++i) VarintEncode(diffs_sizes_freqs_serial, diffs_sizes_freqs[i]); + uint64_t num_bytes_diffs_enc_arith = 4 + diffs_sizes_freqs_serial.size() + (uint64_t(num_sizes_bits_diffs_enc_arith * diffs_sizes_total_freq + 0.999) + 7) / 8 + (num_bits_diffs_enc_arith + 7) /8; + // Compute number of bits for arith encoded ranges sizes. + double num_sizes_bits_ranges0_enc_arith = 0; + uint32_t nonzero_ranges0_sizes_freqs_end = 0; + for (uint32_t i = 0; i < ranges0_sizes_freqs.size(); ++i) { + if (ranges0_sizes_freqs[i] > 0) { + double prob = double(ranges0_sizes_freqs[i]) / ranges0_sizes_total_freq; + num_sizes_bits_ranges0_enc_arith += - prob * log(prob) / log(2); + nonzero_ranges0_sizes_freqs_end = i + 1; + } + } + double num_sizes_bits_ranges1_enc_arith = 0; + uint32_t nonzero_ranges1_sizes_freqs_end = 0; + for (uint32_t i = 0; i < ranges1_sizes_freqs.size(); ++i) { + if (ranges1_sizes_freqs[i] > 0) { + double prob = double(ranges1_sizes_freqs[i]) / ranges1_sizes_total_freq; + num_sizes_bits_ranges1_enc_arith += - prob * log(prob) / log(2); + nonzero_ranges1_sizes_freqs_end = i + 1; + } + } + vector ranges0_sizes_freqs_serial, ranges1_sizes_freqs_serial; + for (uint32_t i = 0; i < nonzero_ranges0_sizes_freqs_end; ++i) VarintEncode(ranges0_sizes_freqs_serial, ranges0_sizes_freqs[i]); + for (uint32_t i = 0; i < nonzero_ranges1_sizes_freqs_end; ++i) VarintEncode(ranges1_sizes_freqs_serial, ranges1_sizes_freqs[i]); + uint64_t num_bytes_ranges_enc_arith = 4 + ranges0_sizes_freqs_serial.size() + ranges1_sizes_freqs_serial.size() + + (uint64_t(num_sizes_bits_ranges0_enc_arith * ranges0_sizes_total_freq + 0.999) + 7) / 8 + (uint64_t(num_sizes_bits_ranges1_enc_arith * ranges1_sizes_total_freq + 0.999) + 7) / 8 + + (num_bits_ranges_enc_arith + 7) / 8; + + // Find minimum among 4 types of encoding. + vector num_bytes_per_enc = {num_bytes_diffs_enc_vint, num_bytes_ranges_enc_vint, num_bytes_diffs_enc_arith, num_bytes_ranges_enc_arith}; + uint32_t enc_type = 0; + if (chosen_enc_type != -1) { ASSERT(0 <= chosen_enc_type && chosen_enc_type <= 3, ()); enc_type = chosen_enc_type; } + else if (num_bytes_per_enc[0] <= num_bytes_per_enc[1] && num_bytes_per_enc[0] <= num_bytes_per_enc[2] && num_bytes_per_enc[0] <= num_bytes_per_enc[3]) enc_type = 0; + else if (num_bytes_per_enc[1] <= num_bytes_per_enc[0] && num_bytes_per_enc[1] <= num_bytes_per_enc[2] && num_bytes_per_enc[1] <= num_bytes_per_enc[3]) enc_type = 1; + else if (num_bytes_per_enc[2] <= num_bytes_per_enc[0] && num_bytes_per_enc[2] <= num_bytes_per_enc[1] && num_bytes_per_enc[2] <= num_bytes_per_enc[3]) enc_type = 2; + else if (num_bytes_per_enc[3] <= num_bytes_per_enc[0] && num_bytes_per_enc[3] <= num_bytes_per_enc[1] && num_bytes_per_enc[3] <= num_bytes_per_enc[2]) enc_type = 3; + + if (enc_type == 0) { + // Diffs-Varint encoding. + + int64_t prev_one_pos = -1; + bool is_empty = pos_ones.empty(); + // Encode encoding type and first diff. + if (is_empty) { + VarintEncode(writer, enc_type + (1 << 2)); + } else { + VarintEncode(writer, enc_type + (0 << 2) + ((pos_ones[0] - prev_one_pos - 1) << 3)); + prev_one_pos = pos_ones[0]; + } + for (uint32_t i = 1; i < pos_ones.size(); ++i) { + ASSERT_GREATER(pos_ones[i], prev_one_pos, ()); + // Encode one's pos (diff - 1). + VarintEncode(writer, pos_ones[i] - prev_one_pos - 1); + prev_one_pos = pos_ones[i]; + } + } else if (enc_type == 2) { + // Diffs-Arith encoding. + + // Encode encoding type plus number of freqs in the table. + VarintEncode(writer, enc_type + (nonzero_diffs_sizes_freqs_end << 2)); + // Encode freqs table. + writer.Write(diffs_sizes_freqs_serial.data(), diffs_sizes_freqs_serial.size()); + uint64_t tmp_offset = 0; + MemReader diffs_sizes_freqs_serial_reader(diffs_sizes_freqs_serial.data(), diffs_sizes_freqs_serial.size()); + vector distr_table = SerialFreqsToDistrTable( + diffs_sizes_freqs_serial_reader, tmp_offset, nonzero_diffs_sizes_freqs_end + ); + + { + // First stage. Encode all bits sizes of all diffs using ArithmeticEncoder. + ArithmeticEncoder arith_enc(distr_table); + int64_t prev_one_pos = -1; + uint64_t cnt_elements = 0; + for (uint64_t i = 0; i < pos_ones.size(); ++i) { + ASSERT_GREATER(pos_ones[i], prev_one_pos, ()); + uint32_t bits_used = NumUsedBits(pos_ones[i] - prev_one_pos - 1); + arith_enc.Encode(bits_used); + ++cnt_elements; + prev_one_pos = pos_ones[i]; + } + vector serial_sizes_enc = arith_enc.Finalize(); + // Store number of compressed elements. + VarintEncode(writer, cnt_elements); + // Store compressed size of encoded sizes. + VarintEncode(writer, serial_sizes_enc.size()); + // Store serial sizes. + writer.Write(serial_sizes_enc.data(), serial_sizes_enc.size()); + } + { + // Second Stage. Encode all bits of all diffs using BitWriter. + BitWriter bit_writer(writer); + int64_t prev_one_pos = -1; + uint64_t total_read_bits = 0; + uint64_t total_read_cnts = 0; + for (uint64_t i = 0; i < pos_ones.size(); ++i) { + ASSERT_GREATER(pos_ones[i], prev_one_pos, ()); + // Encode one's pos (diff - 1). + uint64_t diff = pos_ones[i] - prev_one_pos - 1; + uint32_t bits_used = NumUsedBits(diff); + if (bits_used > 1) { + // Most significant bit is always 1 for non-zero diffs, so don't store it. + --bits_used; + bit_writer.Write(diff, bits_used); + total_read_bits += bits_used; + ++total_read_cnts; + } + prev_one_pos = pos_ones[i]; + } + } + } else if (enc_type == 1) { + // Ranges-Varint encoding. + + // If bit vector starts with 1. + bool is_first_one = pos_ones.size() > 0 && pos_ones.front() == 0; + // Encode encoding type plus flag if first is 1. + VarintEncode(writer, enc_type + ((is_first_one ? 1 : 0) << 2)); + int64_t prev_one_pos = -1; + uint64_t ones_range_len = 0; + for (uint32_t i = 0; i < pos_ones.size(); ++i) { + ASSERT_GREATER(pos_ones[i], prev_one_pos, ()); + if (pos_ones[i] - prev_one_pos > 1) { + if (ones_range_len > 0) { + // Encode ones range size - 1. + VarintEncode(writer, ones_range_len - 1); + ones_range_len = 0; + } + // Encode zeros range size - 1. + VarintEncode(writer, pos_ones[i] - prev_one_pos - 2); + } + ++ones_range_len; + prev_one_pos = pos_ones[i]; + } + if (ones_range_len > 0) { + // Encode last ones range size. + VarintEncode(writer, ones_range_len - 1); + ones_range_len = 0; + } + } else if (enc_type == 3) { + // Ranges-Arith encoding. + + // If bit vector starts with 1. + bool is_first_one = pos_ones.size() > 0 && pos_ones.front() == 0; + // Encode encoding type plus flag if first is 1 plus count of sizes freqs. + VarintEncode(writer, enc_type + ((is_first_one ? 1 : 0) << 2) + (nonzero_ranges0_sizes_freqs_end << 3)); + VarintEncode(writer, nonzero_ranges1_sizes_freqs_end); + // Encode freqs table. + writer.Write(ranges0_sizes_freqs_serial.data(), ranges0_sizes_freqs_serial.size()); + writer.Write(ranges1_sizes_freqs_serial.data(), ranges1_sizes_freqs_serial.size()); + // Create distr tables. + uint64_t tmp_offset = 0; + MemReader ranges0_sizes_freqs_serial_reader(ranges0_sizes_freqs_serial.data(), ranges0_sizes_freqs_serial.size()); + vector distr_table0 = SerialFreqsToDistrTable( + ranges0_sizes_freqs_serial_reader, tmp_offset, nonzero_ranges0_sizes_freqs_end + ); + tmp_offset = 0; + MemReader ranges1_sizes_freqs_serial_reader(ranges1_sizes_freqs_serial.data(), ranges1_sizes_freqs_serial.size()); + vector distr_table1 = SerialFreqsToDistrTable( + ranges1_sizes_freqs_serial_reader, tmp_offset, nonzero_ranges1_sizes_freqs_end + ); + + { + // First stage, encode all ranges bits sizes using ArithmeticEncoder. + + // Encode number of compressed elements. + ArithmeticEncoder arith_enc0(distr_table0), arith_enc1(distr_table1); + int64_t prev_one_pos = -1; + uint64_t ones_range_len = 0; + // Total number of compressed elements (ranges sizes). + uint64_t cnt_elements0 = 0, cnt_elements1 = 0; + for (uint32_t i = 0; i < pos_ones.size(); ++i) { + ASSERT_GREATER(pos_ones[i], prev_one_pos, ()); + if (pos_ones[i] - prev_one_pos > 1) { + if (ones_range_len > 0) { + // Encode ones range bits size. + uint32_t bits_used = NumUsedBits(ones_range_len - 1); + arith_enc1.Encode(bits_used); + ++cnt_elements1; + ones_range_len = 0; + } + // Encode zeros range bits size - 1. + uint32_t bits_used = NumUsedBits(pos_ones[i] - prev_one_pos - 2); + arith_enc0.Encode(bits_used); + ++cnt_elements0; + } + ++ones_range_len; + prev_one_pos = pos_ones[i]; + } + if (ones_range_len > 0) { + // Encode last ones range size - 1. + uint32_t bits_used = NumUsedBits(ones_range_len - 1); + arith_enc1.Encode(bits_used); + ++cnt_elements1; + ones_range_len = 0; + } + vector serial0_sizes_enc = arith_enc0.Finalize(), serial1_sizes_enc = arith_enc1.Finalize(); + // Store number of compressed elements. + VarintEncode(writer, cnt_elements0); + VarintEncode(writer, cnt_elements1); + // Store size of encoded bits sizes. + VarintEncode(writer, serial0_sizes_enc.size()); + VarintEncode(writer, serial1_sizes_enc.size()); + // Store serial sizes. + writer.Write(serial0_sizes_enc.data(), serial0_sizes_enc.size()); + writer.Write(serial1_sizes_enc.data(), serial1_sizes_enc.size()); + } + + { + // Second stage, encode all ranges bits using BitWriter. + BitWriter bit_writer(writer); + int64_t prev_one_pos = -1; + uint64_t ones_range_len = 0; + for (uint32_t i = 0; i < pos_ones.size(); ++i) { + ASSERT_GREATER(pos_ones[i], prev_one_pos, ()); + if (pos_ones[i] - prev_one_pos > 1) { + if (ones_range_len > 0) { + // Encode ones range bits size. + uint32_t bits_used = NumUsedBits(ones_range_len - 1); + if (bits_used > 1) { + // Most significant bit for non-zero values is always 1, don't encode it. + --bits_used; + bit_writer.Write(ones_range_len - 1, bits_used); + } + ones_range_len = 0; + } + // Encode zeros range bits size - 1. + uint32_t bits_used = NumUsedBits(pos_ones[i] - prev_one_pos - 2); + if (bits_used > 1) { + // Most significant bit for non-zero values is always 1, don't encode it. + --bits_used; + bit_writer.Write(pos_ones[i] - prev_one_pos - 2, bits_used); + } + } + ++ones_range_len; + prev_one_pos = pos_ones[i]; + } + if (ones_range_len > 0) { + // Encode last ones range size - 1. + uint32_t bits_used = NumUsedBits(ones_range_len - 1); + if (bits_used > 1) { + // Most significant bit for non-zero values is always 1, don't encode it. + --bits_used; + bit_writer.Write(ones_range_len - 1, bits_used); + } + ones_range_len = 0; + } + } + } +} + +vector DecodeCompressedBitVector(Reader & reader) { + uint64_t serial_size = reader.Size(); + vector pos_ones; + uint64_t decode_offset = 0; + uint64_t header = VarintDecode(reader, decode_offset); + uint32_t enc_type = header & 3; + ASSERT_LESS(enc_type, 4, ()); + if (enc_type == 0) { + // Diffs-Varint encoded. + int64_t prev_one_pos = -1; + // For non-empty vectors first diff is taken from header number. + bool is_empty = (header & 4) != 0; + if (!is_empty) { + pos_ones.push_back(header >> 3); + prev_one_pos = pos_ones.back(); + } + while (decode_offset < serial_size) { + pos_ones.push_back(prev_one_pos + VarintDecode(reader, decode_offset) + 1); + prev_one_pos = pos_ones.back(); + } + } else if (enc_type == 2) { + // Diffs-Arith encoded. + uint64_t freqs_cnt = header >> 2; + vector distr_table = SerialFreqsToDistrTable(reader, decode_offset, freqs_cnt); + uint64_t cnt_elements = VarintDecode(reader, decode_offset); + uint64_t enc_sizes_bytesize = VarintDecode(reader, decode_offset); + vector bits_used_vec; + Reader * arith_dec_reader = reader.CreateSubReader(decode_offset, enc_sizes_bytesize); + ArithmeticDecoder arith_dec(*arith_dec_reader, distr_table); + for (uint64_t i = 0; i < cnt_elements; ++i) bits_used_vec.push_back(arith_dec.Decode()); + decode_offset += enc_sizes_bytesize; + Reader * bit_reader_reader = reader.CreateSubReader(decode_offset, serial_size - decode_offset); + BitReader bit_reader(*bit_reader_reader); + int64_t prev_one_pos = -1; + for (uint64_t i = 0; i < cnt_elements; ++i) { + uint32_t bits_used = bits_used_vec[i]; + uint64_t diff = 0; + if (bits_used > 0) diff = ((uint64_t(1) << (bits_used - 1)) | bit_reader.Read(bits_used - 1)) + 1; else diff = 1; + pos_ones.push_back(prev_one_pos + diff); + prev_one_pos += diff; + } + decode_offset = serial_size; + } else if (enc_type == 1) { + // Ranges-Varint encoding. + + // If bit vector starts with 1. + bool is_first_one = ((header >> 2) & 1) == 1; + uint64_t sum = 0; + while (decode_offset < serial_size) { + uint64_t zeros_range_size = 0; + // Don't read zero range size for the first time if first bit is 1. + if (!is_first_one) zeros_range_size = VarintDecode(reader, decode_offset) + 1; else is_first_one = false; + uint64_t ones_range_size = VarintDecode(reader, decode_offset) + 1; + sum += zeros_range_size; + for (uint64_t i = sum; i < sum + ones_range_size; ++i) pos_ones.push_back(i); + sum += ones_range_size; + } + } else if (enc_type == 3) { + // Ranges-Arith encoding. + + // If bit vector starts with 1. + bool is_first_one = ((header >> 2) & 1) == 1; + uint64_t freqs0_cnt = header >> 3, freqs1_cnt = VarintDecode(reader, decode_offset); + vector distr_table0 = SerialFreqsToDistrTable(reader, decode_offset, freqs0_cnt); + vector distr_table1 = SerialFreqsToDistrTable(reader, decode_offset, freqs1_cnt); + uint64_t cnt_elements0 = VarintDecode(reader, decode_offset), cnt_elements1 = VarintDecode(reader, decode_offset); + uint64_t enc0_sizes_bytesize = VarintDecode(reader, decode_offset), enc1_sizes_bytesize = VarintDecode(reader, decode_offset); + Reader * arith_dec0_reader = reader.CreateSubReader(decode_offset, enc0_sizes_bytesize); + ArithmeticDecoder arith_dec0(*arith_dec0_reader, distr_table0); + vector bits_sizes0; + for (uint64_t i = 0; i < cnt_elements0; ++i) bits_sizes0.push_back(arith_dec0.Decode()); + decode_offset += enc0_sizes_bytesize; + Reader * arith_dec1_reader = reader.CreateSubReader(decode_offset, enc1_sizes_bytesize); + ArithmeticDecoder arith_dec1(*arith_dec1_reader, distr_table1); + vector bits_sizes1; + for (uint64_t i = 0; i < cnt_elements1; ++i) bits_sizes1.push_back(arith_dec1.Decode()); + decode_offset += enc1_sizes_bytesize; + Reader * bit_reader_reader = reader.CreateSubReader(decode_offset, serial_size - decode_offset); + BitReader bit_reader(*bit_reader_reader); + uint64_t sum = 0, i0 = 0, i1 = 0; + while (i0 < cnt_elements0 && i1 < cnt_elements1) { + uint64_t zeros_range_size = 0; + // Don't read zero range size for the first time if first bit is 1. + if (!is_first_one) { + uint32_t bits_used = bits_sizes0[i0]; + if (bits_used > 0) zeros_range_size = ((uint64_t(1) << (bits_used - 1)) | bit_reader.Read(bits_used - 1)) + 1; else zeros_range_size = 1; + ++i0; + } else is_first_one = false; + uint64_t ones_range_size = 0; + uint32_t bits_used = bits_sizes1[i1]; + if (bits_used > 0) ones_range_size = ((uint64_t(1) << (bits_used - 1)) | bit_reader.Read(bits_used - 1)) + 1; else ones_range_size = 1; + ++i1; + sum += zeros_range_size; + for (uint64_t j = sum; j < sum + ones_range_size; ++j) pos_ones.push_back(j); + sum += ones_range_size; + } + ASSERT(i0 == cnt_elements0 && i1 == cnt_elements1, ()); + decode_offset = serial_size; + } + return pos_ones; +} diff --git a/coding/compressed_bit_vector.hpp b/coding/compressed_bit_vector.hpp new file mode 100644 index 0000000000..049b87c4ac --- /dev/null +++ b/coding/compressed_bit_vector.hpp @@ -0,0 +1,120 @@ +// Author: Artyom. +// Module for compressing/decompressing bit vectors. +// Usage: +// vector compr_bits1; +// MemWriter< vector > writer(compr_bits1); +// // Create a bit vector by storing increasing positions of ones. +// vector pos_ones1 = {12, 34, 75}, pos_ones2 = {10, 34, 95}; +// // Compress some vectors. +// BuildCompressedBitVector(writer, pos_ones1); +// MemReader reader(compr_bits1.data(), compr_bits1.size()); +// // Decompress compressed vectors before operations. +// MemReader reader(compr_bits1.data(), compr_bits1.size()); +// pos_ones1 = DecodeCompressedBitVector(reader); +// // Intersect two vectors. +// vector and_res = BitVectorsAnd(pos_ones1.begin(), pos_ones1.end(), pos_ones2.begin(), pos_ones2.end()); +// // Unite two vectors. +// vector or_res = BitVectorsAnd(pos_ones1.begin(), pos_ones1.end(), pos_ones2.begin(), pos_ones2.end()); +// // Sub-and two vectors (second vector-set is a subset of first vector-set as bit vectors, +// // so that second vector size should be equal to number of ones of the first vector). +// vector suband_res = BitVectorsSubAnd(pos_ones1.begin(), pos_ones1.end(), pos_ones2.begin(), pos_ones2.end()); + +#pragma once + +#include "../base/assert.hpp" +#include "../std/stdint.hpp" +#include "../std/vector.hpp" + +// Forward declare used Reader/Writer. +class Reader; +class Writer; + +// Build compressed bit vector from vector of ones bits positions, you may provide chosen_enc_type - encoding +// type of the result, otherwise encoding type is chosen to achieve maximum compression. +// Encoding types are: 0 - Diffs/Varint, 1 - Ranges/Varint, 2 - Diffs/Arith, 3 - Ranges/Arith. +// ("Diffs" creates a compressed array of pos diffs between ones inside source bit vector, +// "Ranges" creates a compressed array of lengths of zeros and ones ranges, +// "Varint" encodes resulting sizes using varint encoding, +// "Arith" encodes resulting sizes using arithmetic encoding). +void BuildCompressedBitVector(Writer & writer, std::vector const & pos_ones, int chosen_enc_type = -1); +// Decodes compressed bit vector to uncompressed array of ones positions. +std::vector DecodeCompressedBitVector(Reader & reader); + +// Intersects two bit vectors based on theirs begin and end iterators. +// Returns resulting positions of ones. +template +std::vector BitVectorsAnd(It1T begin1, It1T end1, It2T begin2, It2T end2) { + std::vector result; + + It1T it1 = begin1; + It2T it2 = begin2; + while (it1 != end1 && it2 != end2) { + uint32_t pos1 = *it1, pos2 = *it2; + if (pos1 == pos2) { + result.push_back(pos1); + ++it1; + ++it2; + } else if (pos1 < pos2) { ++it1; } + else if (pos1 > pos2) { ++it2; } + } + return result; +} + +// Unites two bit vectors based on theirs begin and end iterators. +// Returns resulting positions of ones. +template +std::vector BitVectorsOr(It1T begin1, It1T end1, It2T begin2, It2T end2) { + std::vector result; + + It1T it1 = begin1; + It2T it2 = begin2; + while (it1 != end1 && it2 != end2) { + uint32_t pos1 = *it1, pos2 = *it2; + if (pos1 == pos2) { + result.push_back(pos1); + ++it1; + ++it2; + } else if (pos1 < pos2) { + result.push_back(pos1); + ++it1; + } else if (pos1 > pos2) { + result.push_back(pos2); + ++it2; + } + } + if (it2 == end2) { + while (it1 != end1) { + uint32_t pos1 = *it1; + result.push_back(pos1); + ++it1; + } + } else { + while (it2 != end2) { + uint32_t pos2 = *it2; + result.push_back(pos2); + ++it2; + } + } + return result; +} + +// Intersects first vector with second vector, when second vector is a subset of the first vector, +// second bit vector should have size equal to first vector's number of ones. +// Returns resulting positions of ones. +template +std::vector BitVectorsSubAnd(It1T begin1, It1T end1, It2T begin2, It2T end2) { + std::vector result; + + It1T it1 = begin1; + It2T it2 = begin2; + uint64_t index2 = 0; + for (; it1 != end1 && it2 != end2; ++it1, ++index2) { + uint32_t pos1 = *it1, pos2 = *it2; + if (pos2 == index2) { + result.push_back(pos1); + ++it2; + } + } + ASSERT((it2 == end2), ()); + return result; +}