diff --git a/base/bits.hpp b/base/bits.hpp index cbae1baa4f..6b40d2406d 100644 --- a/base/bits.hpp +++ b/base/bits.hpp @@ -63,9 +63,13 @@ namespace bits inline uint32_t PopCount(uint64_t x) { - uint32_t lower = static_cast(x); - uint32_t higher = static_cast(x >> 32); - return PopCount(lower) + PopCount(higher); + x = (x & 0x5555555555555555) + ((x & 0xAAAAAAAAAAAAAAAA) >> 1); + x = (x & 0x3333333333333333) + ((x & 0xCCCCCCCCCCCCCCCC) >> 2); + x = (x & 0x0F0F0F0F0F0F0F0F) + ((x & 0xF0F0F0F0F0F0F0F0) >> 4); + x = (x & 0x00FF00FF00FF00FF) + ((x & 0xFF00FF00FF00FF00) >> 8); + x = (x & 0x0000FFFF0000FFFF) + ((x & 0xFFFF0000FFFF0000) >> 16); + x = x + (x >> 32); + return static_cast(x); } // Will be implemented when needed. diff --git a/coding/coding_tests/compressed_bit_vector_test.cpp b/coding/coding_tests/compressed_bit_vector_test.cpp index eae105ef79..74d9f9bb55 100644 --- a/coding/coding_tests/compressed_bit_vector_test.cpp +++ b/coding/coding_tests/compressed_bit_vector_test.cpp @@ -5,6 +5,7 @@ #include "std/algorithm.hpp" #include "std/iterator.hpp" +#include "std/set.hpp" namespace { @@ -18,27 +19,28 @@ void CheckIntersection(vector & setBits1, vector & setBits2, set_intersection(setBits1.begin(), setBits1.end(), setBits2.begin(), setBits2.end(), back_inserter(expected)); TEST_EQUAL(expected.size(), cbv->PopCount(), ()); + vector expectedBitmap(expected.back() + 1); for (size_t i = 0; i < expected.size(); ++i) - TEST(cbv->GetBit(expected[i]), ()); + expectedBitmap[expected[i]] = true; + for (size_t i = 0; i < expectedBitmap.size(); ++i) + TEST_EQUAL(cbv->GetBit(i), expectedBitmap[i], ()); } } // namespace -UNIT_TEST(CompressedBitVector_Smoke) {} - UNIT_TEST(CompressedBitVector_Intersect1) { - size_t const n = 100; + size_t const kNumBits = 100; vector setBits1; vector setBits2; - for (size_t i = 0; i < n; ++i) + for (size_t i = 0; i < kNumBits; ++i) { if (i > 0) setBits1.push_back(i); - if (i + 1 < n) + if (i + 1 < kNumBits) setBits2.push_back(i); } - auto cbv1 = coding::CompressedBitVectorBuilder::Build(setBits1); - auto cbv2 = coding::CompressedBitVectorBuilder::Build(setBits2); + auto cbv1 = coding::CompressedBitVectorBuilder::FromBitPositions(setBits1); + auto cbv2 = coding::CompressedBitVectorBuilder::FromBitPositions(setBits2); TEST(cbv1.get(), ()); TEST(cbv2.get(), ()); auto cbv3 = coding::CompressedBitVector::Intersect(*cbv1, *cbv2); @@ -48,18 +50,18 @@ UNIT_TEST(CompressedBitVector_Intersect1) UNIT_TEST(CompressedBitVector_Intersect2) { - size_t const n = 100; + size_t const kNumBits = 100; vector setBits1; vector setBits2; - for (size_t i = 0; i < n; ++i) + for (size_t i = 0; i < kNumBits; ++i) { - if (i <= n / 2) + if (i <= kNumBits / 2) setBits1.push_back(i); - if (i >= n / 2) + if (i >= kNumBits / 2) setBits2.push_back(i); } - auto cbv1 = coding::CompressedBitVectorBuilder::Build(setBits1); - auto cbv2 = coding::CompressedBitVectorBuilder::Build(setBits2); + auto cbv1 = coding::CompressedBitVectorBuilder::FromBitPositions(setBits1); + auto cbv2 = coding::CompressedBitVectorBuilder::FromBitPositions(setBits2); TEST(cbv1.get(), ()); TEST(cbv2.get(), ()); auto cbv3 = coding::CompressedBitVector::Intersect(*cbv1, *cbv2); @@ -69,69 +71,61 @@ UNIT_TEST(CompressedBitVector_Intersect2) UNIT_TEST(CompressedBitVector_Intersect3) { - size_t const n = 100; + size_t const kNumBits = 100; vector setBits1; vector setBits2; - for (size_t i = 0; i < n; ++i) + for (size_t i = 0; i < kNumBits; ++i) { if (i % 2 == 0) setBits1.push_back(i); if (i % 3 == 0) setBits2.push_back(i); } - auto cbv1 = coding::CompressedBitVectorBuilder::Build(setBits1); - auto cbv2 = coding::CompressedBitVectorBuilder::Build(setBits2); + auto cbv1 = coding::CompressedBitVectorBuilder::FromBitPositions(setBits1); + auto cbv2 = coding::CompressedBitVectorBuilder::FromBitPositions(setBits2); TEST(cbv1.get(), ()); TEST(cbv2.get(), ()); auto cbv3 = coding::CompressedBitVector::Intersect(*cbv1, *cbv2); TEST_EQUAL(coding::CompressedBitVector::StorageStrategy::Sparse, cbv3->GetStorageStrategy(), ()); - for (size_t i = 0; i < n; ++i) - { - bool expected = i % 6 == 0; - TEST_EQUAL(expected, cbv3->GetBit(i), (i)); - } + CheckIntersection(setBits1, setBits2, cbv3); } UNIT_TEST(CompressedBitVector_Intersect4) { - size_t const n = 1000; + size_t const kNumBits = 1000; vector setBits1; vector setBits2; - for (size_t i = 0; i < n; ++i) + for (size_t i = 0; i < kNumBits; ++i) { if (i % 100 == 0) setBits1.push_back(i); if (i % 150 == 0) setBits2.push_back(i); } - auto cbv1 = coding::CompressedBitVectorBuilder::Build(setBits1); - auto cbv2 = coding::CompressedBitVectorBuilder::Build(setBits2); + auto cbv1 = coding::CompressedBitVectorBuilder::FromBitPositions(setBits1); + auto cbv2 = coding::CompressedBitVectorBuilder::FromBitPositions(setBits2); TEST(cbv1.get(), ()); TEST(cbv2.get(), ()); auto cbv3 = coding::CompressedBitVector::Intersect(*cbv1, *cbv2); TEST_EQUAL(coding::CompressedBitVector::StorageStrategy::Sparse, cbv3->GetStorageStrategy(), ()); - for (size_t i = 0; i < n; ++i) - { - bool expected = i % 300 == 0; - TEST_EQUAL(expected, cbv3->GetBit(i), (i)); - } + CheckIntersection(setBits1, setBits2, cbv3); } UNIT_TEST(CompressedBitVector_SerializationDense) { - int const n = 100; + int const kNumBits = 100; vector setBits; - for (size_t i = 0; i < n; ++i) + for (size_t i = 0; i < kNumBits; ++i) setBits.push_back(i); vector buf; { MemWriter> writer(buf); - auto cbv = coding::CompressedBitVectorBuilder::Build(setBits); + auto cbv = coding::CompressedBitVectorBuilder::FromBitPositions(setBits); + TEST_EQUAL(coding::CompressedBitVector::StorageStrategy::Dense, cbv->GetStorageStrategy(), ()); cbv->Serialize(writer); } MemReader reader(buf.data(), buf.size()); - ReaderSource src(reader); - auto cbv = coding::CompressedBitVectorBuilder::Deserialize(src); + auto cbv = coding::CompressedBitVectorBuilder::Deserialize(reader); TEST(cbv.get(), ()); TEST_EQUAL(coding::CompressedBitVector::StorageStrategy::Dense, cbv->GetStorageStrategy(), ()); TEST_EQUAL(setBits.size(), cbv->PopCount(), ()); @@ -141,9 +135,9 @@ UNIT_TEST(CompressedBitVector_SerializationDense) UNIT_TEST(CompressedBitVector_SerializationSparse) { - int const n = 100; + int const kNumBits = 100; vector setBits; - for (size_t i = 0; i < n; ++i) + for (size_t i = 0; i < kNumBits; ++i) { if (i % 10 == 0) setBits.push_back(i); @@ -151,15 +145,49 @@ UNIT_TEST(CompressedBitVector_SerializationSparse) vector buf; { MemWriter> writer(buf); - auto cbv = coding::CompressedBitVectorBuilder::Build(setBits); + auto cbv = coding::CompressedBitVectorBuilder::FromBitPositions(setBits); + TEST_EQUAL(coding::CompressedBitVector::StorageStrategy::Sparse, cbv->GetStorageStrategy(), ()); cbv->Serialize(writer); } MemReader reader(buf.data(), buf.size()); - ReaderSource src(reader); - auto cbv = coding::CompressedBitVectorBuilder::Deserialize(src); + auto cbv = coding::CompressedBitVectorBuilder::Deserialize(reader); TEST(cbv.get(), ()); TEST_EQUAL(coding::CompressedBitVector::StorageStrategy::Sparse, cbv->GetStorageStrategy(), ()); TEST_EQUAL(setBits.size(), cbv->PopCount(), ()); for (size_t i = 0; i < setBits.size(); ++i) TEST(cbv->GetBit(setBits[i]), ()); } + +UNIT_TEST(CompressedBitVector_ForEach) +{ + int const kNumBits = 150; + vector denseBits; + vector sparseBits; + for (size_t i = 0; i < kNumBits; ++i) + { + denseBits.push_back(i); + if (i % 15 == 0) + sparseBits.push_back(i); + } + auto denseCBV = coding::CompressedBitVectorBuilder::FromBitPositions(denseBits); + auto sparseCBV = coding::CompressedBitVectorBuilder::FromBitPositions(sparseBits); + TEST_EQUAL(coding::CompressedBitVector::StorageStrategy::Dense, denseCBV->GetStorageStrategy(), + ()); + TEST_EQUAL(coding::CompressedBitVector::StorageStrategy::Sparse, sparseCBV->GetStorageStrategy(), + ()); + + set denseSet; + uint64_t maxPos = 0; + coding::CompressedBitVectorEnumerator::ForEach(*denseCBV, [&](uint64_t pos) + { + denseSet.insert(pos); + maxPos = max(maxPos, pos); + }); + TEST_EQUAL(denseSet.size(), kNumBits, ()); + TEST_EQUAL(maxPos, kNumBits - 1, ()); + + coding::CompressedBitVectorEnumerator::ForEach(*sparseCBV, [](uint64_t pos) + { + TEST_EQUAL(pos % 15, 0, ()); + }); +} diff --git a/coding/compressed_bit_vector.cpp b/coding/compressed_bit_vector.cpp index 5f4336b324..0b444b42ac 100644 --- a/coding/compressed_bit_vector.cpp +++ b/coding/compressed_bit_vector.cpp @@ -2,24 +2,23 @@ #include "coding/writer.hpp" #include "coding/write_to_sink.hpp" +#include "base/bits.hpp" + #include "std/algorithm.hpp" namespace { +uint64_t const kBlockSize = coding::DenseCBV::kBlockSize; + unique_ptr IntersectImpl(coding::DenseCBV const & a, coding::DenseCBV const & b) { size_t sizeA = a.NumBitGroups(); size_t sizeB = b.NumBitGroups(); - vector resBits; - for (size_t i = 0; i < min(sizeA, sizeB); ++i) - { - uint64_t bitGroup = a.GetBitGroup(i) & b.GetBitGroup(i); - for (size_t j = 0; j < 64; j++) - if (((bitGroup >> j) & 1) > 0) - resBits.push_back(64 * i + j); - } - return coding::CompressedBitVectorBuilder::Build(resBits); + vector resGroups(min(sizeA, sizeB)); + for (size_t i = 0; i < resGroups.size(); ++i) + resGroups[i] = a.GetBitGroup(i) & b.GetBitGroup(i); + return coding::CompressedBitVectorBuilder::FromBitGroups(move(resGroups)); } // The intersection of dense and sparse is always sparse. @@ -71,6 +70,17 @@ unique_ptr IntersectImpl(coding::SparseCBV const & } return make_unique(move(resPos)); } + +// Returns true if a bit vector with popCount bits set out of totalBits +// is fit to be represented as a DenseCBV. Note that we do not +// account for possible irregularities in the distribution of bits. +// In particular, we do not break the bit vector into blocks that are +// stored separately although this might turn out to be a good idea. +bool DenseEnough(uint64_t popCount, uint64_t totalBits) +{ + // Settle at 30% for now. + return popCount * 10 >= totalBits * 3; +} } // namespace namespace coding @@ -79,18 +89,39 @@ DenseCBV::DenseCBV(vector const & setBits) { if (setBits.empty()) { - m_bits.resize(0); + m_bitGroups.resize(0); m_popCount = 0; return; } uint64_t maxBit = setBits[0]; for (size_t i = 1; i < setBits.size(); ++i) maxBit = max(maxBit, setBits[i]); - size_t sz = (maxBit + 64 - 1) / 64; - m_bits.resize(sz); + size_t sz = (maxBit + kBlockSize - 1) / kBlockSize; + m_bitGroups.resize(sz); m_popCount = static_cast(setBits.size()); for (uint64_t pos : setBits) - m_bits[pos / 64] |= static_cast(1) << (pos % 64); + m_bitGroups[pos / kBlockSize] |= static_cast(1) << (pos % kBlockSize); +} + +// static +unique_ptr DenseCBV::BuildFromBitGroups(vector && bitGroups) +{ + unique_ptr cbv(new DenseCBV()); + cbv->m_popCount = 0; + for (size_t i = 0; i < bitGroups.size(); ++i) + cbv->m_popCount += bits::PopCount(bitGroups[i]); + cbv->m_bitGroups = move(bitGroups); + return cbv; +} + +SparseCBV::SparseCBV(vector const & setBits) : m_positions(setBits) +{ + ASSERT(is_sorted(m_positions.begin(), m_positions.end()), ()); +} + +SparseCBV::SparseCBV(vector && setBits) : m_positions(move(setBits)) +{ + ASSERT(is_sorted(m_positions.begin(), m_positions.end()), ()); } uint32_t DenseCBV::PopCount() const { return m_popCount; } @@ -99,16 +130,27 @@ uint32_t SparseCBV::PopCount() const { return m_positions.size(); } bool DenseCBV::GetBit(uint32_t pos) const { - uint64_t bitGroup = GetBitGroup(pos / 64); - return ((bitGroup >> (pos % 64)) & 1) > 0; + uint64_t bitGroup = GetBitGroup(pos / kBlockSize); + return ((bitGroup >> (pos % kBlockSize)) & 1) > 0; } bool SparseCBV::GetBit(uint32_t pos) const { - auto it = lower_bound(m_positions.begin(), m_positions.end(), pos); + auto const it = lower_bound(m_positions.begin(), m_positions.end(), pos); return it != m_positions.end() && *it == pos; } +uint64_t DenseCBV::GetBitGroup(size_t i) const +{ + return i < m_bitGroups.size() ? m_bitGroups[i] : 0; +} + +uint64_t SparseCBV::Select(size_t i) const +{ + ASSERT_LESS(i, m_positions.size(), ()); + return m_positions[i]; +} + CompressedBitVector::StorageStrategy DenseCBV::GetStorageStrategy() const { return CompressedBitVector::StorageStrategy::Dense; @@ -119,22 +161,6 @@ CompressedBitVector::StorageStrategy SparseCBV::GetStorageStrategy() const return CompressedBitVector::StorageStrategy::Sparse; } -template -void DenseCBV::ForEach(F && f) const -{ - for (size_t i = 0; i < m_bits.size(); ++i) - for (size_t j = 0; j < 64; ++j) - if (((m_bits[i] >> j) & 1) > 0) - f(64 * i + j); -} - -template -void SparseCBV::ForEach(F && f) const -{ - for (size_t i = 0; i < m_positions.size(); ++i) - f(m_positions[i]); -} - string DebugPrint(CompressedBitVector::StorageStrategy strat) { switch (strat) @@ -167,16 +193,43 @@ void SparseCBV::Serialize(Writer & writer) const } // static -unique_ptr CompressedBitVectorBuilder::Build(vector const & setBits) +unique_ptr CompressedBitVectorBuilder::FromBitPositions( + vector const & setBits) { if (setBits.empty()) return make_unique(setBits); uint64_t maxBit = setBits[0]; for (size_t i = 1; i < setBits.size(); ++i) maxBit = max(maxBit, setBits[i]); - // 30% occupied is dense enough - if (10 * setBits.size() >= 3 * maxBit) + + if (DenseEnough(setBits.size(), maxBit)) return make_unique(setBits); + + return make_unique(setBits); +} + +// static +unique_ptr CompressedBitVectorBuilder::FromBitGroups( + vector && bitGroups) +{ + while (!bitGroups.empty() && bitGroups.back() == 0) + bitGroups.pop_back(); + if (bitGroups.empty()) + return make_unique(bitGroups); + + uint64_t maxBit = kBlockSize * bitGroups.size() - 1; + uint64_t popCount = 0; + for (size_t i = 0; i < bitGroups.size(); ++i) + popCount += bits::PopCount(bitGroups[i]); + + if (DenseEnough(popCount, maxBit)) + return DenseCBV::BuildFromBitGroups(move(bitGroups)); + + vector setBits; + for (size_t i = 0; i < bitGroups.size(); ++i) + for (size_t j = 0; j < kBlockSize; ++j) + if (((bitGroups[i] >> j) & 1) > 0) + setBits.push_back(kBlockSize * i + j); return make_unique(setBits); } @@ -184,29 +237,28 @@ unique_ptr CompressedBitVectorBuilder::Build(vector CompressedBitVector::Intersect(CompressedBitVector const & lhs, CompressedBitVector const & rhs) { - auto stratA = lhs.GetStorageStrategy(); - auto stratB = rhs.GetStorageStrategy(); - auto stratDense = CompressedBitVector::StorageStrategy::Dense; - auto stratSparse = CompressedBitVector::StorageStrategy::Sparse; - if (stratA == stratDense && stratB == stratDense) + using strat = CompressedBitVector::StorageStrategy; + auto const stratA = lhs.GetStorageStrategy(); + auto const stratB = rhs.GetStorageStrategy(); + if (stratA == strat::Dense && stratB == strat::Dense) { DenseCBV const & a = static_cast(lhs); DenseCBV const & b = static_cast(rhs); return IntersectImpl(a, b); } - if (stratA == stratDense && stratB == stratSparse) + if (stratA == strat::Dense && stratB == strat::Sparse) { DenseCBV const & a = static_cast(lhs); SparseCBV const & b = static_cast(rhs); return IntersectImpl(a, b); } - if (stratA == stratSparse && stratB == stratDense) + if (stratA == strat::Sparse && stratB == strat::Dense) { SparseCBV const & a = static_cast(lhs); DenseCBV const & b = static_cast(rhs); return IntersectImpl(a, b); } - if (stratA == stratSparse && stratB == stratSparse) + if (stratA == strat::Sparse && stratB == strat::Sparse) { SparseCBV const & a = static_cast(lhs); SparseCBV const & b = static_cast(rhs); diff --git a/coding/compressed_bit_vector.hpp b/coding/compressed_bit_vector.hpp index b856f9fc72..26a2182f77 100644 --- a/coding/compressed_bit_vector.hpp +++ b/coding/compressed_bit_vector.hpp @@ -1,7 +1,6 @@ #include "std/vector.hpp" #include "base/assert.hpp" -#include "base/bits.hpp" #include "coding/reader.hpp" #include "coding/writer.hpp" @@ -24,14 +23,20 @@ public: virtual ~CompressedBitVector() = default; - // Executes f for each bit that is set to one using - // the bit's 0-based position as argument. - template - void ForEach(F && f) const; - // Intersects two bit vectors. - static unique_ptr Intersect(CompressedBitVector const &, - CompressedBitVector const &); + // todo(@pimenov) We expect the common use case to be as follows. + // A CBV is created in memory and several CBVs are read and intersected + // with it one by one. The in-memory CBV may initially contain a bit + // for every feature in an mwm and the intersected CBVs are read from + // the leaves of a search trie. + // Therefore an optimization of Intersect comes to mind: make a wrapper + // around TReader that will read a representation of a CBV from disk + // and intersect it bit by bit with the global in-memory CBV bypassing such + // routines as allocating memory and choosing strategy. They all can be called only + // once, namely in the end, when it is needed to pack the in-memory CBV into + // a suitable representation and pass it to the caller. + static unique_ptr Intersect(CompressedBitVector const & lhs, + CompressedBitVector const & rhs); // Returns the number of set bits (population count). virtual uint32_t PopCount() const = 0; @@ -60,80 +65,63 @@ string DebugPrint(CompressedBitVector::StorageStrategy strat); class DenseCBV : public CompressedBitVector { public: + DenseCBV() = default; + // Builds a dense CBV from a list of positions of set bits. DenseCBV(vector const & setBits); - // Builds a dense CBV from a packed bitmap of set bits. - // todo(@pimenov) This behaviour of & and && constructors is extremely error-prone. - DenseCBV(vector && bitMasks) : m_bits(move(bitMasks)) - { - m_popCount = 0; - for (size_t i = 0; i < m_bits.size(); ++i) - m_popCount += bits::PopCount(m_bits[i]); - } + // Not to be confused with the constructor: the semantics + // of the array of integers is completely different. + static unique_ptr BuildFromBitGroups(vector && bitGroups); - ~DenseCBV() = default; + size_t NumBitGroups() const { return m_bitGroups.size(); } - size_t NumBitGroups() const { return m_bits.size(); } + static uint32_t const kBlockSize = 64; template - void ForEach(F && f) const; - - uint64_t GetBitGroup(size_t i) const + void ForEach(F && f) const { - if (i < m_bits.size()) - return m_bits[i]; - return 0; + for (size_t i = 0; i < m_bitGroups.size(); ++i) + for (size_t j = 0; j < kBlockSize; ++j) + if (((m_bitGroups[i] >> j) & 1) > 0) + f(kBlockSize * i + j); } + // Returns 0 if the group number is too large to be contained in m_bits. + uint64_t GetBitGroup(size_t i) const; + // CompressedBitVector overrides: - uint32_t PopCount() const override; - bool GetBit(uint32_t pos) const override; - StorageStrategy GetStorageStrategy() const override; - void Serialize(Writer & writer) const override; private: - vector m_bits; - uint32_t m_popCount; + vector m_bitGroups; + uint32_t m_popCount = 0; }; class SparseCBV : public CompressedBitVector { public: - SparseCBV(vector const & setBits) : m_positions(setBits) - { - ASSERT(is_sorted(m_positions.begin(), m_positions.end()), ()); - } + SparseCBV(vector const & setBits); - SparseCBV(vector && setBits) : m_positions(move(setBits)) - { - ASSERT(is_sorted(m_positions.begin(), m_positions.end()), ()); - } - - ~SparseCBV() = default; + SparseCBV(vector && setBits); // Returns the position of the i'th set bit. - uint64_t Select(size_t i) const - { - ASSERT_LESS(i, m_positions.size(), ()); - return m_positions[i]; - } + uint64_t Select(size_t i) const; template - void ForEach(F && f) const; + void ForEach(F && f) const + { + for (auto const & position : m_positions) + f(position); + } // CompressedBitVector overrides: - uint32_t PopCount() const override; - bool GetBit(uint32_t pos) const override; - StorageStrategy GetStorageStrategy() const override; - void Serialize(Writer & writer) const override; private: @@ -146,7 +134,11 @@ class CompressedBitVectorBuilder public: // Chooses a strategy to store the bit vector with bits from setBits set to one // and returns a pointer to a class that fits best. - static unique_ptr Build(vector const & setBits); + static unique_ptr FromBitPositions(vector const & setBits); + + // Chooses a strategy to store the bit vector with bits from a bitmap obtained + // by concatenating the elements of bitGroups. + static unique_ptr FromBitGroups(vector && bitGroups); // Reads a bit vector from reader which must contain a valid // bit vector representation (see CompressedBitVector::Serialize for the format). @@ -154,29 +146,53 @@ public: static unique_ptr Deserialize(TReader & reader) { ReaderSource src(reader); - uint8_t header = ReadPrimitiveFromSource(reader); + uint8_t header = ReadPrimitiveFromSource(src); CompressedBitVector::StorageStrategy strat = static_cast(header); switch (strat) { case CompressedBitVector::StorageStrategy::Dense: { - uint32_t numBitGroups = ReadPrimitiveFromSource(reader); - vector bitGroups(numBitGroups); - for (size_t i = 0; i < numBitGroups; ++i) - bitGroups[i] = ReadPrimitiveFromSource(reader); - return make_unique(move(bitGroups)); + vector bitGroups; + ReadPrimitiveVectorFromSource(src, bitGroups); + return DenseCBV::BuildFromBitGroups(move(bitGroups)); } case CompressedBitVector::StorageStrategy::Sparse: { - uint32_t numBits = ReadPrimitiveFromSource(reader); - vector setBits(numBits); - for (size_t i = 0; i < numBits; ++i) - setBits[i] = ReadPrimitiveFromSource(reader); + vector setBits; + ReadPrimitiveVectorFromSource(src, setBits); return make_unique(setBits); } } return nullptr; } }; + +// ForEach is generic and therefore cannot be virtual: a helper class is needed. +class CompressedBitVectorEnumerator +{ +public: + // Executes f for each bit that is set to one using + // the bit's 0-based position as argument. + template + static void ForEach(CompressedBitVector const & cbv, F && f) + { + CompressedBitVector::StorageStrategy strat = cbv.GetStorageStrategy(); + switch (strat) + { + case CompressedBitVector::StorageStrategy::Dense: + { + DenseCBV const & denseCBV = static_cast(cbv); + denseCBV.ForEach(f); + return; + } + case CompressedBitVector::StorageStrategy::Sparse: + { + SparseCBV const & sparseCBV = static_cast(cbv); + sparseCBV.ForEach(f); + return; + } + } + } +}; } // namespace coding diff --git a/coding/reader.hpp b/coding/reader.hpp index 2f8c8cb658..e4f7980439 100644 --- a/coding/reader.hpp +++ b/coding/reader.hpp @@ -8,7 +8,7 @@ #include "std/shared_ptr.hpp" #include "std/string.hpp" #include "std/cstring.hpp" - +#include "std/vector.hpp" // Base class for random-access Reader. Not thread-safe. class Reader @@ -110,13 +110,14 @@ private: // Reader wrapper to hold the pointer to a polymorfic reader. // Common use: ReaderSource >. // Note! It takes the ownership of Reader. -template class ReaderPtr +template +class ReaderPtr { protected: - shared_ptr m_p; + shared_ptr m_p; public: - ReaderPtr(ReaderT * p = 0) : m_p(p) {} + ReaderPtr(TReader * p = 0) : m_p(p) {} uint64_t Size() const { @@ -133,7 +134,7 @@ public: m_p->ReadAsString(s); } - ReaderT * GetPtr() const { return m_p.get(); } + TReader * GetPtr() const { return m_p.get(); } }; // Model reader store file id as string. @@ -167,14 +168,13 @@ public: // Source that reads from a reader. -template class ReaderSource +template +class ReaderSource { public: - typedef ReaderT ReaderType; + typedef TReader ReaderType; - ReaderSource(ReaderT const & reader) : m_reader(reader), m_pos(0) - { - } + ReaderSource(TReader const & reader) : m_reader(reader), m_pos(0) {} void Read(void * p, size_t size) { @@ -199,17 +199,14 @@ public: return (m_reader.Size() - m_pos); } - ReaderT SubReader(uint64_t size) + TReader SubReader(uint64_t size) { uint64_t const pos = m_pos; Skip(size); return m_reader.SubReader(pos, size); } - ReaderT SubReader() - { - return SubReader(Size()); - } + TReader SubReader() { return SubReader(Size()); } private: bool AssertPosition() const @@ -219,28 +216,39 @@ private: return ret; } - ReaderT m_reader; + TReader m_reader; uint64_t m_pos; }; -template inline -void ReadFromPos(ReaderT const & reader, uint64_t pos, void * p, size_t size) +template +inline void ReadFromPos(TReader const & reader, uint64_t pos, void * p, size_t size) { reader.Read(pos, p, size); } -template inline -PrimitiveT ReadPrimitiveFromPos(ReaderT const & reader, uint64_t pos) +template +inline TPrimitive ReadPrimitiveFromPos(TReader const & reader, uint64_t pos) { - PrimitiveT primitive; + TPrimitive primitive; ReadFromPos(reader, pos, &primitive, sizeof(primitive)); return SwapIfBigEndian(primitive); } -template inline -PrimitiveT ReadPrimitiveFromSource(TSource & source) +template +inline typename enable_if::value, TPrimitive>::type +ReadPrimitiveFromSource(TSource & source) { - PrimitiveT primitive; + TPrimitive primitive; source.Read(&primitive, sizeof(primitive)); return SwapIfBigEndian(primitive); } + +template +void ReadPrimitiveVectorFromSource(TSource && source, vector & result) +{ + // Do not overspecify the size type: uint32_t is enough. + size_t size = static_cast(ReadPrimitiveFromSource(source)); + result.resize(size); + for (size_t i = 0; i < size; ++i) + result[i] = ReadPrimitiveFromSource(source); +}