diff --git a/coding/bit_streams.cpp b/coding/bit_streams.cpp index 32a05b1a25..deb23a9342 100644 --- a/coding/bit_streams.cpp +++ b/coding/bit_streams.cpp @@ -14,20 +14,30 @@ BitSink::~BitSink() void BitSink::Write(uint64_t bits, uint32_t writeSize) { if (writeSize == 0) return; + CHECK_LESS_OR_EQUAL(writeSize, 64, ()); m_totalBits += writeSize; uint32_t remSize = m_size % 8; - CHECK_LESS_OR_EQUAL(writeSize, 64 - remSize, ()); - if (remSize > 0) + if (writeSize > 64 - remSize) { - bits <<= remSize; - bits |= m_lastByte; - writeSize += remSize; - m_size -= remSize; + uint64_t writeData = (bits << remSize) | m_lastByte; + m_writer.Write(&writeData, sizeof(writeData)); + m_lastByte = uint8_t(bits >> (64 - remSize)); + m_size += writeSize; + } + else + { + if (remSize > 0) + { + bits <<= remSize; + bits |= m_lastByte; + writeSize += remSize; + m_size -= remSize; + } + uint32_t writeBytesSize = writeSize / 8; + m_writer.Write(&bits, writeBytesSize); + m_lastByte = (bits >> (writeBytesSize * 8)) & ((1 << (writeSize % 8)) - 1); + m_size += writeSize; } - uint32_t writeBytesSize = writeSize / 8; - m_writer.Write(&bits, writeBytesSize); - m_lastByte = (bits >> (writeBytesSize * 8)) & ((1 << (writeSize % 8)) - 1); - m_size += writeSize; } diff --git a/coding/coding_tests/bit_streams_test.cpp b/coding/coding_tests/bit_streams_test.cpp index eab3b408d4..730dd23745 100644 --- a/coding/coding_tests/bit_streams_test.cpp +++ b/coding/coding_tests/bit_streams_test.cpp @@ -25,7 +25,7 @@ UNIT_TEST(BitStream_ReadWrite) vector< pair > nums; for (uint32_t i = 0; i < NUMS_CNT; ++i) { - uint32_t numBits = GetRand64() % 57; + uint32_t numBits = GetRand64() % 65; uint64_t num = GetRand64() & (uint64_t(-1) >> (64 - numBits)); // Right bit shift by 64 doesn't always work correctly, // this is a workaround.