diff --git a/coding/bit_streams.cpp b/coding/bit_streams.cpp deleted file mode 100644 index c10547e50a..0000000000 --- a/coding/bit_streams.cpp +++ /dev/null @@ -1,74 +0,0 @@ -#include "coding/bit_streams.hpp" - -#include "coding/reader.hpp" -#include "coding/writer.hpp" - -BitSink::BitSink(Writer & writer) - : m_writer(writer), m_lastByte(0), m_size(0) {} - -BitSink::~BitSink() -{ - if (m_size % 8 > 0) m_writer.Write(&m_lastByte, 1); -} - -void BitSink::Write(uint64_t bits, uint32_t writeSize) -{ - if (writeSize == 0) return; - CHECK_LESS_OR_EQUAL(writeSize, 64, ()); - m_totalBits += writeSize; - uint32_t remSize = m_size % 8; - if (writeSize > 64 - remSize) - { - uint64_t writeData = (bits << remSize) | m_lastByte; - m_writer.Write(&writeData, sizeof(writeData)); - m_lastByte = uint8_t(bits >> (64 - remSize)); - m_size += writeSize; - } - else - { - if (remSize > 0) - { - bits <<= remSize; - bits |= m_lastByte; - writeSize += remSize; - m_size -= remSize; - } - uint32_t writeBytesSize = writeSize / 8; - m_writer.Write(&bits, writeBytesSize); - m_lastByte = (bits >> (writeBytesSize * 8)) & ((1 << (writeSize % 8)) - 1); - m_size += writeSize; - } -} - - -BitSource::BitSource(Reader & reader) - : m_reader(reader), m_serialCur(0), m_serialEnd(reader.Size()), - m_bits(0), m_bitsSize(0), m_totalBitsRead(0) {} - -uint64_t BitSource::Read(uint32_t readSize) -{ - uint32_t requestedReadSize = readSize; - if (readSize == 0) return 0; - CHECK_LESS_OR_EQUAL(readSize, 64, ()); - // First read, sets bits that are in the m_bits buffer. - uint32_t firstReadSize = readSize <= m_bitsSize ? readSize : m_bitsSize; - uint64_t result = m_bits & (~uint64_t(0) >> (64 - firstReadSize)); - m_bits >>= firstReadSize; - m_bitsSize -= firstReadSize; - readSize -= firstReadSize; - // Second read, does an extra read using m_reader. - if (readSize > 0) - { - size_t read_byte_size = m_serialCur + sizeof(m_bits) <= m_serialEnd ? sizeof(m_bits) : m_serialEnd - m_serialCur; - m_reader.Read(m_serialCur, &m_bits, read_byte_size); - m_serialCur += read_byte_size; - m_bitsSize += read_byte_size * 8; - if (readSize > m_bitsSize) CHECK_LESS_OR_EQUAL(readSize, m_bitsSize, ()); - result |= (m_bits & (~uint64_t(0) >> (64 - readSize))) << firstReadSize; - m_bits >>= readSize; - m_bitsSize -= readSize; - readSize = 0; - } - m_totalBitsRead += requestedReadSize; - return result; -} diff --git a/coding/bit_streams.hpp b/coding/bit_streams.hpp index 93d585f696..2751403cbf 100644 --- a/coding/bit_streams.hpp +++ b/coding/bit_streams.hpp @@ -1,42 +1,127 @@ -// Author: Artyom Polkovnikov. -// Bits source and sink for sequential read and write of bits. - #pragma once +#include "std/algorithm.hpp" #include "std/cstdint.hpp" +#include "std/limits.hpp" -// Forward declarations. -class Reader; -class Writer; +#include "base/assert.hpp" +#include "base/logging.hpp" -class BitSink +namespace +{ +uint64_t const kByteMask = (static_cast(1) << CHAR_BIT) - 1; +} // namespace + +template +class BitWriter { public: - BitSink(Writer & writer); - // Note! Last byte is flushed in destructor. - ~BitSink(); - uint64_t NumBitsWritten() const { return m_size; } - // Write writeSize number of bits from least significant side of bits number. - void Write(uint64_t bits, uint32_t writeSize); + BitWriter(TWriter & writer) : m_writer(writer), m_buf(0), m_bitsWritten(0) {} + + ~BitWriter() + { + try + { + Flush(); + } + catch (...) + { + LOG(LWARNING, ("Caught an exception when flushing BitWriter.")); + } + } + + // Writes up to CHAR_BIT-1 last bits if they have not been written yet + // and pads them with zeros. + void Flush() + { + if (m_bitsWritten % CHAR_BIT != 0) + m_writer.Write(&m_buf, 1); + } + + // Returns the number of bits that have been sent to BitWriter, + // including those that are in m_buf and are possibly + // not flushed yet. + uint64_t BitsWritten() const { return m_bitsWritten; } + + // Writes n bits starting with the least significant bit. + // They are written one byte at a time so endianness is of no concern. + // All the other bits except for the first n must be set to zero. + void Write(uint8_t bits, uint32_t n) + { + if (n == 0) + return; + CHECK_LESS_OR_EQUAL(n, CHAR_BIT, ()); + uint32_t bufferedBits = m_bitsWritten % CHAR_BIT; + m_bitsWritten += n; + if (n + bufferedBits > 8) + { + uint8_t b = (bits << bufferedBits) | m_buf; + m_writer.Write(&b, 1); + m_buf = bits >> (8 - bufferedBits); + } + else + { + if (bufferedBits > 0) + { + bits = (bits << bufferedBits) | m_buf; + n += bufferedBits; + } + if (n == CHAR_BIT) + { + m_writer.Write(&bits, 1); + bits = 0; + } + m_buf = bits; + } + } + private: - Writer & m_writer; - uint8_t m_lastByte; - uint64_t m_size; - uint64_t m_totalBits; + TWriter & m_writer; + uint8_t m_buf; + uint64_t m_bitsWritten; }; -class BitSource +template +class BitReader { public: - BitSource(Reader & reader); - uint64_t NumBitsRead() const { return m_totalBitsRead; } - // Read readSize number of bits, return it as least significant bits of 64-bit number. - uint64_t Read(uint32_t readSize); + BitReader(TReader & reader) : m_reader(reader), m_bitsRead(0), m_bufferedBits(0), m_buf(0) {} + + // Returns the total number of bits read from this BitReader. + uint32_t BitsRead() const { return m_bitsRead; } + + // Reads n bits and returns them as the least significant bits of an 8-bit number. + // The underlying m_reader is supposed to be byte-aligned (which is the + // case when it reads from the place that was written with BitWriter) because + // Read may use one lookahead byte. + uint8_t Read(uint32_t n) + { + if (n == 0) + return 0; + CHECK_LESS_OR_EQUAL(n, CHAR_BIT, ()); + m_bitsRead += n; + uint8_t result = 0; + if (n <= m_bufferedBits) + { + result = m_buf & (kByteMask >> (CHAR_BIT - n)); + m_bufferedBits -= n; + m_buf >>= n; + } + else + { + uint8_t nextByte; + m_reader.Read(&nextByte, 1); + uint32_t low = n - m_bufferedBits; + result = ((nextByte & (kByteMask >> (CHAR_BIT - low))) << m_bufferedBits) | m_buf; + m_buf = nextByte >> low; + m_bufferedBits = CHAR_BIT - low; + } + return result; + } + private: - Reader & m_reader; - uint64_t m_serialCur; - uint64_t m_serialEnd; - uint64_t m_bits; - uint32_t m_bitsSize; - uint64_t m_totalBitsRead; + TReader & m_reader; + uint32_t m_bitsRead; + uint32_t m_bufferedBits; + uint8_t m_buf; }; diff --git a/coding/coding.pro b/coding/coding.pro index 8b359ec28c..9c48d33449 100644 --- a/coding/coding.pro +++ b/coding/coding.pro @@ -36,7 +36,6 @@ SOURCES += \ arithmetic_codec.cpp \ compressed_bit_vector.cpp \ # compressed_varnum_vector.cpp \ - bit_streams.cpp \ png_memory_encoder.cpp \ HEADERS += \ diff --git a/coding/coding_tests/bit_streams_test.cpp b/coding/coding_tests/bit_streams_test.cpp index f9ad0d23ed..d64acd398a 100644 --- a/coding/coding_tests/bit_streams_test.cpp +++ b/coding/coding_tests/bit_streams_test.cpp @@ -9,30 +9,36 @@ #include "std/vector.hpp" -UNIT_TEST(BitStream_ReadWrite) +UNIT_TEST(BitStreams_Smoke) { + uniform_int_distribution randomBytesDistribution(0, 255); mt19937 rng(0); - uint32_t const NUMS_CNT = 1000; - vector< pair > nums; - for (uint32_t i = 0; i < NUMS_CNT; ++i) + vector> nums; + for (size_t i = 0; i < 100; ++i) { - uint32_t numBits = rng() % 65; - uint64_t num = rng() & ((uint64_t(1) << numBits) - 1); + uint32_t numBits = randomBytesDistribution(rng) % 8; + uint8_t num = randomBytesDistribution(rng) >> (CHAR_BIT - numBits); nums.push_back(make_pair(num, numBits)); } - - vector encodedBits; + for (size_t i = 0; i < 100; ++i) { - MemWriter< vector > encodedBitsWriter(encodedBits); - BitSink bitsSink(encodedBitsWriter); - for (uint32_t i = 0; i < nums.size(); ++i) bitsSink.Write(nums[i].first, nums[i].second); + uint32_t numBits = 8; + uint8_t num = randomBytesDistribution(rng); + nums.push_back(make_pair(num, numBits)); } + + vector encodedBits; + MemWriter> encodedBitsWriter(encodedBits); + BitWriter>> bitSink(encodedBitsWriter); + for (size_t i = 0; i < nums.size(); ++i) + bitSink.Write(nums[i].first, nums[i].second); + MemReader encodedBitsReader(encodedBits.data(), encodedBits.size()); - BitSource bitsSource(encodedBitsReader); - for (uint32_t i = 0; i < nums.size(); ++i) + ReaderSource reader(encodedBitsReader); + BitReader> bitsSource(reader); + for (size_t i = 0; i < nums.size(); ++i) { - uint64_t num = bitsSource.Read(nums[i].second); - TEST_EQUAL(num, nums[i].first, ()); + uint8_t num = bitsSource.Read(nums[i].second); + TEST_EQUAL(num, nums[i].first, (i)); } } - diff --git a/coding/compressed_bit_vector.cpp b/coding/compressed_bit_vector.cpp index 238809845b..d4dd942ade 100644 --- a/coding/compressed_bit_vector.cpp +++ b/coding/compressed_bit_vector.cpp @@ -187,8 +187,8 @@ void BuildCompressedBitVector(Writer & writer, vector const & posOnes, writer.Write(serialSizesEnc.data(), serialSizesEnc.size()); } { - // Second Stage. Encode all bits of all diffs using BitSink. - BitSink bitWriter(writer); + // Second Stage. Encode all bits of all diffs using BitWriter. + BitWriter bitWriter(writer); int64_t prevOnePos = -1; uint64_t totalReadBits = 0; uint64_t totalReadCnts = 0; @@ -319,8 +319,8 @@ void BuildCompressedBitVector(Writer & writer, vector const & posOnes, } { - // Second stage, encode all ranges bits using BitSink. - BitSink bitWriter(writer); + // Second stage, encode all ranges bits using BitWriter. + BitWriter bitWriter(writer); int64_t prevOnePos = -1; uint64_t onesRangeLen = 0; for (uint32_t i = 0; i < posOnes.size(); ++i) @@ -404,8 +404,11 @@ vector DecodeCompressedBitVector(Reader & reader) { ArithmeticDecoder arithDec(*arithDecReader, distrTable); for (uint64_t i = 0; i < cntElements; ++i) bitsUsedVec.push_back(arithDec.Decode()); decodeOffset += encSizesBytesize; - unique_ptr bitReaderReader(reader.CreateSubReader(decodeOffset, serialSize - decodeOffset)); - BitSource bitReader(*bitReaderReader); + unique_ptr bitMemReader( + reader.CreateSubReader(decodeOffset, serialSize - decodeOffset)); + ReaderPtr readerPtr(bitMemReader.get()); + ReaderSource> bitReaderSource(readerPtr); + BitReader>> bitReader(bitReaderSource); int64_t prevOnePos = -1; for (uint64_t i = 0; i < cntElements; ++i) { @@ -456,8 +459,11 @@ vector DecodeCompressedBitVector(Reader & reader) { vector bitsSizes1; for (uint64_t i = 0; i < cntElements1; ++i) bitsSizes1.push_back(arith_dec1.Decode()); decodeOffset += enc1SizesBytesize; - unique_ptr bitReaderReader(reader.CreateSubReader(decodeOffset, serialSize - decodeOffset)); - BitSource bitReader(*bitReaderReader); + unique_ptr bitMemReader( + reader.CreateSubReader(decodeOffset, serialSize - decodeOffset)); + ReaderPtr readerPtr(bitMemReader.get()); + ReaderSource> bitReaderSource(readerPtr); + BitReader>> bitReader(bitReaderSource); uint64_t sum = 0, i0 = 0, i1 = 0; while (i0 < cntElements0 && i1 < cntElements1) {