From 5ae31bb3f374172a5e85f362fac0fbd0f4b32374 Mon Sep 17 00:00:00 2001 From: Artyom Polkovnikov Date: Sun, 16 Nov 2014 20:36:58 +0300 Subject: [PATCH] [coding] [arithmetic_codec] Convert code style to MapsMe C++ style. --- coding/arithmetic_codec.cpp | 181 ++++++++++-------- coding/arithmetic_codec.hpp | 57 +++--- coding/coding_tests/arithmetic_codec_test.cpp | 38 ++-- 3 files changed, 152 insertions(+), 124 deletions(-) diff --git a/coding/arithmetic_codec.cpp b/coding/arithmetic_codec.cpp index 21f9cdff2f..c9efc65c57 100644 --- a/coding/arithmetic_codec.cpp +++ b/coding/arithmetic_codec.cpp @@ -5,127 +5,148 @@ #include "../base/assert.hpp" -using std::vector; - namespace { - inline uint32_t NumHiZeroBits32(uint32_t n) { - uint32_t result = 0; - while ((n & (uint32_t(1) << 31)) == 0) { ++result; n <<= 1; } + inline u32 NumHiZeroBits32(u32 n) + { + u32 result = 0; + while ((n & (u32(1) << 31)) == 0) { ++result; n <<= 1; } return result; } } -vector FreqsToDistrTable(vector const & orig_freqs) { - uint64_t freq_lower_bound = 0; - while (1) { +vector FreqsToDistrTable(vector const & origFreqs) +{ + u64 freqLowerBound = 0; + while (1) + { // Resulting distr table is initialized with first zero value. - vector result(1, 0); - vector freqs; - uint32_t sum = 0; - uint64_t min_freq = ~uint64_t(0); - for (uint32_t i = 0; i < orig_freqs.size(); ++i) { - uint32_t freq = orig_freqs[i]; - if (freq > 0 && freq < min_freq) min_freq = freq; - if (freq > 0 && freq < freq_lower_bound) freq = freq_lower_bound; + vector result(1, 0); + vector freqs; + u32 sum = 0; + u64 minFreq = ~u64(0); + for (u32 i = 0; i < origFreqs.size(); ++i) + { + u32 freq = origFreqs[i]; + if (freq > 0 && freq < minFreq) minFreq = freq; + if (freq > 0 && freq < freqLowerBound) freq = freqLowerBound; freqs.push_back(freq); sum += freq; result.push_back(sum); } - if (freq_lower_bound == 0) freq_lower_bound = min_freq; + if (freqLowerBound == 0) freqLowerBound = minFreq; // This flag shows that some interval with non-zero freq has // degraded to zero interval in normalized distribution table. - bool has_degraded_zero_interval = false; - for (uint32_t i = 1; i < result.size(); ++i) { - result[i] = (uint64_t(result[i]) << c_distr_shift) / uint64_t(sum); - if (freqs[i - 1] > 0 && (result[i] - result[i - 1] == 0)) { - has_degraded_zero_interval = true; + bool hasDegradedZeroInterval = false; + for (u32 i = 1; i < result.size(); ++i) + { + result[i] = (u64(result[i]) << DISTR_SHIFT) / u64(sum); + if (freqs[i - 1] > 0 && (result[i] - result[i - 1] == 0)) + { + hasDegradedZeroInterval = true; break; } } - if (!has_degraded_zero_interval) return result; - ++freq_lower_bound; + if (!hasDegradedZeroInterval) return result; + ++freqLowerBound; } } -ArithmeticEncoder::ArithmeticEncoder(vector const & distr_table) - : begin_(0), size_(-1), distr_table_(distr_table) {} +ArithmeticEncoder::ArithmeticEncoder(vector const & distrTable) + : m_begin(0), m_size(-1), m_distrTable(distrTable) {} -void ArithmeticEncoder::Encode(uint32_t symbol) { - ASSERT_LESS(symbol + 1, distr_table_.size(), ()); - uint32_t distr_begin = distr_table_[symbol]; - uint32_t distr_end = distr_table_[symbol + 1]; - ASSERT_LESS(distr_begin, distr_end, ()); - uint32_t prev_begin = begin_; - begin_ += (size_ >> c_distr_shift) * distr_begin; - size_ = (size_ >> c_distr_shift) * (distr_end - distr_begin); - if (begin_ < prev_begin) PropagateCarry(); - while (size_ < (uint32_t(1) << 24)) { - output_.push_back(uint8_t(begin_ >> 24)); - begin_ <<= 8; - size_ <<= 8; +void ArithmeticEncoder::Encode(u32 symbol) +{ + CHECK_LESS(symbol + 1, m_distrTable.size(), ()); + u32 distrBegin = m_distrTable[symbol]; + u32 distrEnd = m_distrTable[symbol + 1]; + CHECK_LESS(distrBegin, distrEnd, ()); + u32 prevBegin = m_begin; + m_begin += (m_size >> DISTR_SHIFT) * distrBegin; + m_size = (m_size >> DISTR_SHIFT) * (distrEnd - distrBegin); + if (m_begin < prevBegin) PropagateCarry(); + while (m_size < (u32(1) << 24)) + { + m_output.push_back(u8(m_begin >> 24)); + m_begin <<= 8; + m_size <<= 8; } } -vector ArithmeticEncoder::Finalize() { - ASSERT_GREATER(size_, 0, ()); - uint32_t last = begin_ + size_ - 1; - if (last < begin_) { +vector ArithmeticEncoder::Finalize() +{ + CHECK_GREATER(m_size, 0, ()); + u32 last = m_begin + m_size - 1; + if (last < m_begin) + { PropagateCarry(); - } else { - uint32_t result_hi_bits = NumHiZeroBits32(begin_ ^ last) + 1; - uint32_t value = last & (~uint32_t(0) << (32 - result_hi_bits)); - while (value != 0) { - output_.push_back(uint8_t(value >> 24)); + } + else + { + u32 resultHiBits = NumHiZeroBits32(m_begin ^ last) + 1; + u32 value = last & (~u32(0) << (32 - resultHiBits)); + while (value != 0) + { + m_output.push_back(u8(value >> 24)); value <<= 8; } } - begin_ = 0; - size_ = 0; - return output_; + m_begin = 0; + m_size = 0; + return m_output; } -void ArithmeticEncoder::PropagateCarry() { - int i = output_.size() - 1; - while (i >= 0 && output_[i] == 0xFF) { - output_[i] = 0; +void ArithmeticEncoder::PropagateCarry() +{ + int i = m_output.size() - 1; + while (i >= 0 && m_output[i] == 0xFF) + { + m_output[i] = 0; --i; } - ASSERT_GREATER_OR_EQUAL(i, 0, ()); - ++output_[i]; + CHECK_GREATER_OR_EQUAL(i, 0, ()); + ++m_output[i]; } -ArithmeticDecoder::ArithmeticDecoder(Reader & reader, vector const & distr_table) - : code_value_(0), size_(-1), reader_(reader), serial_cur_(0), serial_end_(reader.Size()), distr_table_(distr_table) { - for (uint32_t i = 0; i < sizeof(code_value_); ++i) { - code_value_ <<= 8; - code_value_ |= ReadCodeByte(); +ArithmeticDecoder::ArithmeticDecoder(Reader & reader, vector const & distrTable) + : m_codeValue(0), m_size(-1), m_reader(reader), m_serialCur(0), + m_serialEnd(reader.Size()), m_distrTable(distrTable) +{ + for (u32 i = 0; i < sizeof(m_codeValue); ++i) + { + m_codeValue <<= 8; + m_codeValue |= ReadCodeByte(); } } -uint32_t ArithmeticDecoder::Decode() { - uint32_t l = 0, r = distr_table_.size(), m = 0; - while (r - l > 1) { +u32 ArithmeticDecoder::Decode() +{ + u32 l = 0, r = m_distrTable.size(), m = 0; + while (r - l > 1) + { m = (l + r) / 2; - uint32_t interval_begin = (size_ >> c_distr_shift) * distr_table_[m]; - if (interval_begin <= code_value_) l = m; else r = m; + u32 intervalBegin = (m_size >> DISTR_SHIFT) * m_distrTable[m]; + if (intervalBegin <= m_codeValue) l = m; else r = m; } - uint32_t symbol = l; - code_value_ -= (size_ >> c_distr_shift) * distr_table_[symbol]; - size_ = (size_ >> c_distr_shift) * (distr_table_[symbol + 1] - distr_table_[symbol]); - while (size_ < (uint32_t(1) << 24)) { - code_value_ <<= 8; - size_ <<= 8; - code_value_ |= ReadCodeByte(); + u32 symbol = l; + m_codeValue -= (m_size >> DISTR_SHIFT) * m_distrTable[symbol]; + m_size = (m_size >> DISTR_SHIFT) * (m_distrTable[symbol + 1] - m_distrTable[symbol]); + while (m_size < (u32(1) << 24)) + { + m_codeValue <<= 8; + m_size <<= 8; + m_codeValue |= ReadCodeByte(); } return symbol; } -uint8_t ArithmeticDecoder::ReadCodeByte() { - if (serial_cur_ >= serial_end_) return 0; - else { - uint8_t result = 0; - reader_.Read(serial_cur_, &result, 1); - ++serial_cur_; +u8 ArithmeticDecoder::ReadCodeByte() +{ + if (m_serialCur >= m_serialEnd) return 0; + else + { + u8 result = 0; + m_reader.Read(m_serialCur, &result, 1); + ++m_serialCur; return result; } } diff --git a/coding/arithmetic_codec.hpp b/coding/arithmetic_codec.hpp index 5512374875..7f5c8bb5ab 100644 --- a/coding/arithmetic_codec.hpp +++ b/coding/arithmetic_codec.hpp @@ -5,15 +5,15 @@ // // Compute freqs table by counting number of occurancies of each symbol. // // Freqs table should have size equal to number of symbols in the alphabet. // // Convert freqs table to distr table. -// vector distr_table = FreqsToDistrTable(freqs); -// ArithmeticEncoder arith_enc(distr_table); +// vector distrTable = FreqsToDistrTable(freqs); +// ArithmeticEncoder arith_enc(distrTable); // // Encode any number of symbols. // arith_enc.Encode(10); arith_enc.Encode(17); arith_enc.Encode(0); arith_enc.Encode(4); // // Get encoded bytes. // vector encoded_data = arith_enc.Finalize(); // // Decode encoded bytes. Number of symbols should be provided outside. // MemReader reader(encoded_data.data(), encoded_data.size()); -// ArithmeticDecoder arith_dec(&reader, distr_table); +// ArithmeticDecoder arith_dec(&reader, distrTable); // uint32_t sym1 = arith_dec.Decode(); uint32_t sym2 = arith_dec.Decode(); // uint32_t sym3 = arith_dec.Decode(); uint32_t sym4 = arith_dec.Decode(); @@ -22,52 +22,59 @@ #include "../std/stdint.hpp" #include "../std/vector.hpp" +typedef uint8_t u8; +typedef uint32_t u32; +typedef uint64_t u64; + +// Forward declarations. class Reader; // Default shift of distribution table, i.e. all distribution table frequencies are -// normalized by this shift, i.e. distr table upper bound equals (1 << c_distr_shift). -uint32_t const c_distr_shift = 16; +// normalized by this shift, i.e. distr table upper bound equals (1 << DISTR_SHIFT). +u32 const DISTR_SHIFT = 16; // Converts symbols frequencies table to distribution table, used in Arithmetic codecs. -std::vector FreqsToDistrTable(std::vector const & freqs); +vector FreqsToDistrTable(vector const & freqs); -class ArithmeticEncoder { +class ArithmeticEncoder +{ public: // Provided distribution table. - ArithmeticEncoder(std::vector const & distr_table); + ArithmeticEncoder(vector const & distrTable); // Encode symbol using given distribution table and add that symbol to output. - void Encode(uint32_t symbol); + void Encode(u32 symbol); // Finalize encoding, flushes remaining bytes from the buffer to output. // Returns output vector of encoded bytes. - std::vector Finalize(); + vector Finalize(); private: // Propagates carry in case of overflow. void PropagateCarry(); private: - uint32_t begin_; - uint32_t size_; - std::vector output_; - std::vector const & distr_table_; + u32 m_begin; + u32 m_size; + vector m_output; + vector const & m_distrTable; }; -class ArithmeticDecoder { +class ArithmeticDecoder +{ public: // Decoder is given a reader to read input bytes, - // distr_table - distribution table to decode symbols. - ArithmeticDecoder(Reader & reader, std::vector const & distr_table); + // distrTable - distribution table to decode symbols. + ArithmeticDecoder(Reader & reader, vector const & distrTable); // Decode next symbol from the encoded stream. - uint32_t Decode(); + u32 Decode(); private: // Read next code byte from encoded stream. - uint8_t ReadCodeByte(); + u8 ReadCodeByte(); private: // Current most significant part of code value. - uint32_t code_value_; + u32 m_codeValue; // Current interval size. - uint32_t size_; + u32 m_size; // Reader and two bounds of encoded data within this Reader. - Reader & reader_; - uint64_t serial_cur_; - uint64_t serial_end_; + Reader & m_reader; + u64 m_serialCur; + u64 m_serialEnd; - std::vector const & distr_table_; + vector const & m_distrTable; }; diff --git a/coding/coding_tests/arithmetic_codec_test.cpp b/coding/coding_tests/arithmetic_codec_test.cpp index 9af4451534..f6d7ff53ac 100644 --- a/coding/coding_tests/arithmetic_codec_test.cpp +++ b/coding/coding_tests/arithmetic_codec_test.cpp @@ -7,35 +7,35 @@ UNIT_TEST(ArithmeticCodec) { PseudoRNG32 rng; - uint32_t const c_max_freq = 2048; - uint32_t const c_alphabet_size = 256; - vector symbols; - vector freqs; + u32 const MAX_FREQ = 2048; + u32 const ALPHABET_SIZE = 256; + vector symbols; + vector freqs; // Generate random freqs. - for (uint32_t i = 0; i < c_alphabet_size; ++i) { - uint32_t freq = rng.Generate() % c_max_freq; + for (u32 i = 0; i < ALPHABET_SIZE; ++i) { + u32 freq = rng.Generate() % MAX_FREQ; freqs.push_back(freq); } // Make at least one frequency zero for corner cases. freqs[freqs.size() / 2] = 0; // Generate symbols based on given freqs. - for (uint32_t i = 0; i < freqs.size(); ++i) { - uint32_t freq = freqs[i]; - for (uint32_t j = 0; j < freq; ++j) { - uint32_t pos = rng.Generate() % (symbols.size() + 1); + for (u32 i = 0; i < freqs.size(); ++i) { + u32 freq = freqs[i]; + for (u32 j = 0; j < freq; ++j) { + u32 pos = rng.Generate() % (symbols.size() + 1); symbols.insert(symbols.begin() + pos, 1, i); } } - vector distr_table = FreqsToDistrTable(freqs); + vector distrTable = FreqsToDistrTable(freqs); // Encode symbols. - ArithmeticEncoder arith_enc(distr_table); - for (uint32_t i = 0; i < symbols.size(); ++i) arith_enc.Encode(symbols[i]); - vector encoded_data = arith_enc.Finalize(); + ArithmeticEncoder arithEnc(distrTable); + for (u32 i = 0; i < symbols.size(); ++i) arithEnc.Encode(symbols[i]); + vector encodedData = arithEnc.Finalize(); // Decode symbols. - MemReader reader(encoded_data.data(), encoded_data.size()); - ArithmeticDecoder arith_dec(reader, distr_table); - for (uint32_t i = 0; i < symbols.size(); ++i) { - uint32_t decoded_symbol = arith_dec.Decode(); - TEST_EQUAL(symbols[i], decoded_symbol, ()); + MemReader reader(encodedData.data(), encodedData.size()); + ArithmeticDecoder arithDec(reader, distrTable); + for (u32 i = 0; i < symbols.size(); ++i) { + u32 decodedSymbol = arithDec.Decode(); + TEST_EQUAL(symbols[i], decodedSymbol, ()); } }