diff --git a/coding/bit_streams.cpp b/coding/bit_streams.cpp new file mode 100644 index 0000000000..32a05b1a25 --- /dev/null +++ b/coding/bit_streams.cpp @@ -0,0 +1,64 @@ +#include "bit_streams.hpp" + +#include "reader.hpp" +#include "writer.hpp" + +BitSink::BitSink(Writer & writer) + : m_writer(writer), m_lastByte(0), m_size(0) {} + +BitSink::~BitSink() +{ + if (m_size % 8 > 0) m_writer.Write(&m_lastByte, 1); +} + +void BitSink::Write(uint64_t bits, uint32_t writeSize) +{ + if (writeSize == 0) return; + m_totalBits += writeSize; + uint32_t remSize = m_size % 8; + CHECK_LESS_OR_EQUAL(writeSize, 64 - remSize, ()); + if (remSize > 0) + { + bits <<= remSize; + bits |= m_lastByte; + writeSize += remSize; + m_size -= remSize; + } + uint32_t writeBytesSize = writeSize / 8; + m_writer.Write(&bits, writeBytesSize); + m_lastByte = (bits >> (writeBytesSize * 8)) & ((1 << (writeSize % 8)) - 1); + m_size += writeSize; +} + + +BitSource::BitSource(Reader & reader) + : m_reader(reader), m_serialCur(0), m_serialEnd(reader.Size()), + m_bits(0), m_bitsSize(0), m_totalBitsRead(0) {} + +uint64_t BitSource::Read(uint32_t readSize) +{ + uint32_t requestedReadSize = readSize; + if (readSize == 0) return 0; + CHECK_LESS_OR_EQUAL(readSize, 64, ()); + // First read, sets bits that are in the m_bits buffer. + uint32_t firstReadSize = readSize <= m_bitsSize ? readSize : m_bitsSize; + uint64_t result = m_bits & (~uint64_t(0) >> (64 - firstReadSize)); + m_bits >>= firstReadSize; + m_bitsSize -= firstReadSize; + readSize -= firstReadSize; + // Second read, does an extra read using m_reader. + if (readSize > 0) + { + uint32_t read_byte_size = m_serialCur + sizeof(m_bits) <= m_serialEnd ? sizeof(m_bits) : m_serialEnd - m_serialCur; + m_reader.Read(m_serialCur, &m_bits, read_byte_size); + m_serialCur += read_byte_size; + m_bitsSize += read_byte_size * 8; + if (readSize > m_bitsSize) CHECK_LESS_OR_EQUAL(readSize, m_bitsSize, ()); + result |= (m_bits & (~uint64_t(0) >> (64 - readSize))) << firstReadSize; + m_bits >>= readSize; + m_bitsSize -= readSize; + readSize = 0; + } + m_totalBitsRead += requestedReadSize; + return result; +} diff --git a/coding/bit_streams.hpp b/coding/bit_streams.hpp new file mode 100644 index 0000000000..9ba330d3d9 --- /dev/null +++ b/coding/bit_streams.hpp @@ -0,0 +1,42 @@ +// Author: Artyom Polkovnikov. +// Bits source and sink for sequential read and write of bits. + +#pragma once + +#include "../std/stdint.hpp" + +// Forward declarations. +class Reader; +class Writer; + +class BitSink +{ +public: + BitSink(Writer & writer); + // Note! Last byte is flushed in destructor. + ~BitSink(); + uint64_t NumBitsWritten() const { return m_size; } + // Write writeSize number of bits from least significant side of bits number. + void Write(uint64_t bits, uint32_t writeSize); +private: + Writer & m_writer; + uint8_t m_lastByte; + uint64_t m_size; + uint64_t m_totalBits; +}; + +class BitSource +{ +public: + BitSource(Reader & reader); + uint64_t NumBitsRead() const { return m_totalBitsRead; } + // Read readSize number of bits, return it as least significant bits of 64-bit number. + uint64_t Read(uint32_t readSize); +private: + Reader & m_reader; + uint64_t m_serialCur; + uint64_t m_serialEnd; + uint64_t m_bits; + uint32_t m_bitsSize; + uint64_t m_totalBitsRead; +}; diff --git a/coding/coding.pro b/coding/coding.pro index 56a81f24cf..141e512226 100644 --- a/coding/coding.pro +++ b/coding/coding.pro @@ -36,6 +36,7 @@ SOURCES += \ arithmetic_codec.cpp \ compressed_bit_vector.cpp \ compressed_varnum_vector.cpp \ + bit_streams.cpp \ HEADERS += \ internal/xmlparser.hpp \ @@ -101,3 +102,5 @@ HEADERS += \ arithmetic_codec.hpp \ compressed_bit_vector.hpp \ compressed_varnum_vector.hpp \ + varint_misc.hpp \ + bit_streams.hpp \ diff --git a/coding/compressed_bit_vector.cpp b/coding/compressed_bit_vector.cpp index f63720b542..6573f573ed 100644 --- a/coding/compressed_bit_vector.cpp +++ b/coding/compressed_bit_vector.cpp @@ -1,6 +1,7 @@ #include "compressed_bit_vector.hpp" #include "arithmetic_codec.hpp" +#include "bit_streams.hpp" #include "reader.hpp" #include "writer.hpp" #include "varint_misc.hpp" @@ -17,79 +18,6 @@ namespace { } } -class BitWriter -{ -public: - BitWriter(Writer & writer) - : m_writer(writer), m_lastByte(0), m_size(0) {} - ~BitWriter() { if (m_size % 8 > 0) m_writer.Write(&m_lastByte, 1); } - uint64_t NumBitsWritten() const { return m_size; } - void Write(uint64_t bits, uint32_t writeSize) - { - if (writeSize == 0) return; - m_totalBits += writeSize; - uint32_t remSize = m_size % 8; - CHECK_LESS_OR_EQUAL(writeSize, 64 - remSize, ()); - if (remSize > 0) - { - bits <<= remSize; - bits |= m_lastByte; - writeSize += remSize; - m_size -= remSize; - } - uint32_t writeBytesSize = writeSize / 8; - m_writer.Write(&bits, writeBytesSize); - m_lastByte = (bits >> (writeBytesSize * 8)) & ((1 << (writeSize % 8)) - 1); - m_size += writeSize; - } -private: - Writer & m_writer; - uint8_t m_lastByte; - uint64_t m_size; - uint64_t m_totalBits; -}; - -class BitReader { -public: - BitReader(Reader & reader) - : m_reader(reader), m_serialCur(0), m_serialEnd(reader.Size()), - m_bits(0), m_bitsSize(0), m_totalBitsRead(0) {} - uint64_t NumBitsRead() const { return m_totalBitsRead; } - uint64_t Read(uint32_t readSize) - { - m_totalBitsRead += readSize; - if (readSize == 0) return 0; - CHECK_LESS_OR_EQUAL(readSize, 64, ()); - // First read, sets bits that are in the m_bits buffer. - uint32_t firstReadSize = readSize <= m_bitsSize ? readSize : m_bitsSize; - uint64_t result = m_bits & (~uint64_t(0) >> (64 - firstReadSize)); - m_bits >>= firstReadSize; - m_bitsSize -= firstReadSize; - readSize -= firstReadSize; - // Second read, does an extra read using m_reader. - if (readSize > 0) - { - uint32_t readByteSize = m_serialCur + sizeof(m_bits) <= m_serialEnd ? sizeof(m_bits) : m_serialEnd - m_serialCur; - m_reader.Read(m_serialCur, &m_bits, readByteSize); - m_serialCur += readByteSize; - m_bitsSize += readByteSize * 8; - if (readSize > m_bitsSize) CHECK_LESS_OR_EQUAL(readSize, m_bitsSize, ()); - result |= (m_bits & (~uint64_t(0) >> (64 - readSize))) << firstReadSize; - m_bits >>= readSize; - m_bitsSize -= readSize; - readSize = 0; - } - return result; - } -private: - Reader & m_reader; - uint64_t m_serialCur; - uint64_t m_serialEnd; - uint64_t m_bits; - uint32_t m_bitsSize; - uint64_t m_totalBitsRead; -}; - void BuildCompressedBitVector(Writer & writer, vector const & posOnes, int chosenEncType) { uint32_t const BLOCK_SIZE = 7; @@ -256,8 +184,8 @@ void BuildCompressedBitVector(Writer & writer, vector const & posOnes, writer.Write(serialSizesEnc.data(), serialSizesEnc.size()); } { - // Second Stage. Encode all bits of all diffs using BitWriter. - BitWriter bitWriter(writer); + // Second Stage. Encode all bits of all diffs using BitSink. + BitSink bitWriter(writer); int64_t prevOnePos = -1; uint64_t totalReadBits = 0; uint64_t totalReadCnts = 0; @@ -388,8 +316,8 @@ void BuildCompressedBitVector(Writer & writer, vector const & posOnes, } { - // Second stage, encode all ranges bits using BitWriter. - BitWriter bitWriter(writer); + // Second stage, encode all ranges bits using BitSink. + BitSink bitWriter(writer); int64_t prevOnePos = -1; uint64_t onesRangeLen = 0; for (uint32_t i = 0; i < posOnes.size(); ++i) @@ -474,7 +402,7 @@ vector DecodeCompressedBitVector(Reader & reader) { for (uint64_t i = 0; i < cntElements; ++i) bitsUsedVec.push_back(arithDec.Decode()); decodeOffset += encSizesBytesize; Reader * bitReaderReader = reader.CreateSubReader(decodeOffset, serialSize - decodeOffset); - BitReader bitReader(*bitReaderReader); + BitSource bitReader(*bitReaderReader); int64_t prevOnePos = -1; for (uint64_t i = 0; i < cntElements; ++i) { @@ -526,7 +454,7 @@ vector DecodeCompressedBitVector(Reader & reader) { for (uint64_t i = 0; i < cntElements1; ++i) bitsSizes1.push_back(arith_dec1.Decode()); decodeOffset += enc1SizesBytesize; Reader * bitReaderReader = reader.CreateSubReader(decodeOffset, serialSize - decodeOffset); - BitReader bitReader(*bitReaderReader); + BitSource bitReader(*bitReaderReader); uint64_t sum = 0, i0 = 0, i1 = 0; while (i0 < cntElements0 && i1 < cntElements1) { diff --git a/coding/compressed_varnum_vector.cpp b/coding/compressed_varnum_vector.cpp index 41acd61f4f..ac5defe269 100644 --- a/coding/compressed_varnum_vector.cpp +++ b/coding/compressed_varnum_vector.cpp @@ -1,4 +1,5 @@ #include "arithmetic_codec.hpp" +#include "bit_streams.hpp" #include "compressed_varnum_vector.hpp" #include "reader.hpp" #include "writer.hpp" @@ -18,80 +19,6 @@ namespace { } } -class BitWriter -{ -public: - BitWriter(Writer & writer) - : m_writer(writer), m_lastByte(0), m_size(0) {} - ~BitWriter() { if (m_size % 8 > 0) m_writer.Write(&m_lastByte, 1); } - u64 NumBitsWritten() const { return m_size; } - void Write(u64 bits, u32 writeSize) - { - if (writeSize == 0) return; - m_totalBits += writeSize; - u32 remSize = m_size % 8; - CHECK_LESS_OR_EQUAL(writeSize, 64 - remSize, ()); - if (remSize > 0) - { - bits <<= remSize; - bits |= m_lastByte; - writeSize += remSize; - m_size -= remSize; - } - u32 writeBytesSize = writeSize / 8; - m_writer.Write(&bits, writeBytesSize); - m_lastByte = (bits >> (writeBytesSize * 8)) & ((1 << (writeSize % 8)) - 1); - m_size += writeSize; - } -private: - Writer & m_writer; - u8 m_lastByte; - u64 m_size; - u64 m_totalBits; -}; - -class BitReader -{ -public: - BitReader(Reader & reader) - : m_reader(reader), m_serialCur(0), m_serialEnd(reader.Size()), - m_bits(0), m_bitsSize(0), m_totalBitsRead(0) {} - u64 NumBitsRead() const { return m_totalBitsRead; } - u64 Read(u32 readSize) - { - m_totalBitsRead += readSize; - if (readSize == 0) return 0; - CHECK_LESS_OR_EQUAL(readSize, 64, ()); - // First read, sets bits that are in the m_bits buffer. - u32 firstReadSize = readSize <= m_bitsSize ? readSize : m_bitsSize; - u64 result = m_bits & (~u64(0) >> (64 - firstReadSize)); - m_bits >>= firstReadSize; - m_bitsSize -= firstReadSize; - readSize -= firstReadSize; - // Second read, does an extra read using m_reader. - if (readSize > 0) - { - u32 read_byte_size = m_serialCur + sizeof(m_bits) <= m_serialEnd ? sizeof(m_bits) : m_serialEnd - m_serialCur; - m_reader.Read(m_serialCur, &m_bits, read_byte_size); - m_serialCur += read_byte_size; - m_bitsSize += read_byte_size * 8; - if (readSize > m_bitsSize) CHECK_LESS_OR_EQUAL(readSize, m_bitsSize, ()); - result |= (m_bits & (~u64(0) >> (64 - readSize))) << firstReadSize; - m_bits >>= readSize; - m_bitsSize -= readSize; - readSize = 0; - } - return result; - } -private: - Reader & m_reader; - u64 m_serialCur; - u64 m_serialEnd; - u64 m_bits; - u32 m_bitsSize; - u64 m_totalBitsRead; -}; - void BuildCompressedVarnumVector(Writer & writer, NumsSourceFuncT numsSource, u64 numsCnt, bool supportSums) { // Encode header. @@ -130,7 +57,7 @@ void BuildCompressedVarnumVector(Writer & writer, NumsSourceFuncT numsSource, u6 ArithmeticEncoder arithEncSizes(distr_table); { MemWriter< vector > encoded_bits_writer(encodedBits); - BitWriter bitsWriter(encoded_bits_writer); + BitSink bitsWriter(encoded_bits_writer); for (u64 ichunkNum = 0; ichunkNum < NUM_ELEM_PER_TABLE_ENTRY && inum < numsCnt; ++ichunkNum, ++inum) { u64 num = numsSource(inum); @@ -162,7 +89,7 @@ struct CompressedVarnumVectorReader::DecodeContext unique_ptr m_sizesArithDecReader; unique_ptr m_sizesArithDec; unique_ptr m_numsBitsReaderReader; - unique_ptr m_numsBitsReader; + unique_ptr m_numsBitsReader; u64 m_numsLeftInChunk; }; @@ -221,7 +148,7 @@ void CompressedVarnumVectorReader::SetDecodeContext(u64 tableEntryIndex) m_decodeCtx->m_sizesArithDecReader.reset(m_reader.CreateSubReader(decodeOffset, encodedSizesSize)); m_decodeCtx->m_sizesArithDec.reset(new ArithmeticDecoder(*m_decodeCtx->m_sizesArithDecReader, m_distrTable)); m_decodeCtx->m_numsBitsReaderReader.reset(m_reader.CreateSubReader(decodeOffset + encodedSizesSize, m_numsEncodedOffset + m_tablePos[tableEntryIndex + 1] - decodeOffset - encodedSizesSize)); - m_decodeCtx->m_numsBitsReader.reset(new BitReader(*m_decodeCtx->m_numsBitsReaderReader)); + m_decodeCtx->m_numsBitsReader.reset(new BitSource(*m_decodeCtx->m_numsBitsReaderReader)); m_decodeCtx->m_numsLeftInChunk = min((tableEntryIndex + 1) * m_numElemPerTableEntry, m_numsCnt) - tableEntryIndex * m_numElemPerTableEntry; }