diff --git a/coding/arithmetic_codec.cpp b/coding/arithmetic_codec.cpp index 0e5a18a392..1054f1450b 100644 --- a/coding/arithmetic_codec.cpp +++ b/coding/arithmetic_codec.cpp @@ -6,19 +6,19 @@ #include "../base/assert.hpp" #include "../base/bits.hpp" -vector FreqsToDistrTable(vector const & origFreqs) +vector FreqsToDistrTable(vector const & origFreqs) { - u64 freqLowerBound = 0; + uint64_t freqLowerBound = 0; while (1) { // Resulting distr table is initialized with first zero value. - vector result(1, 0); - vector freqs; - u64 sum = 0; - u64 minFreq = ~u64(0); - for (u32 i = 0; i < origFreqs.size(); ++i) + vector result(1, 0); + vector freqs; + uint64_t sum = 0; + uint64_t minFreq = ~uint64_t(0); + for (uint32_t i = 0; i < origFreqs.size(); ++i) { - u32 freq = origFreqs[i]; + uint32_t freq = origFreqs[i]; if (freq > 0 && freq < minFreq) minFreq = freq; if (freq > 0 && freq < freqLowerBound) freq = freqLowerBound; freqs.push_back(freq); @@ -29,7 +29,7 @@ vector FreqsToDistrTable(vector const & origFreqs) // This flag shows that some interval with non-zero freq has // degraded to zero interval in normalized distribution table. bool hasDegradedZeroInterval = false; - for (u32 i = 1; i < result.size(); ++i) + for (uint32_t i = 1; i < result.size(); ++i) { result[i] = (result[i] << DISTR_SHIFT) / sum; if (freqs[i - 1] > 0 && (result[i] - result[i - 1] == 0)) @@ -40,50 +40,50 @@ vector FreqsToDistrTable(vector const & origFreqs) } if (!hasDegradedZeroInterval) { // Convert distr_table to 32-bit vector, although currently only 17 bits are used. - vector distr_table; - for (u32 i = 0; i < result.size(); ++i) distr_table.push_back(result[i]); + vector distr_table; + for (uint32_t i = 0; i < result.size(); ++i) distr_table.push_back(result[i]); return distr_table; } ++freqLowerBound; } } -ArithmeticEncoder::ArithmeticEncoder(vector const & distrTable) +ArithmeticEncoder::ArithmeticEncoder(vector const & distrTable) : m_begin(0), m_size(-1), m_distrTable(distrTable) {} -void ArithmeticEncoder::Encode(u32 symbol) +void ArithmeticEncoder::Encode(uint32_t symbol) { CHECK_LESS(symbol + 1, m_distrTable.size(), ()); - u32 distrBegin = m_distrTable[symbol]; - u32 distrEnd = m_distrTable[symbol + 1]; + uint32_t distrBegin = m_distrTable[symbol]; + uint32_t distrEnd = m_distrTable[symbol + 1]; CHECK_LESS(distrBegin, distrEnd, ()); - u32 prevBegin = m_begin; + uint32_t prevBegin = m_begin; m_begin += (m_size >> DISTR_SHIFT) * distrBegin; m_size = (m_size >> DISTR_SHIFT) * (distrEnd - distrBegin); if (m_begin < prevBegin) PropagateCarry(); - while (m_size < (u32(1) << 24)) + while (m_size < (uint32_t(1) << 24)) { - m_output.push_back(u8(m_begin >> 24)); + m_output.push_back(uint8_t(m_begin >> 24)); m_begin <<= 8; m_size <<= 8; } } -vector ArithmeticEncoder::Finalize() +vector ArithmeticEncoder::Finalize() { CHECK_GREATER(m_size, 0, ()); - u32 last = m_begin + m_size - 1; + uint32_t last = m_begin + m_size - 1; if (last < m_begin) { PropagateCarry(); } else { - u32 resultHiBits = bits::NumHiZeroBits32(m_begin ^ last) + 1; - u32 value = last & (~u32(0) << (32 - resultHiBits)); + uint32_t resultHiBits = bits::NumHiZeroBits32(m_begin ^ last) + 1; + uint32_t value = last & (~uint32_t(0) << (32 - resultHiBits)); while (value != 0) { - m_output.push_back(u8(value >> 24)); + m_output.push_back(uint8_t(value >> 24)); value <<= 8; } } @@ -104,31 +104,31 @@ void ArithmeticEncoder::PropagateCarry() ++m_output[i]; } -ArithmeticDecoder::ArithmeticDecoder(Reader & reader, vector const & distrTable) +ArithmeticDecoder::ArithmeticDecoder(Reader & reader, vector const & distrTable) : m_codeValue(0), m_size(-1), m_reader(reader), m_serialCur(0), m_serialEnd(reader.Size()), m_distrTable(distrTable) { - for (u32 i = 0; i < sizeof(m_codeValue); ++i) + for (uint32_t i = 0; i < sizeof(m_codeValue); ++i) { m_codeValue <<= 8; m_codeValue |= ReadCodeByte(); } } -u32 ArithmeticDecoder::Decode() +uint32_t ArithmeticDecoder::Decode() { - u32 l = 0, r = m_distrTable.size(), m = 0; - u32 shiftedSize = m_size >> DISTR_SHIFT; + uint32_t l = 0, r = m_distrTable.size(), m = 0; + uint32_t shiftedSize = m_size >> DISTR_SHIFT; while (r - l > 1) { m = (l + r) / 2; - u32 intervalBegin = shiftedSize * m_distrTable[m]; + uint32_t intervalBegin = shiftedSize * m_distrTable[m]; if (intervalBegin <= m_codeValue) l = m; else r = m; } - u32 symbol = l; + uint32_t symbol = l; m_codeValue -= shiftedSize * m_distrTable[symbol]; m_size = shiftedSize * (m_distrTable[symbol + 1] - m_distrTable[symbol]); - while (m_size < (u32(1) << 24)) + while (m_size < (uint32_t(1) << 24)) { m_codeValue <<= 8; m_size <<= 8; @@ -137,12 +137,12 @@ u32 ArithmeticDecoder::Decode() return symbol; } -u8 ArithmeticDecoder::ReadCodeByte() +uint8_t ArithmeticDecoder::ReadCodeByte() { if (m_serialCur >= m_serialEnd) return 0; else { - u8 result = 0; + uint8_t result = 0; m_reader.Read(m_serialCur, &result, 1); ++m_serialCur; return result; diff --git a/coding/arithmetic_codec.hpp b/coding/arithmetic_codec.hpp index 7f5c8bb5ab..5f7d9cee23 100644 --- a/coding/arithmetic_codec.hpp +++ b/coding/arithmetic_codec.hpp @@ -22,37 +22,33 @@ #include "../std/stdint.hpp" #include "../std/vector.hpp" -typedef uint8_t u8; -typedef uint32_t u32; -typedef uint64_t u64; - // Forward declarations. class Reader; // Default shift of distribution table, i.e. all distribution table frequencies are // normalized by this shift, i.e. distr table upper bound equals (1 << DISTR_SHIFT). -u32 const DISTR_SHIFT = 16; +uint32_t const DISTR_SHIFT = 16; // Converts symbols frequencies table to distribution table, used in Arithmetic codecs. -vector FreqsToDistrTable(vector const & freqs); +vector FreqsToDistrTable(vector const & freqs); class ArithmeticEncoder { public: // Provided distribution table. - ArithmeticEncoder(vector const & distrTable); + ArithmeticEncoder(vector const & distrTable); // Encode symbol using given distribution table and add that symbol to output. - void Encode(u32 symbol); + void Encode(uint32_t symbol); // Finalize encoding, flushes remaining bytes from the buffer to output. // Returns output vector of encoded bytes. - vector Finalize(); + vector Finalize(); private: // Propagates carry in case of overflow. void PropagateCarry(); private: - u32 m_begin; - u32 m_size; - vector m_output; - vector const & m_distrTable; + uint32_t m_begin; + uint32_t m_size; + vector m_output; + vector const & m_distrTable; }; class ArithmeticDecoder @@ -60,21 +56,21 @@ class ArithmeticDecoder public: // Decoder is given a reader to read input bytes, // distrTable - distribution table to decode symbols. - ArithmeticDecoder(Reader & reader, vector const & distrTable); + ArithmeticDecoder(Reader & reader, vector const & distrTable); // Decode next symbol from the encoded stream. - u32 Decode(); + uint32_t Decode(); private: // Read next code byte from encoded stream. - u8 ReadCodeByte(); + uint8_t ReadCodeByte(); private: // Current most significant part of code value. - u32 m_codeValue; + uint32_t m_codeValue; // Current interval size. - u32 m_size; + uint32_t m_size; // Reader and two bounds of encoded data within this Reader. Reader & m_reader; - u64 m_serialCur; - u64 m_serialEnd; + uint64_t m_serialCur; + uint64_t m_serialEnd; - vector const & m_distrTable; + vector const & m_distrTable; }; diff --git a/coding/coding_tests/arithmetic_codec_test.cpp b/coding/coding_tests/arithmetic_codec_test.cpp index f6d7ff53ac..3348ac9f84 100644 --- a/coding/coding_tests/arithmetic_codec_test.cpp +++ b/coding/coding_tests/arithmetic_codec_test.cpp @@ -7,35 +7,35 @@ UNIT_TEST(ArithmeticCodec) { PseudoRNG32 rng; - u32 const MAX_FREQ = 2048; - u32 const ALPHABET_SIZE = 256; - vector symbols; - vector freqs; + uint32_t const MAX_FREQ = 2048; + uint32_t const ALPHABET_SIZE = 256; + vector symbols; + vector freqs; // Generate random freqs. - for (u32 i = 0; i < ALPHABET_SIZE; ++i) { - u32 freq = rng.Generate() % MAX_FREQ; + for (uint32_t i = 0; i < ALPHABET_SIZE; ++i) { + uint32_t freq = rng.Generate() % MAX_FREQ; freqs.push_back(freq); } // Make at least one frequency zero for corner cases. freqs[freqs.size() / 2] = 0; // Generate symbols based on given freqs. - for (u32 i = 0; i < freqs.size(); ++i) { - u32 freq = freqs[i]; - for (u32 j = 0; j < freq; ++j) { - u32 pos = rng.Generate() % (symbols.size() + 1); + for (uint32_t i = 0; i < freqs.size(); ++i) { + uint32_t freq = freqs[i]; + for (uint32_t j = 0; j < freq; ++j) { + uint32_t pos = rng.Generate() % (symbols.size() + 1); symbols.insert(symbols.begin() + pos, 1, i); } } - vector distrTable = FreqsToDistrTable(freqs); + vector distrTable = FreqsToDistrTable(freqs); // Encode symbols. ArithmeticEncoder arithEnc(distrTable); - for (u32 i = 0; i < symbols.size(); ++i) arithEnc.Encode(symbols[i]); - vector encodedData = arithEnc.Finalize(); + for (uint32_t i = 0; i < symbols.size(); ++i) arithEnc.Encode(symbols[i]); + vector encodedData = arithEnc.Finalize(); // Decode symbols. MemReader reader(encodedData.data(), encodedData.size()); ArithmeticDecoder arithDec(reader, distrTable); - for (u32 i = 0; i < symbols.size(); ++i) { - u32 decodedSymbol = arithDec.Decode(); + for (uint32_t i = 0; i < symbols.size(); ++i) { + uint32_t decodedSymbol = arithDec.Decode(); TEST_EQUAL(symbols[i], decodedSymbol, ()); } }