[omim] [coding] BitReader and BitWriter.
This commit is contained in:
parent
c82b3e927f
commit
4b670d0671
5 changed files with 149 additions and 127 deletions
|
@ -1,74 +0,0 @@
|
|||
#include "coding/bit_streams.hpp"
|
||||
|
||||
#include "coding/reader.hpp"
|
||||
#include "coding/writer.hpp"
|
||||
|
||||
BitSink::BitSink(Writer & writer)
|
||||
: m_writer(writer), m_lastByte(0), m_size(0) {}
|
||||
|
||||
BitSink::~BitSink()
|
||||
{
|
||||
if (m_size % 8 > 0) m_writer.Write(&m_lastByte, 1);
|
||||
}
|
||||
|
||||
void BitSink::Write(uint64_t bits, uint32_t writeSize)
|
||||
{
|
||||
if (writeSize == 0) return;
|
||||
CHECK_LESS_OR_EQUAL(writeSize, 64, ());
|
||||
m_totalBits += writeSize;
|
||||
uint32_t remSize = m_size % 8;
|
||||
if (writeSize > 64 - remSize)
|
||||
{
|
||||
uint64_t writeData = (bits << remSize) | m_lastByte;
|
||||
m_writer.Write(&writeData, sizeof(writeData));
|
||||
m_lastByte = uint8_t(bits >> (64 - remSize));
|
||||
m_size += writeSize;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (remSize > 0)
|
||||
{
|
||||
bits <<= remSize;
|
||||
bits |= m_lastByte;
|
||||
writeSize += remSize;
|
||||
m_size -= remSize;
|
||||
}
|
||||
uint32_t writeBytesSize = writeSize / 8;
|
||||
m_writer.Write(&bits, writeBytesSize);
|
||||
m_lastByte = (bits >> (writeBytesSize * 8)) & ((1 << (writeSize % 8)) - 1);
|
||||
m_size += writeSize;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
BitSource::BitSource(Reader & reader)
|
||||
: m_reader(reader), m_serialCur(0), m_serialEnd(reader.Size()),
|
||||
m_bits(0), m_bitsSize(0), m_totalBitsRead(0) {}
|
||||
|
||||
uint64_t BitSource::Read(uint32_t readSize)
|
||||
{
|
||||
uint32_t requestedReadSize = readSize;
|
||||
if (readSize == 0) return 0;
|
||||
CHECK_LESS_OR_EQUAL(readSize, 64, ());
|
||||
// First read, sets bits that are in the m_bits buffer.
|
||||
uint32_t firstReadSize = readSize <= m_bitsSize ? readSize : m_bitsSize;
|
||||
uint64_t result = m_bits & (~uint64_t(0) >> (64 - firstReadSize));
|
||||
m_bits >>= firstReadSize;
|
||||
m_bitsSize -= firstReadSize;
|
||||
readSize -= firstReadSize;
|
||||
// Second read, does an extra read using m_reader.
|
||||
if (readSize > 0)
|
||||
{
|
||||
size_t read_byte_size = m_serialCur + sizeof(m_bits) <= m_serialEnd ? sizeof(m_bits) : m_serialEnd - m_serialCur;
|
||||
m_reader.Read(m_serialCur, &m_bits, read_byte_size);
|
||||
m_serialCur += read_byte_size;
|
||||
m_bitsSize += read_byte_size * 8;
|
||||
if (readSize > m_bitsSize) CHECK_LESS_OR_EQUAL(readSize, m_bitsSize, ());
|
||||
result |= (m_bits & (~uint64_t(0) >> (64 - readSize))) << firstReadSize;
|
||||
m_bits >>= readSize;
|
||||
m_bitsSize -= readSize;
|
||||
readSize = 0;
|
||||
}
|
||||
m_totalBitsRead += requestedReadSize;
|
||||
return result;
|
||||
}
|
|
@ -1,42 +1,127 @@
|
|||
// Author: Artyom Polkovnikov.
|
||||
// Bits source and sink for sequential read and write of bits.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "std/algorithm.hpp"
|
||||
#include "std/cstdint.hpp"
|
||||
#include "std/limits.hpp"
|
||||
|
||||
// Forward declarations.
|
||||
class Reader;
|
||||
class Writer;
|
||||
#include "base/assert.hpp"
|
||||
#include "base/logging.hpp"
|
||||
|
||||
class BitSink
|
||||
namespace
|
||||
{
|
||||
uint64_t const kByteMask = (static_cast<uint64_t>(1) << CHAR_BIT) - 1;
|
||||
} // namespace
|
||||
|
||||
template <typename TWriter>
|
||||
class BitWriter
|
||||
{
|
||||
public:
|
||||
BitSink(Writer & writer);
|
||||
// Note! Last byte is flushed in destructor.
|
||||
~BitSink();
|
||||
uint64_t NumBitsWritten() const { return m_size; }
|
||||
// Write writeSize number of bits from least significant side of bits number.
|
||||
void Write(uint64_t bits, uint32_t writeSize);
|
||||
BitWriter(TWriter & writer) : m_writer(writer), m_buf(0), m_bitsWritten(0) {}
|
||||
|
||||
~BitWriter()
|
||||
{
|
||||
try
|
||||
{
|
||||
Flush();
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
LOG(LWARNING, ("Caught an exception when flushing BitWriter."));
|
||||
}
|
||||
}
|
||||
|
||||
// Writes up to CHAR_BIT-1 last bits if they have not been written yet
|
||||
// and pads them with zeros.
|
||||
void Flush()
|
||||
{
|
||||
if (m_bitsWritten % CHAR_BIT != 0)
|
||||
m_writer.Write(&m_buf, 1);
|
||||
}
|
||||
|
||||
// Returns the number of bits that have been sent to BitWriter,
|
||||
// including those that are in m_buf and are possibly
|
||||
// not flushed yet.
|
||||
uint64_t BitsWritten() const { return m_bitsWritten; }
|
||||
|
||||
// Writes n bits starting with the least significant bit.
|
||||
// They are written one byte at a time so endianness is of no concern.
|
||||
// All the other bits except for the first n must be set to zero.
|
||||
void Write(uint8_t bits, uint32_t n)
|
||||
{
|
||||
if (n == 0)
|
||||
return;
|
||||
CHECK_LESS_OR_EQUAL(n, CHAR_BIT, ());
|
||||
uint32_t bufferedBits = m_bitsWritten % CHAR_BIT;
|
||||
m_bitsWritten += n;
|
||||
if (n + bufferedBits > 8)
|
||||
{
|
||||
uint8_t b = (bits << bufferedBits) | m_buf;
|
||||
m_writer.Write(&b, 1);
|
||||
m_buf = bits >> (8 - bufferedBits);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (bufferedBits > 0)
|
||||
{
|
||||
bits = (bits << bufferedBits) | m_buf;
|
||||
n += bufferedBits;
|
||||
}
|
||||
if (n == CHAR_BIT)
|
||||
{
|
||||
m_writer.Write(&bits, 1);
|
||||
bits = 0;
|
||||
}
|
||||
m_buf = bits;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Writer & m_writer;
|
||||
uint8_t m_lastByte;
|
||||
uint64_t m_size;
|
||||
uint64_t m_totalBits;
|
||||
TWriter & m_writer;
|
||||
uint8_t m_buf;
|
||||
uint64_t m_bitsWritten;
|
||||
};
|
||||
|
||||
class BitSource
|
||||
template <typename TReader>
|
||||
class BitReader
|
||||
{
|
||||
public:
|
||||
BitSource(Reader & reader);
|
||||
uint64_t NumBitsRead() const { return m_totalBitsRead; }
|
||||
// Read readSize number of bits, return it as least significant bits of 64-bit number.
|
||||
uint64_t Read(uint32_t readSize);
|
||||
BitReader(TReader & reader) : m_reader(reader), m_bitsRead(0), m_bufferedBits(0), m_buf(0) {}
|
||||
|
||||
// Returns the total number of bits read from this BitReader.
|
||||
uint32_t BitsRead() const { return m_bitsRead; }
|
||||
|
||||
// Reads n bits and returns them as the least significant bits of an 8-bit number.
|
||||
// The underlying m_reader is supposed to be byte-aligned (which is the
|
||||
// case when it reads from the place that was written with BitWriter) because
|
||||
// Read may use one lookahead byte.
|
||||
uint8_t Read(uint32_t n)
|
||||
{
|
||||
if (n == 0)
|
||||
return 0;
|
||||
CHECK_LESS_OR_EQUAL(n, CHAR_BIT, ());
|
||||
m_bitsRead += n;
|
||||
uint8_t result = 0;
|
||||
if (n <= m_bufferedBits)
|
||||
{
|
||||
result = m_buf & (kByteMask >> (CHAR_BIT - n));
|
||||
m_bufferedBits -= n;
|
||||
m_buf >>= n;
|
||||
}
|
||||
else
|
||||
{
|
||||
uint8_t nextByte;
|
||||
m_reader.Read(&nextByte, 1);
|
||||
uint32_t low = n - m_bufferedBits;
|
||||
result = ((nextByte & (kByteMask >> (CHAR_BIT - low))) << m_bufferedBits) | m_buf;
|
||||
m_buf = nextByte >> low;
|
||||
m_bufferedBits = CHAR_BIT - low;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
Reader & m_reader;
|
||||
uint64_t m_serialCur;
|
||||
uint64_t m_serialEnd;
|
||||
uint64_t m_bits;
|
||||
uint32_t m_bitsSize;
|
||||
uint64_t m_totalBitsRead;
|
||||
TReader & m_reader;
|
||||
uint32_t m_bitsRead;
|
||||
uint32_t m_bufferedBits;
|
||||
uint8_t m_buf;
|
||||
};
|
||||
|
|
|
@ -36,7 +36,6 @@ SOURCES += \
|
|||
arithmetic_codec.cpp \
|
||||
compressed_bit_vector.cpp \
|
||||
# compressed_varnum_vector.cpp \
|
||||
bit_streams.cpp \
|
||||
png_memory_encoder.cpp \
|
||||
|
||||
HEADERS += \
|
||||
|
|
|
@ -9,30 +9,36 @@
|
|||
#include "std/vector.hpp"
|
||||
|
||||
|
||||
UNIT_TEST(BitStream_ReadWrite)
|
||||
UNIT_TEST(BitStreams_Smoke)
|
||||
{
|
||||
uniform_int_distribution<uint8_t> randomBytesDistribution(0, 255);
|
||||
mt19937 rng(0);
|
||||
uint32_t const NUMS_CNT = 1000;
|
||||
vector< pair<uint64_t, uint32_t> > nums;
|
||||
for (uint32_t i = 0; i < NUMS_CNT; ++i)
|
||||
vector<pair<uint8_t, uint32_t>> nums;
|
||||
for (size_t i = 0; i < 100; ++i)
|
||||
{
|
||||
uint32_t numBits = rng() % 65;
|
||||
uint64_t num = rng() & ((uint64_t(1) << numBits) - 1);
|
||||
uint32_t numBits = randomBytesDistribution(rng) % 8;
|
||||
uint8_t num = randomBytesDistribution(rng) >> (CHAR_BIT - numBits);
|
||||
nums.push_back(make_pair(num, numBits));
|
||||
}
|
||||
|
||||
vector<uint8_t> encodedBits;
|
||||
for (size_t i = 0; i < 100; ++i)
|
||||
{
|
||||
MemWriter< vector<uint8_t> > encodedBitsWriter(encodedBits);
|
||||
BitSink bitsSink(encodedBitsWriter);
|
||||
for (uint32_t i = 0; i < nums.size(); ++i) bitsSink.Write(nums[i].first, nums[i].second);
|
||||
uint32_t numBits = 8;
|
||||
uint8_t num = randomBytesDistribution(rng);
|
||||
nums.push_back(make_pair(num, numBits));
|
||||
}
|
||||
|
||||
vector<uint8_t> encodedBits;
|
||||
MemWriter<vector<uint8_t>> encodedBitsWriter(encodedBits);
|
||||
BitWriter<MemWriter<vector<uint8_t>>> bitSink(encodedBitsWriter);
|
||||
for (size_t i = 0; i < nums.size(); ++i)
|
||||
bitSink.Write(nums[i].first, nums[i].second);
|
||||
|
||||
MemReader encodedBitsReader(encodedBits.data(), encodedBits.size());
|
||||
BitSource bitsSource(encodedBitsReader);
|
||||
for (uint32_t i = 0; i < nums.size(); ++i)
|
||||
ReaderSource<MemReader> reader(encodedBitsReader);
|
||||
BitReader<ReaderSource<MemReader>> bitsSource(reader);
|
||||
for (size_t i = 0; i < nums.size(); ++i)
|
||||
{
|
||||
uint64_t num = bitsSource.Read(nums[i].second);
|
||||
TEST_EQUAL(num, nums[i].first, ());
|
||||
uint8_t num = bitsSource.Read(nums[i].second);
|
||||
TEST_EQUAL(num, nums[i].first, (i));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -187,8 +187,8 @@ void BuildCompressedBitVector(Writer & writer, vector<uint32_t> const & posOnes,
|
|||
writer.Write(serialSizesEnc.data(), serialSizesEnc.size());
|
||||
}
|
||||
{
|
||||
// Second Stage. Encode all bits of all diffs using BitSink.
|
||||
BitSink bitWriter(writer);
|
||||
// Second Stage. Encode all bits of all diffs using BitWriter.
|
||||
BitWriter<Writer> bitWriter(writer);
|
||||
int64_t prevOnePos = -1;
|
||||
uint64_t totalReadBits = 0;
|
||||
uint64_t totalReadCnts = 0;
|
||||
|
@ -319,8 +319,8 @@ void BuildCompressedBitVector(Writer & writer, vector<uint32_t> const & posOnes,
|
|||
}
|
||||
|
||||
{
|
||||
// Second stage, encode all ranges bits using BitSink.
|
||||
BitSink bitWriter(writer);
|
||||
// Second stage, encode all ranges bits using BitWriter.
|
||||
BitWriter<Writer> bitWriter(writer);
|
||||
int64_t prevOnePos = -1;
|
||||
uint64_t onesRangeLen = 0;
|
||||
for (uint32_t i = 0; i < posOnes.size(); ++i)
|
||||
|
@ -404,8 +404,11 @@ vector<uint32_t> DecodeCompressedBitVector(Reader & reader) {
|
|||
ArithmeticDecoder arithDec(*arithDecReader, distrTable);
|
||||
for (uint64_t i = 0; i < cntElements; ++i) bitsUsedVec.push_back(arithDec.Decode());
|
||||
decodeOffset += encSizesBytesize;
|
||||
unique_ptr<Reader> bitReaderReader(reader.CreateSubReader(decodeOffset, serialSize - decodeOffset));
|
||||
BitSource bitReader(*bitReaderReader);
|
||||
unique_ptr<Reader> bitMemReader(
|
||||
reader.CreateSubReader(decodeOffset, serialSize - decodeOffset));
|
||||
ReaderPtr<Reader> readerPtr(bitMemReader.get());
|
||||
ReaderSource<ReaderPtr<Reader>> bitReaderSource(readerPtr);
|
||||
BitReader<ReaderSource<ReaderPtr<Reader>>> bitReader(bitReaderSource);
|
||||
int64_t prevOnePos = -1;
|
||||
for (uint64_t i = 0; i < cntElements; ++i)
|
||||
{
|
||||
|
@ -456,8 +459,11 @@ vector<uint32_t> DecodeCompressedBitVector(Reader & reader) {
|
|||
vector<uint32_t> bitsSizes1;
|
||||
for (uint64_t i = 0; i < cntElements1; ++i) bitsSizes1.push_back(arith_dec1.Decode());
|
||||
decodeOffset += enc1SizesBytesize;
|
||||
unique_ptr<Reader> bitReaderReader(reader.CreateSubReader(decodeOffset, serialSize - decodeOffset));
|
||||
BitSource bitReader(*bitReaderReader);
|
||||
unique_ptr<Reader> bitMemReader(
|
||||
reader.CreateSubReader(decodeOffset, serialSize - decodeOffset));
|
||||
ReaderPtr<Reader> readerPtr(bitMemReader.get());
|
||||
ReaderSource<ReaderPtr<Reader>> bitReaderSource(readerPtr);
|
||||
BitReader<ReaderSource<ReaderPtr<Reader>>> bitReader(bitReaderSource);
|
||||
uint64_t sum = 0, i0 = 0, i1 = 0;
|
||||
while (i0 < cntElements0 && i1 < cntElements1)
|
||||
{
|
||||
|
|
Reference in a new issue