diff --git a/base/base_tests/suffix_array_tests.cpp b/base/base_tests/suffix_array_tests.cpp index 08123892f9..ef2824ef99 100644 --- a/base/base_tests/suffix_array_tests.cpp +++ b/base/base_tests/suffix_array_tests.cpp @@ -60,6 +60,8 @@ UNIT_TEST(Skew_Classic) { char const * s = "mississippi"; size_t const n = strlen(s); + TEST_EQUAL(n, 11, ()); + vector pos(n); Skew(n, reinterpret_cast(s), pos.data()); diff --git a/base/suffix_array.cpp b/base/suffix_array.cpp index d9c434b04d..7e5f5355ea 100644 --- a/base/suffix_array.cpp +++ b/base/suffix_array.cpp @@ -22,30 +22,36 @@ bool LEQ(size_t a1, size_t a2, size_t a3, size_t b1, size_t b2, size_t b3) return LEQ(a2, a3, b2, b3); } +// Actually this is a counting sort, but the name RadixSort is used +// here to keep the correspondence with the article about Skew|DC3. template -void RadixSort(size_t n, size_t const * keys, size_t maxValue, Values const & values, +void RadixSort(size_t numKeys, size_t const * keys, size_t numValues, Values const & values, size_t * resultKeys) { - std::vector count(maxValue); - for (size_t i = 0; i < n; ++i) + vector count(numValues, 0); + for (size_t i = 0; i < numKeys; ++i) + { + auto const value = values[keys[i]]; + ASSERT_LESS(value, count.size(), ()); ++count[values[keys[i]]]; - for (size_t i = 1; i < maxValue; ++i) + } + for (size_t i = 1; i < numValues; ++i) count[i] += count[i - 1]; - for (size_t i = n - 1; i < n; --i) + for (size_t i = numKeys - 1; i < numKeys; --i) resultKeys[--count[values[keys[i]]]] = keys[i]; } -bool InLeftHalf(size_t n0, size_t pos) { return pos < n0; } +bool InLeftHalf(size_t middle, size_t pos) { return pos < middle; } -size_t RestoreIndex(size_t n0, size_t pos) +size_t RestoreIndex(size_t middle, size_t pos) { - return InLeftHalf(n0, pos) ? pos * 3 + 1 : (pos - n0) * 3 + 2; + return InLeftHalf(middle, pos) ? pos * 3 + 1 : (pos - middle) * 3 + 2; } struct SkewWrapper { SkewWrapper(size_t n, uint8_t const * s) : m_n(n), m_s(s) {} - size_t size() const { return m_n; } + size_t operator[](size_t i) const { if (i < m_n) @@ -61,17 +67,18 @@ struct SkewWrapper template struct Slice { - Slice(Container const & c, size_t n, size_t offset) : m_c(c), m_n(n), m_offset(offset) {} + Slice(Container const & c, size_t offset) : m_c(c), m_offset(offset) {} + size_t operator[](size_t i) const { return m_c[i + m_offset]; } - const Container & m_c; - const size_t m_n; - const size_t m_offset; + + Container const & m_c; + size_t const m_offset; }; template Slice MakeSlice(Container const & c, size_t offset) { - return Slice(c, c.size(), offset); + return Slice(c, offset); } // Builds suffix array over the string s, where for all i < n: 0 < s[i] <= k. @@ -87,7 +94,7 @@ Slice MakeSlice(Container const & c, size_t offset) template void RawSkew(size_t n, size_t maxValue, S const & s, size_t * sa) { - size_t const kInvalidId = std::numeric_limits::max(); + size_t const kInvalidId = numeric_limits::max(); if (n == 0) return; @@ -98,21 +105,22 @@ void RawSkew(size_t n, size_t maxValue, S const & s, size_t * sa) return; } - // The number of =1 (mod 3) suffixes is the same as the number of =0 - // (mod 3) suffixes. - const size_t n0 = (n + 2) / 3; // Number of =0 (mod 3) suffixes. - const size_t n1 = (n + 1) / 3; // Number of =1 (mod 3) suffixes. - const size_t n2 = n / 3; // Number of =2 (mod 3) suffixes. + size_t const n0 = (n + 2) / 3; // Number of =0 (mod 3) suffixes. + size_t const n1 = (n + 1) / 3; // Number of =1 (mod 3) suffixes. + size_t const n2 = n / 3; // Number of =2 (mod 3) suffixes. - const size_t n02 = n0 + n2; + size_t const n02 = n0 + n2; - const bool fake1 = n0 != n1; + size_t const fake1 = n0 != n1 ? 1 : 0; + + // The total number of =1 (mod 3) suffixes (including the fake one) + // is the same as the number of =0 (mod 3) suffixes. ASSERT_EQUAL(n1 + fake1, n0, ()); - ASSERT_EQUAL(fake1, (n % 3 == 1), ()); + ASSERT_EQUAL(fake1, static_cast(n % 3 == 1), ()); // Generate positions of =(1|2) (mod 3) suffixes. - std::vector s12(n02 + 3); - std::vector sa12(n02 + 3); + vector s12(n02 + 3); + vector sa12(n02 + 3); // (n0 - n1) is needed in case when n == 0 (mod 3). We need a fake // =1 (mod 3) suffix for proper sorting of =0 (mod 3) suffixes. @@ -125,6 +133,10 @@ void RawSkew(size_t n, size_t maxValue, S const & s, size_t * sa) s12[j++] = i; } + // Following three lines perform a stable sorting of all triples + // where i =(1|2) (mod 3), including + // possible fake1 suffix. Final order of these triples is written to + // |sa12|. RadixSort(n02, s12.data(), maxValue + 1, MakeSlice(s, 2), sa12.data()); RadixSort(n02, sa12.data(), maxValue + 1, MakeSlice(s, 1), s12.data()); RadixSort(n02, s12.data(), maxValue + 1, s, sa12.data()); @@ -157,7 +169,7 @@ void RawSkew(size_t n, size_t maxValue, S const & s, size_t * sa) { // When not all triples unique, we need to build a suffix array // for them. - RawSkew(n02, name, s12, sa12.data()); + RawSkew(n02 /* n */, name /* maxValue */, s12, sa12.data()); for (size_t i = 0; i < n02; ++i) s12[sa12[i]] = i + 1; } @@ -172,8 +184,8 @@ void RawSkew(size_t n, size_t maxValue, S const & s, size_t * sa) // in s12 are unique. // Need to do a stable sort for all =0 (mod 3) suffixes. - std::vector s0(n0); - std::vector sa0(n0); + vector s0(n0); + vector sa0(n0); for (size_t i = 0, j = 0; i < n02; ++i) { if (sa12[i] < n0) @@ -195,39 +207,28 @@ void RawSkew(size_t n, size_t maxValue, S const & s, size_t * sa) size_t k = 0; while (i12 != n02 && i0 != n0) { - const size_t p0 = sa0[i0]; - const size_t p12 = RestoreIndex(n0, sa12[i12]); + size_t const p0 = sa0[i0]; + size_t const p12 = RestoreIndex(n0, sa12[i12]); ASSERT_LESS(p12 / 3, n0, ()); - if (InLeftHalf(n0, sa12[i12])) + bool const isLEQ = + InLeftHalf(n0, sa12[i12]) + ? LEQ(s[p12], s12[sa12[i12] + n0], s[p0], s12[p0 / 3]) + : LEQ(s[p12], s[p12 + 1], s12[sa12[i12] - n0 + 1], s[p0], s[p0 + 1], s12[p0 / 3 + n0]); + + if (isLEQ) { - if (LEQ(s[p12], s12[sa12[i12] + n0], s[p0], s12[p0 / 3])) - { - // Suffix =(1|2) (mod 3) is smaller. - sa[k++] = p12; - ++i12; - } - else - { - sa[k++] = p0; - ++i0; - } + // Suffix =(1|2) (mod 3) is smaller. + sa[k++] = p12; + ++i12; } else { - if (LEQ(s[p12], s[p12 + 1], s12[sa12[i12] - n0 + 1], s[p0], s[p0 + 1], s12[p0 / 3 + n0])) - { - // Suffix =(1|2) (mod 3) is smaller. - sa[k++] = p12; - ++i12; - } - else - { - sa[k++] = p0; - ++i0; - } + sa[k++] = p0; + ++i0; } } + for (; i12 != n02; ++k, ++i12) sa[k] = RestoreIndex(n0, sa12[i12]); for (; i0 != n0; ++k, ++i0) @@ -240,13 +241,14 @@ namespace base { void Skew(size_t n, uint8_t const * s, size_t * sa) { - RawSkew(n, 0xFF /* maxValue */, SkewWrapper(n, s), sa); + auto const maxValue = static_cast(numeric_limits::max()); + RawSkew(n, maxValue, SkewWrapper(n, s), sa); } -void Skew(std::string const & s, std::vector & sa) +void Skew(string const & s, vector & sa) { auto const n = s.size(); sa.assign(n, 0); - Skew(n, reinterpret_cast(s.data()), sa.data()); + Skew(n, reinterpret_cast(s.data()), sa.data()); } } // namespace base