[coding] [arithmetic_codec] Convert code style to MapsMe C++ style.

This commit is contained in:
Artyom Polkovnikov 2014-11-16 20:36:58 +03:00 committed by Alex Zolotarev
parent f08c7ae381
commit 5ae31bb3f3
3 changed files with 152 additions and 124 deletions

View file

@ -5,127 +5,148 @@
#include "../base/assert.hpp"
using std::vector;
namespace {
inline uint32_t NumHiZeroBits32(uint32_t n) {
uint32_t result = 0;
while ((n & (uint32_t(1) << 31)) == 0) { ++result; n <<= 1; }
inline u32 NumHiZeroBits32(u32 n)
{
u32 result = 0;
while ((n & (u32(1) << 31)) == 0) { ++result; n <<= 1; }
return result;
}
}
vector<uint32_t> FreqsToDistrTable(vector<uint32_t> const & orig_freqs) {
uint64_t freq_lower_bound = 0;
while (1) {
vector<u32> FreqsToDistrTable(vector<u32> const & origFreqs)
{
u64 freqLowerBound = 0;
while (1)
{
// Resulting distr table is initialized with first zero value.
vector<uint32_t> result(1, 0);
vector<uint32_t> freqs;
uint32_t sum = 0;
uint64_t min_freq = ~uint64_t(0);
for (uint32_t i = 0; i < orig_freqs.size(); ++i) {
uint32_t freq = orig_freqs[i];
if (freq > 0 && freq < min_freq) min_freq = freq;
if (freq > 0 && freq < freq_lower_bound) freq = freq_lower_bound;
vector<u32> result(1, 0);
vector<u32> freqs;
u32 sum = 0;
u64 minFreq = ~u64(0);
for (u32 i = 0; i < origFreqs.size(); ++i)
{
u32 freq = origFreqs[i];
if (freq > 0 && freq < minFreq) minFreq = freq;
if (freq > 0 && freq < freqLowerBound) freq = freqLowerBound;
freqs.push_back(freq);
sum += freq;
result.push_back(sum);
}
if (freq_lower_bound == 0) freq_lower_bound = min_freq;
if (freqLowerBound == 0) freqLowerBound = minFreq;
// This flag shows that some interval with non-zero freq has
// degraded to zero interval in normalized distribution table.
bool has_degraded_zero_interval = false;
for (uint32_t i = 1; i < result.size(); ++i) {
result[i] = (uint64_t(result[i]) << c_distr_shift) / uint64_t(sum);
if (freqs[i - 1] > 0 && (result[i] - result[i - 1] == 0)) {
has_degraded_zero_interval = true;
bool hasDegradedZeroInterval = false;
for (u32 i = 1; i < result.size(); ++i)
{
result[i] = (u64(result[i]) << DISTR_SHIFT) / u64(sum);
if (freqs[i - 1] > 0 && (result[i] - result[i - 1] == 0))
{
hasDegradedZeroInterval = true;
break;
}
}
if (!has_degraded_zero_interval) return result;
++freq_lower_bound;
if (!hasDegradedZeroInterval) return result;
++freqLowerBound;
}
}
ArithmeticEncoder::ArithmeticEncoder(vector<uint32_t> const & distr_table)
: begin_(0), size_(-1), distr_table_(distr_table) {}
ArithmeticEncoder::ArithmeticEncoder(vector<u32> const & distrTable)
: m_begin(0), m_size(-1), m_distrTable(distrTable) {}
void ArithmeticEncoder::Encode(uint32_t symbol) {
ASSERT_LESS(symbol + 1, distr_table_.size(), ());
uint32_t distr_begin = distr_table_[symbol];
uint32_t distr_end = distr_table_[symbol + 1];
ASSERT_LESS(distr_begin, distr_end, ());
uint32_t prev_begin = begin_;
begin_ += (size_ >> c_distr_shift) * distr_begin;
size_ = (size_ >> c_distr_shift) * (distr_end - distr_begin);
if (begin_ < prev_begin) PropagateCarry();
while (size_ < (uint32_t(1) << 24)) {
output_.push_back(uint8_t(begin_ >> 24));
begin_ <<= 8;
size_ <<= 8;
void ArithmeticEncoder::Encode(u32 symbol)
{
CHECK_LESS(symbol + 1, m_distrTable.size(), ());
u32 distrBegin = m_distrTable[symbol];
u32 distrEnd = m_distrTable[symbol + 1];
CHECK_LESS(distrBegin, distrEnd, ());
u32 prevBegin = m_begin;
m_begin += (m_size >> DISTR_SHIFT) * distrBegin;
m_size = (m_size >> DISTR_SHIFT) * (distrEnd - distrBegin);
if (m_begin < prevBegin) PropagateCarry();
while (m_size < (u32(1) << 24))
{
m_output.push_back(u8(m_begin >> 24));
m_begin <<= 8;
m_size <<= 8;
}
}
vector<uint8_t> ArithmeticEncoder::Finalize() {
ASSERT_GREATER(size_, 0, ());
uint32_t last = begin_ + size_ - 1;
if (last < begin_) {
vector<u8> ArithmeticEncoder::Finalize()
{
CHECK_GREATER(m_size, 0, ());
u32 last = m_begin + m_size - 1;
if (last < m_begin)
{
PropagateCarry();
} else {
uint32_t result_hi_bits = NumHiZeroBits32(begin_ ^ last) + 1;
uint32_t value = last & (~uint32_t(0) << (32 - result_hi_bits));
while (value != 0) {
output_.push_back(uint8_t(value >> 24));
}
else
{
u32 resultHiBits = NumHiZeroBits32(m_begin ^ last) + 1;
u32 value = last & (~u32(0) << (32 - resultHiBits));
while (value != 0)
{
m_output.push_back(u8(value >> 24));
value <<= 8;
}
}
begin_ = 0;
size_ = 0;
return output_;
m_begin = 0;
m_size = 0;
return m_output;
}
void ArithmeticEncoder::PropagateCarry() {
int i = output_.size() - 1;
while (i >= 0 && output_[i] == 0xFF) {
output_[i] = 0;
void ArithmeticEncoder::PropagateCarry()
{
int i = m_output.size() - 1;
while (i >= 0 && m_output[i] == 0xFF)
{
m_output[i] = 0;
--i;
}
ASSERT_GREATER_OR_EQUAL(i, 0, ());
++output_[i];
CHECK_GREATER_OR_EQUAL(i, 0, ());
++m_output[i];
}
ArithmeticDecoder::ArithmeticDecoder(Reader & reader, vector<uint32_t> const & distr_table)
: code_value_(0), size_(-1), reader_(reader), serial_cur_(0), serial_end_(reader.Size()), distr_table_(distr_table) {
for (uint32_t i = 0; i < sizeof(code_value_); ++i) {
code_value_ <<= 8;
code_value_ |= ReadCodeByte();
ArithmeticDecoder::ArithmeticDecoder(Reader & reader, vector<u32> const & distrTable)
: m_codeValue(0), m_size(-1), m_reader(reader), m_serialCur(0),
m_serialEnd(reader.Size()), m_distrTable(distrTable)
{
for (u32 i = 0; i < sizeof(m_codeValue); ++i)
{
m_codeValue <<= 8;
m_codeValue |= ReadCodeByte();
}
}
uint32_t ArithmeticDecoder::Decode() {
uint32_t l = 0, r = distr_table_.size(), m = 0;
while (r - l > 1) {
u32 ArithmeticDecoder::Decode()
{
u32 l = 0, r = m_distrTable.size(), m = 0;
while (r - l > 1)
{
m = (l + r) / 2;
uint32_t interval_begin = (size_ >> c_distr_shift) * distr_table_[m];
if (interval_begin <= code_value_) l = m; else r = m;
u32 intervalBegin = (m_size >> DISTR_SHIFT) * m_distrTable[m];
if (intervalBegin <= m_codeValue) l = m; else r = m;
}
uint32_t symbol = l;
code_value_ -= (size_ >> c_distr_shift) * distr_table_[symbol];
size_ = (size_ >> c_distr_shift) * (distr_table_[symbol + 1] - distr_table_[symbol]);
while (size_ < (uint32_t(1) << 24)) {
code_value_ <<= 8;
size_ <<= 8;
code_value_ |= ReadCodeByte();
u32 symbol = l;
m_codeValue -= (m_size >> DISTR_SHIFT) * m_distrTable[symbol];
m_size = (m_size >> DISTR_SHIFT) * (m_distrTable[symbol + 1] - m_distrTable[symbol]);
while (m_size < (u32(1) << 24))
{
m_codeValue <<= 8;
m_size <<= 8;
m_codeValue |= ReadCodeByte();
}
return symbol;
}
uint8_t ArithmeticDecoder::ReadCodeByte() {
if (serial_cur_ >= serial_end_) return 0;
else {
uint8_t result = 0;
reader_.Read(serial_cur_, &result, 1);
++serial_cur_;
u8 ArithmeticDecoder::ReadCodeByte()
{
if (m_serialCur >= m_serialEnd) return 0;
else
{
u8 result = 0;
m_reader.Read(m_serialCur, &result, 1);
++m_serialCur;
return result;
}
}

View file

@ -5,15 +5,15 @@
// // Compute freqs table by counting number of occurancies of each symbol.
// // Freqs table should have size equal to number of symbols in the alphabet.
// // Convert freqs table to distr table.
// vector<uint32_t> distr_table = FreqsToDistrTable(freqs);
// ArithmeticEncoder arith_enc(distr_table);
// vector<uint32_t> distrTable = FreqsToDistrTable(freqs);
// ArithmeticEncoder arith_enc(distrTable);
// // Encode any number of symbols.
// arith_enc.Encode(10); arith_enc.Encode(17); arith_enc.Encode(0); arith_enc.Encode(4);
// // Get encoded bytes.
// vector<uint8_t> encoded_data = arith_enc.Finalize();
// // Decode encoded bytes. Number of symbols should be provided outside.
// MemReader reader(encoded_data.data(), encoded_data.size());
// ArithmeticDecoder arith_dec(&reader, distr_table);
// ArithmeticDecoder arith_dec(&reader, distrTable);
// uint32_t sym1 = arith_dec.Decode(); uint32_t sym2 = arith_dec.Decode();
// uint32_t sym3 = arith_dec.Decode(); uint32_t sym4 = arith_dec.Decode();
@ -22,52 +22,59 @@
#include "../std/stdint.hpp"
#include "../std/vector.hpp"
typedef uint8_t u8;
typedef uint32_t u32;
typedef uint64_t u64;
// Forward declarations.
class Reader;
// Default shift of distribution table, i.e. all distribution table frequencies are
// normalized by this shift, i.e. distr table upper bound equals (1 << c_distr_shift).
uint32_t const c_distr_shift = 16;
// normalized by this shift, i.e. distr table upper bound equals (1 << DISTR_SHIFT).
u32 const DISTR_SHIFT = 16;
// Converts symbols frequencies table to distribution table, used in Arithmetic codecs.
std::vector<uint32_t> FreqsToDistrTable(std::vector<uint32_t> const & freqs);
vector<u32> FreqsToDistrTable(vector<u32> const & freqs);
class ArithmeticEncoder {
class ArithmeticEncoder
{
public:
// Provided distribution table.
ArithmeticEncoder(std::vector<uint32_t> const & distr_table);
ArithmeticEncoder(vector<u32> const & distrTable);
// Encode symbol using given distribution table and add that symbol to output.
void Encode(uint32_t symbol);
void Encode(u32 symbol);
// Finalize encoding, flushes remaining bytes from the buffer to output.
// Returns output vector of encoded bytes.
std::vector<uint8_t> Finalize();
vector<u8> Finalize();
private:
// Propagates carry in case of overflow.
void PropagateCarry();
private:
uint32_t begin_;
uint32_t size_;
std::vector<uint8_t> output_;
std::vector<uint32_t> const & distr_table_;
u32 m_begin;
u32 m_size;
vector<u8> m_output;
vector<u32> const & m_distrTable;
};
class ArithmeticDecoder {
class ArithmeticDecoder
{
public:
// Decoder is given a reader to read input bytes,
// distr_table - distribution table to decode symbols.
ArithmeticDecoder(Reader & reader, std::vector<uint32_t> const & distr_table);
// distrTable - distribution table to decode symbols.
ArithmeticDecoder(Reader & reader, vector<u32> const & distrTable);
// Decode next symbol from the encoded stream.
uint32_t Decode();
u32 Decode();
private:
// Read next code byte from encoded stream.
uint8_t ReadCodeByte();
u8 ReadCodeByte();
private:
// Current most significant part of code value.
uint32_t code_value_;
u32 m_codeValue;
// Current interval size.
uint32_t size_;
u32 m_size;
// Reader and two bounds of encoded data within this Reader.
Reader & reader_;
uint64_t serial_cur_;
uint64_t serial_end_;
Reader & m_reader;
u64 m_serialCur;
u64 m_serialEnd;
std::vector<uint32_t> const & distr_table_;
vector<u32> const & m_distrTable;
};

View file

@ -7,35 +7,35 @@
UNIT_TEST(ArithmeticCodec) {
PseudoRNG32 rng;
uint32_t const c_max_freq = 2048;
uint32_t const c_alphabet_size = 256;
vector<uint32_t> symbols;
vector<uint32_t> freqs;
u32 const MAX_FREQ = 2048;
u32 const ALPHABET_SIZE = 256;
vector<u32> symbols;
vector<u32> freqs;
// Generate random freqs.
for (uint32_t i = 0; i < c_alphabet_size; ++i) {
uint32_t freq = rng.Generate() % c_max_freq;
for (u32 i = 0; i < ALPHABET_SIZE; ++i) {
u32 freq = rng.Generate() % MAX_FREQ;
freqs.push_back(freq);
}
// Make at least one frequency zero for corner cases.
freqs[freqs.size() / 2] = 0;
// Generate symbols based on given freqs.
for (uint32_t i = 0; i < freqs.size(); ++i) {
uint32_t freq = freqs[i];
for (uint32_t j = 0; j < freq; ++j) {
uint32_t pos = rng.Generate() % (symbols.size() + 1);
for (u32 i = 0; i < freqs.size(); ++i) {
u32 freq = freqs[i];
for (u32 j = 0; j < freq; ++j) {
u32 pos = rng.Generate() % (symbols.size() + 1);
symbols.insert(symbols.begin() + pos, 1, i);
}
}
vector<uint32_t> distr_table = FreqsToDistrTable(freqs);
vector<u32> distrTable = FreqsToDistrTable(freqs);
// Encode symbols.
ArithmeticEncoder arith_enc(distr_table);
for (uint32_t i = 0; i < symbols.size(); ++i) arith_enc.Encode(symbols[i]);
vector<uint8_t> encoded_data = arith_enc.Finalize();
ArithmeticEncoder arithEnc(distrTable);
for (u32 i = 0; i < symbols.size(); ++i) arithEnc.Encode(symbols[i]);
vector<u8> encodedData = arithEnc.Finalize();
// Decode symbols.
MemReader reader(encoded_data.data(), encoded_data.size());
ArithmeticDecoder arith_dec(reader, distr_table);
for (uint32_t i = 0; i < symbols.size(); ++i) {
uint32_t decoded_symbol = arith_dec.Decode();
TEST_EQUAL(symbols[i], decoded_symbol, ());
MemReader reader(encodedData.data(), encodedData.size());
ArithmeticDecoder arithDec(reader, distrTable);
for (u32 i = 0; i < symbols.size(); ++i) {
u32 decodedSymbol = arithDec.Decode();
TEST_EQUAL(symbols[i], decodedSymbol, ());
}
}