diff --git a/base/base_tests/bits_test.cpp b/base/base_tests/bits_test.cpp index 310e02ba2e..ef436f6a1e 100644 --- a/base/base_tests/bits_test.cpp +++ b/base/base_tests/bits_test.cpp @@ -108,3 +108,25 @@ UNIT_TEST(NumUsedBits) TEST_EQUAL(bits::NumUsedBits(0x0FABCDEF0FABCDEFULL), 60, ()); TEST_EQUAL(bits::NumUsedBits(0x000000000000FDEFULL), 16, ()); } + +UNIT_TEST(PopCount64) +{ + TEST_EQUAL(0, bits::PopCount(static_cast(0x0)), ()); + TEST_EQUAL(1, bits::PopCount(static_cast(0x1)), ()); + TEST_EQUAL(32, bits::PopCount(0xAAAAAAAA55555555), ()); + TEST_EQUAL(64, bits::PopCount(0xFFFFFFFFFFFFFFFF), ()); +} + +UNIT_TEST(CeilLog) +{ + TEST_EQUAL(0, bits::CeilLog(0x0), ()); + TEST_EQUAL(0, bits::CeilLog(0x1), ()); + TEST_EQUAL(1, bits::CeilLog(0x2), ()); + TEST_EQUAL(1, bits::CeilLog(0x3), ()); + TEST_EQUAL(2, bits::CeilLog(0x4), ()); + + TEST_EQUAL(6, bits::CeilLog(0x7f), ()); + TEST_EQUAL(7, bits::CeilLog(0x80), ()); + TEST_EQUAL(31, bits::CeilLog(0xFFFFFFFF), ()); + TEST_EQUAL(63, bits::CeilLog(0xFFFFFFFFFFFFFFFF), ()); +} diff --git a/base/bits.hpp b/base/bits.hpp index 1971e41c08..4c561f3422 100644 --- a/base/bits.hpp +++ b/base/bits.hpp @@ -72,6 +72,27 @@ namespace bits return static_cast(x); } + inline uint8_t CeilLog(uint64_t x) noexcept + { +#define CHECK_RSH(x, msb, offset) \ + if (x >> offset) \ + { \ + x >>= offset; \ + msb += offset; \ + } + + uint8_t msb = 0; + CHECK_RSH(x, msb, 32); + CHECK_RSH(x, msb, 16); + CHECK_RSH(x, msb, 8); + CHECK_RSH(x, msb, 4); + CHECK_RSH(x, msb, 2); + CHECK_RSH(x, msb, 1); +#undef CHECK_RSH + + return msb; + } + // Will be implemented when needed. uint64_t PopCount(uint64_t const * p, uint64_t n); diff --git a/coding/coding_tests/compressed_bit_vector_test.cpp b/coding/coding_tests/compressed_bit_vector_test.cpp index e74596bd0d..a09075bcfe 100644 --- a/coding/coding_tests/compressed_bit_vector_test.cpp +++ b/coding/coding_tests/compressed_bit_vector_test.cpp @@ -9,21 +9,50 @@ namespace { -void CheckIntersection(vector & setBits1, vector & setBits2, - unique_ptr const & cbv) +void Intersect(vector & setBits1, vector & setBits2, vector & result) { - TEST(cbv.get(), ()); - vector expected; sort(setBits1.begin(), setBits1.end()); sort(setBits2.begin(), setBits2.end()); set_intersection(setBits1.begin(), setBits1.end(), setBits2.begin(), setBits2.end(), - back_inserter(expected)); - TEST_EQUAL(expected.size(), cbv->PopCount(), ()); - vector expectedBitmap(expected.back() + 1); + back_inserter(result)); +} + +void Subtract(vector & setBits1, vector & setBits2, vector & result) +{ + sort(setBits1.begin(), setBits1.end()); + sort(setBits2.begin(), setBits2.end()); + set_difference(setBits1.begin(), setBits1.end(), setBits2.begin(), setBits2.end(), + back_inserter(result)); +} + +template +void CheckBinaryOp(TBinaryOp op, vector & setBits1, vector & setBits2, + coding::CompressedBitVector const & cbv) +{ + vector expected; + op(setBits1, setBits2, expected); + TEST_EQUAL(expected.size(), cbv.PopCount(), ()); + + vector expectedBitmap; + if (!expected.empty()) + expectedBitmap.resize(expected.back() + 1); + for (size_t i = 0; i < expected.size(); ++i) expectedBitmap[expected[i]] = true; for (size_t i = 0; i < expectedBitmap.size(); ++i) - TEST_EQUAL(cbv->GetBit(i), expectedBitmap[i], ()); + TEST_EQUAL(cbv.GetBit(i), expectedBitmap[i], ()); +} + +void CheckIntersection(vector & setBits1, vector & setBits2, + coding::CompressedBitVector const & cbv) +{ + CheckBinaryOp(&Intersect, setBits1, setBits2, cbv); +} + +void CheckSubtraction(vector & setBits1, vector & setBits2, + coding::CompressedBitVector const & cbv) +{ + CheckBinaryOp(&Subtract, setBits1, setBits2, cbv); } } // namespace @@ -43,9 +72,10 @@ UNIT_TEST(CompressedBitVector_Intersect1) auto cbv2 = coding::CompressedBitVectorBuilder::FromBitPositions(setBits2); TEST(cbv1.get(), ()); TEST(cbv2.get(), ()); + auto cbv3 = coding::CompressedBitVector::Intersect(*cbv1, *cbv2); TEST_EQUAL(coding::CompressedBitVector::StorageStrategy::Dense, cbv3->GetStorageStrategy(), ()); - CheckIntersection(setBits1, setBits2, cbv3); + CheckIntersection(setBits1, setBits2, *cbv3); } UNIT_TEST(CompressedBitVector_Intersect2) @@ -64,9 +94,10 @@ UNIT_TEST(CompressedBitVector_Intersect2) auto cbv2 = coding::CompressedBitVectorBuilder::FromBitPositions(setBits2); TEST(cbv1.get(), ()); TEST(cbv2.get(), ()); + auto cbv3 = coding::CompressedBitVector::Intersect(*cbv1, *cbv2); TEST_EQUAL(coding::CompressedBitVector::StorageStrategy::Sparse, cbv3->GetStorageStrategy(), ()); - CheckIntersection(setBits1, setBits2, cbv3); + CheckIntersection(setBits1, setBits2, *cbv3); } UNIT_TEST(CompressedBitVector_Intersect3) @@ -87,7 +118,7 @@ UNIT_TEST(CompressedBitVector_Intersect3) TEST(cbv2.get(), ()); auto cbv3 = coding::CompressedBitVector::Intersect(*cbv1, *cbv2); TEST_EQUAL(coding::CompressedBitVector::StorageStrategy::Sparse, cbv3->GetStorageStrategy(), ()); - CheckIntersection(setBits1, setBits2, cbv3); + CheckIntersection(setBits1, setBits2, *cbv3); } UNIT_TEST(CompressedBitVector_Intersect4) @@ -108,7 +139,79 @@ UNIT_TEST(CompressedBitVector_Intersect4) TEST(cbv2.get(), ()); auto cbv3 = coding::CompressedBitVector::Intersect(*cbv1, *cbv2); TEST_EQUAL(coding::CompressedBitVector::StorageStrategy::Sparse, cbv3->GetStorageStrategy(), ()); - CheckIntersection(setBits1, setBits2, cbv3); + CheckIntersection(setBits1, setBits2, *cbv3); +} + +UNIT_TEST(CompressedBitVector_Subtract1) +{ + vector setBits1 = {0, 1, 2, 3, 4, 5, 6}; + vector setBits2 = {1, 2, 3, 4, 5, 6, 7}; + + auto cbv1 = coding::CompressedBitVectorBuilder::FromBitPositions(setBits1); + auto cbv2 = coding::CompressedBitVectorBuilder::FromBitPositions(setBits2); + TEST(cbv1.get(), ()); + TEST(cbv2.get(), ()); + TEST_EQUAL(coding::CompressedBitVector::StorageStrategy::Dense, cbv1->GetStorageStrategy(), ()); + TEST_EQUAL(coding::CompressedBitVector::StorageStrategy::Dense, cbv2->GetStorageStrategy(), ()); + + auto cbv3 = coding::CompressedBitVector::Subtract(*cbv1, *cbv2); + TEST(cbv3.get(), ()); + TEST_EQUAL(coding::CompressedBitVector::StorageStrategy::Dense, cbv3->GetStorageStrategy(), ()); + CheckSubtraction(setBits1, setBits2, *cbv3); +} + +UNIT_TEST(CompressedBitVector_Subtract2) +{ + vector setBits1; + for (size_t i = 0; i < 100; ++i) + setBits1.push_back(i); + + vector setBits2 = {9, 14}; + auto cbv1 = coding::CompressedBitVectorBuilder::FromBitPositions(setBits1); + auto cbv2 = coding::CompressedBitVectorBuilder::FromBitPositions(setBits2); + TEST(cbv1.get(), ()); + TEST(cbv2.get(), ()); + TEST_EQUAL(coding::CompressedBitVector::StorageStrategy::Dense, cbv1->GetStorageStrategy(), ()); + TEST_EQUAL(coding::CompressedBitVector::StorageStrategy::Sparse, cbv2->GetStorageStrategy(), ()); + + auto cbv3 = coding::CompressedBitVector::Subtract(*cbv1, *cbv2); + TEST(cbv3.get(), ()); + TEST_EQUAL(coding::CompressedBitVector::StorageStrategy::Dense, cbv3->GetStorageStrategy(), ()); + CheckSubtraction(setBits1, setBits2, *cbv3); +} + +UNIT_TEST(CompressedBitVector_Subtract3) +{ + vector setBits1 = {0, 9}; + vector setBits2 = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + auto cbv1 = coding::CompressedBitVectorBuilder::FromBitPositions(setBits1); + auto cbv2 = coding::CompressedBitVectorBuilder::FromBitPositions(setBits2); + TEST(cbv1.get(), ()); + TEST(cbv2.get(), ()); + TEST_EQUAL(coding::CompressedBitVector::StorageStrategy::Sparse, cbv1->GetStorageStrategy(), ()); + TEST_EQUAL(coding::CompressedBitVector::StorageStrategy::Dense, cbv2->GetStorageStrategy(), ()); + + auto cbv3 = coding::CompressedBitVector::Subtract(*cbv1, *cbv2); + TEST(cbv3.get(), ()); + TEST_EQUAL(coding::CompressedBitVector::StorageStrategy::Sparse, cbv3->GetStorageStrategy(), ()); + CheckSubtraction(setBits1, setBits2, *cbv3); +} + +UNIT_TEST(CompressedBitVector_Subtract4) +{ + vector setBits1 = {0, 5, 15}; + vector setBits2 = {0, 10}; + auto cbv1 = coding::CompressedBitVectorBuilder::FromBitPositions(setBits1); + auto cbv2 = coding::CompressedBitVectorBuilder::FromBitPositions(setBits2); + TEST(cbv1.get(), ()); + TEST(cbv2.get(), ()); + TEST_EQUAL(coding::CompressedBitVector::StorageStrategy::Sparse, cbv1->GetStorageStrategy(), ()); + TEST_EQUAL(coding::CompressedBitVector::StorageStrategy::Sparse, cbv2->GetStorageStrategy(), ()); + + auto cbv3 = coding::CompressedBitVector::Subtract(*cbv1, *cbv2); + TEST(cbv3.get(), ()); + TEST_EQUAL(coding::CompressedBitVector::StorageStrategy::Sparse, cbv3->GetStorageStrategy(), ()); + CheckSubtraction(setBits1, setBits2, *cbv3); } UNIT_TEST(CompressedBitVector_SerializationDense) @@ -121,6 +224,7 @@ UNIT_TEST(CompressedBitVector_SerializationDense) { MemWriter> writer(buf); auto cbv = coding::CompressedBitVectorBuilder::FromBitPositions(setBits); + TEST_EQUAL(setBits.size(), cbv->PopCount(), ()); TEST_EQUAL(coding::CompressedBitVector::StorageStrategy::Dense, cbv->GetStorageStrategy(), ()); cbv->Serialize(writer); } diff --git a/coding/compressed_bit_vector.cpp b/coding/compressed_bit_vector.cpp index 876b0135f4..a5fc3ebee7 100644 --- a/coding/compressed_bit_vector.cpp +++ b/coding/compressed_bit_vector.cpp @@ -1,80 +1,156 @@ #include "coding/compressed_bit_vector.hpp" -#include "coding/writer.hpp" + #include "coding/write_to_sink.hpp" +#include "base/assert.hpp" #include "base/bits.hpp" #include "std/algorithm.hpp" namespace coding { -// static -uint32_t const DenseCBV::kBlockSize; -} // namespace coding - namespace { -uint64_t const kBlockSize = coding::DenseCBV::kBlockSize; - -unique_ptr IntersectImpl(coding::DenseCBV const & a, - coding::DenseCBV const & b) +struct IntersectOp { - size_t sizeA = a.NumBitGroups(); - size_t sizeB = b.NumBitGroups(); - vector resGroups(min(sizeA, sizeB)); - for (size_t i = 0; i < resGroups.size(); ++i) - resGroups[i] = a.GetBitGroup(i) & b.GetBitGroup(i); - return coding::CompressedBitVectorBuilder::FromBitGroups(move(resGroups)); -} + IntersectOp() {} -// The intersection of dense and sparse is always sparse. -unique_ptr IntersectImpl(coding::DenseCBV const & a, - coding::SparseCBV const & b) -{ - vector resPos; - for (size_t i = 0; i < b.PopCount(); ++i) + unique_ptr operator()(coding::DenseCBV const & a, + coding::DenseCBV const & b) const { - auto pos = b.Select(i); - if (a.GetBit(pos)) - resPos.push_back(pos); + size_t sizeA = a.NumBitGroups(); + size_t sizeB = b.NumBitGroups(); + vector resGroups(min(sizeA, sizeB)); + for (size_t i = 0; i < resGroups.size(); ++i) + resGroups[i] = a.GetBitGroup(i) & b.GetBitGroup(i); + return coding::CompressedBitVectorBuilder::FromBitGroups(move(resGroups)); } - return make_unique(move(resPos)); -} -unique_ptr IntersectImpl(coding::SparseCBV const & a, - coding::DenseCBV const & b) -{ - return IntersectImpl(b, a); -} - -unique_ptr IntersectImpl(coding::SparseCBV const & a, - coding::SparseCBV const & b) -{ - size_t sizeA = a.PopCount(); - size_t sizeB = b.PopCount(); - vector resPos; - size_t i = 0; - size_t j = 0; - while (i < sizeA && j < sizeB) + // The intersection of dense and sparse is always sparse. + unique_ptr operator()(coding::DenseCBV const & a, + coding::SparseCBV const & b) const { - auto posA = a.Select(i); - auto posB = b.Select(j); - if (posA == posB) + vector resPos; + for (size_t i = 0; i < b.PopCount(); ++i) { - resPos.push_back(posA); - ++i; - ++j; - } - else if (posA < posB) - { - ++i; - } - else - { - ++j; + auto pos = b.Select(i); + if (a.GetBit(pos)) + resPos.push_back(pos); } + return make_unique(move(resPos)); } - return make_unique(move(resPos)); + + unique_ptr operator()(coding::SparseCBV const & a, + coding::DenseCBV const & b) const + { + return operator()(b, a); + } + + unique_ptr operator()(coding::SparseCBV const & a, + coding::SparseCBV const & b) const + { + vector resPos; + set_intersection(a.Begin(), a.End(), b.Begin(), b.End(), back_inserter(resPos)); + return make_unique(move(resPos)); + } +}; + +struct SubtractOp +{ + SubtractOp() {} + + unique_ptr operator()(coding::DenseCBV const & a, + coding::DenseCBV const & b) const + { + size_t sizeA = a.NumBitGroups(); + size_t sizeB = b.NumBitGroups(); + vector resGroups(min(sizeA, sizeB)); + for (size_t i = 0; i < resGroups.size(); ++i) + resGroups[i] = a.GetBitGroup(i) & ~b.GetBitGroup(i); + return CompressedBitVectorBuilder::FromBitGroups(move(resGroups)); + } + + unique_ptr operator()(coding::DenseCBV const & a, + coding::SparseCBV const & b) const + { + vector resGroups(a.NumBitGroups()); + + size_t i = 0; + auto j = b.Begin(); + for (; i < resGroups.size() && j < b.End(); ++i) + { + uint64_t const kBitsBegin = i * DenseCBV::kBlockSize; + uint64_t const kBitsEnd = (i + 1) * DenseCBV::kBlockSize; + + uint64_t mask = 0; + for (; j < b.End() && *j < kBitsEnd; ++j) + { + ASSERT_GREATER_OR_EQUAL(*j, kBitsBegin, ()); + mask |= static_cast(1) << (*j - kBitsBegin); + } + + resGroups[i] = a.GetBitGroup(i) & ~mask; + } + + for (; i < resGroups.size(); ++i) + resGroups[i] = a.GetBitGroup(i); + + return CompressedBitVectorBuilder::FromBitGroups(move(resGroups)); + } + + unique_ptr operator()(coding::SparseCBV const & a, + coding::DenseCBV const & b) const + { + vector resPos; + copy_if(a.Begin(), a.End(), back_inserter(resPos), [&](uint64_t bit) + { + return !b.GetBit(bit); + }); + return CompressedBitVectorBuilder::FromBitPositions(move(resPos)); + } + + unique_ptr operator()(coding::SparseCBV const & a, + coding::SparseCBV const & b) const + { + vector resPos; + set_difference(a.Begin(), a.End(), b.Begin(), b.End(), back_inserter(resPos)); + return CompressedBitVectorBuilder::FromBitPositions(move(resPos)); + } +}; + +template +unique_ptr Apply(TBinaryOp const & op, CompressedBitVector const & lhs, + CompressedBitVector const & rhs) +{ + using strat = CompressedBitVector::StorageStrategy; + auto const stratA = lhs.GetStorageStrategy(); + auto const stratB = rhs.GetStorageStrategy(); + if (stratA == strat::Dense && stratB == strat::Dense) + { + DenseCBV const & a = static_cast(lhs); + DenseCBV const & b = static_cast(rhs); + return op(a, b); + } + if (stratA == strat::Dense && stratB == strat::Sparse) + { + DenseCBV const & a = static_cast(lhs); + SparseCBV const & b = static_cast(rhs); + return op(a, b); + } + if (stratA == strat::Sparse && stratB == strat::Dense) + { + SparseCBV const & a = static_cast(lhs); + DenseCBV const & b = static_cast(rhs); + return op(a, b); + } + if (stratA == strat::Sparse && stratB == strat::Sparse) + { + SparseCBV const & a = static_cast(lhs); + SparseCBV const & b = static_cast(rhs); + return op(a, b); + } + + return nullptr; } // Returns true if a bit vector with popCount bits set out of totalBits @@ -87,22 +163,34 @@ bool DenseEnough(uint64_t popCount, uint64_t totalBits) // Settle at 30% for now. return popCount * 10 >= totalBits * 3; } + +template +unique_ptr BuildFromBitPositions(TBitPositions && setBits) +{ + if (setBits.empty()) + return make_unique(forward(setBits)); + uint64_t const maxBit = *max_element(setBits.begin(), setBits.end()); + + if (DenseEnough(setBits.size(), maxBit)) + return make_unique(forward(setBits)); + + return make_unique(forward(setBits)); +} } // namespace -namespace coding -{ +// static +uint64_t const DenseCBV::kBlockSize; + DenseCBV::DenseCBV(vector const & setBits) { if (setBits.empty()) { return; } - uint64_t maxBit = setBits[0]; - for (size_t i = 1; i < setBits.size(); ++i) - maxBit = max(maxBit, setBits[i]); + uint64_t const maxBit = *max_element(setBits.begin(), setBits.end()); size_t const sz = 1 + maxBit / kBlockSize; m_bitGroups.resize(sz); - m_popCount = static_cast(setBits.size()); + m_popCount = static_cast(setBits.size()); for (uint64_t pos : setBits) m_bitGroups[pos / kBlockSize] |= static_cast(1) << (pos % kBlockSize); } @@ -123,9 +211,9 @@ uint64_t DenseCBV::GetBitGroup(size_t i) const return i < m_bitGroups.size() ? m_bitGroups[i] : 0; } -uint32_t DenseCBV::PopCount() const { return m_popCount; } +uint64_t DenseCBV::PopCount() const { return m_popCount; } -bool DenseCBV::GetBit(uint32_t pos) const +bool DenseCBV::GetBit(uint64_t pos) const { uint64_t bitGroup = GetBitGroup(pos / kBlockSize); return ((bitGroup >> (pos % kBlockSize)) & 1) > 0; @@ -140,9 +228,7 @@ void DenseCBV::Serialize(Writer & writer) const { uint8_t header = static_cast(GetStorageStrategy()); WriteToSink(writer, header); - WriteToSink(writer, static_cast(NumBitGroups())); - for (size_t i = 0; i < NumBitGroups(); ++i) - WriteToSink(writer, GetBitGroup(i)); + rw::WriteVectorOfPOD(writer, m_bitGroups); } SparseCBV::SparseCBV(vector const & setBits) : m_positions(setBits) @@ -161,9 +247,9 @@ uint64_t SparseCBV::Select(size_t i) const return m_positions[i]; } -uint32_t SparseCBV::PopCount() const { return m_positions.size(); } +uint64_t SparseCBV::PopCount() const { return m_positions.size(); } -bool SparseCBV::GetBit(uint32_t pos) const +bool SparseCBV::GetBit(uint64_t pos) const { auto const it = lower_bound(m_positions.begin(), m_positions.end(), pos); return it != m_positions.end() && *it == pos; @@ -178,39 +264,35 @@ void SparseCBV::Serialize(Writer & writer) const { uint8_t header = static_cast(GetStorageStrategy()); WriteToSink(writer, header); - WriteToSink(writer, PopCount()); - ForEach([&](uint64_t bitPos) - { - WriteToSink(writer, bitPos); - }); + rw::WriteVectorOfPOD(writer, m_positions); } // static unique_ptr CompressedBitVectorBuilder::FromBitPositions( vector const & setBits) { - if (setBits.empty()) - return make_unique(setBits); - uint64_t maxBit = setBits[0]; - for (size_t i = 1; i < setBits.size(); ++i) - maxBit = max(maxBit, setBits[i]); + return BuildFromBitPositions(setBits); +} - if (DenseEnough(setBits.size(), maxBit)) - return make_unique(setBits); - - return make_unique(setBits); +// static +unique_ptr CompressedBitVectorBuilder::FromBitPositions( + vector && setBits) +{ + return BuildFromBitPositions(move(setBits)); } // static unique_ptr CompressedBitVectorBuilder::FromBitGroups( vector && bitGroups) { + static uint64_t const kBlockSize = DenseCBV::kBlockSize; + while (!bitGroups.empty() && bitGroups.back() == 0) bitGroups.pop_back(); if (bitGroups.empty()) return make_unique(bitGroups); - uint64_t const maxBit = kBlockSize * bitGroups.size() - 1; + uint64_t const maxBit = kBlockSize * (bitGroups.size() - 1) + bits::CeilLog(bitGroups.back()); uint64_t popCount = 0; for (size_t i = 0; i < bitGroups.size(); ++i) popCount += bits::PopCount(bitGroups[i]); @@ -245,34 +327,15 @@ string DebugPrint(CompressedBitVector::StorageStrategy strat) unique_ptr CompressedBitVector::Intersect(CompressedBitVector const & lhs, CompressedBitVector const & rhs) { - using strat = CompressedBitVector::StorageStrategy; - auto const stratA = lhs.GetStorageStrategy(); - auto const stratB = rhs.GetStorageStrategy(); - if (stratA == strat::Dense && stratB == strat::Dense) - { - DenseCBV const & a = static_cast(lhs); - DenseCBV const & b = static_cast(rhs); - return IntersectImpl(a, b); - } - if (stratA == strat::Dense && stratB == strat::Sparse) - { - DenseCBV const & a = static_cast(lhs); - SparseCBV const & b = static_cast(rhs); - return IntersectImpl(a, b); - } - if (stratA == strat::Sparse && stratB == strat::Dense) - { - SparseCBV const & a = static_cast(lhs); - DenseCBV const & b = static_cast(rhs); - return IntersectImpl(a, b); - } - if (stratA == strat::Sparse && stratB == strat::Sparse) - { - SparseCBV const & a = static_cast(lhs); - SparseCBV const & b = static_cast(rhs); - return IntersectImpl(a, b); - } + static IntersectOp const intersectOp; + return Apply(intersectOp, lhs, rhs); +} - return nullptr; +// static +unique_ptr CompressedBitVector::Subtract(CompressedBitVector const & lhs, + CompressedBitVector const & rhs) +{ + static SubtractOp const subtractOp; + return Apply(subtractOp, lhs, rhs); } } // namespace coding diff --git a/coding/compressed_bit_vector.hpp b/coding/compressed_bit_vector.hpp index 1fca665dd4..d048c86767 100644 --- a/coding/compressed_bit_vector.hpp +++ b/coding/compressed_bit_vector.hpp @@ -1,14 +1,10 @@ -#include "std/vector.hpp" - -#include "base/assert.hpp" - +#include "coding/read_write_utils.hpp" #include "coding/reader.hpp" #include "coding/writer.hpp" #include "std/algorithm.hpp" #include "std/unique_ptr.hpp" - -#include "base/assert.hpp" +#include "std/vector.hpp" namespace coding { @@ -38,12 +34,16 @@ public: static unique_ptr Intersect(CompressedBitVector const & lhs, CompressedBitVector const & rhs); + // Subtracts two bit vectors. + static unique_ptr Subtract(CompressedBitVector const & lhs, + CompressedBitVector const & rhs); + // Returns the number of set bits (population count). - virtual uint32_t PopCount() const = 0; + virtual uint64_t PopCount() const = 0; // todo(@pimenov) How long will 32 bits be enough here? // Would operator[] look better? - virtual bool GetBit(uint32_t pos) const = 0; + virtual bool GetBit(uint64_t pos) const = 0; // Returns the strategy used when storing this bit vector. virtual StorageStrategy GetStorageStrategy() const = 0; @@ -65,7 +65,7 @@ string DebugPrint(CompressedBitVector::StorageStrategy strat); class DenseCBV : public CompressedBitVector { public: - static uint32_t const kBlockSize = 64; + static uint64_t const kBlockSize = 64; DenseCBV() = default; @@ -95,19 +95,21 @@ public: uint64_t GetBitGroup(size_t i) const; // CompressedBitVector overrides: - uint32_t PopCount() const override; - bool GetBit(uint32_t pos) const override; + uint64_t PopCount() const override; + bool GetBit(uint64_t pos) const override; StorageStrategy GetStorageStrategy() const override; void Serialize(Writer & writer) const override; private: vector m_bitGroups; - uint32_t m_popCount = 0; + uint64_t m_popCount = 0; }; class SparseCBV : public CompressedBitVector { public: + using TIterator = vector::const_iterator; + SparseCBV(vector const & setBits); SparseCBV(vector && setBits); @@ -123,11 +125,14 @@ public: } // CompressedBitVector overrides: - uint32_t PopCount() const override; - bool GetBit(uint32_t pos) const override; + uint64_t PopCount() const override; + bool GetBit(uint64_t pos) const override; StorageStrategy GetStorageStrategy() const override; void Serialize(Writer & writer) const override; + inline TIterator Begin() const { return m_positions.cbegin(); } + inline TIterator End() const { return m_positions.cend(); } + private: // 0-based positions of the set bits. vector m_positions; @@ -139,6 +144,7 @@ public: // Chooses a strategy to store the bit vector with bits from setBits set to one // and returns a pointer to a class that fits best. static unique_ptr FromBitPositions(vector const & setBits); + static unique_ptr FromBitPositions(vector && setBits); // Chooses a strategy to store the bit vector with bits from a bitmap obtained // by concatenating the elements of bitGroups. @@ -158,14 +164,14 @@ public: case CompressedBitVector::StorageStrategy::Dense: { vector bitGroups; - ReadPrimitiveVectorFromSource(src, bitGroups); + rw::ReadVectorOfPOD(src, bitGroups); return DenseCBV::BuildFromBitGroups(move(bitGroups)); } case CompressedBitVector::StorageStrategy::Sparse: { vector setBits; - ReadPrimitiveVectorFromSource(src, setBits); - return make_unique(setBits); + rw::ReadVectorOfPOD(src, setBits); + return make_unique(move(setBits)); } } return nullptr; diff --git a/coding/reader.hpp b/coding/reader.hpp index 7b61ebd20d..09a1b73ec8 100644 --- a/coding/reader.hpp +++ b/coding/reader.hpp @@ -247,13 +247,3 @@ TPrimitive ReadPrimitiveFromSource(TSource & source) source.Read(&primitive, sizeof(primitive)); return SwapIfBigEndian(primitive); } - -template -void ReadPrimitiveVectorFromSource(TSource && source, vector & result) -{ - // Do not overspecify the size type: uint32_t is enough. - size_t size = static_cast(ReadPrimitiveFromSource(source)); - result.resize(size); - for (size_t i = 0; i < size; ++i) - result[i] = ReadPrimitiveFromSource(source); -}