diff --git a/base/CMakeLists.txt b/base/CMakeLists.txt index 9112a18640..fe64857e8a 100644 --- a/base/CMakeLists.txt +++ b/base/CMakeLists.txt @@ -8,6 +8,8 @@ set( base.hpp bits.hpp buffer_vector.hpp + bwt.cpp + bwt.hpp cache.hpp cancellable.hpp checked_cast.hpp diff --git a/base/base.pro b/base/base.pro index c9fdc71316..cb2cdd8888 100644 --- a/base/base.pro +++ b/base/base.pro @@ -9,6 +9,7 @@ include($$ROOT_DIR/common.pri) SOURCES += \ base.cpp \ + bwt.cpp \ condition.cpp \ deferred_task.cpp \ exception.cpp \ @@ -43,6 +44,7 @@ HEADERS += \ base.hpp \ bits.hpp \ buffer_vector.hpp \ + bwt.hpp \ cache.hpp \ cancellable.hpp \ checked_cast.hpp \ diff --git a/base/base_tests/CMakeLists.txt b/base/base_tests/CMakeLists.txt index 2de3cf2b09..277983f38d 100644 --- a/base/base_tests/CMakeLists.txt +++ b/base/base_tests/CMakeLists.txt @@ -7,6 +7,7 @@ set( assert_test.cpp bits_test.cpp buffer_vector_test.cpp + bwt_tests.cpp cache_test.cpp collection_cast_test.cpp condition_test.cpp diff --git a/base/base_tests/base_tests.pro b/base/base_tests/base_tests.pro index 6b7bce0be3..1f2025da66 100644 --- a/base/base_tests/base_tests.pro +++ b/base/base_tests/base_tests.pro @@ -17,6 +17,7 @@ SOURCES += \ assert_test.cpp \ bits_test.cpp \ buffer_vector_test.cpp \ + bwt_tests.cpp \ cache_test.cpp \ collection_cast_test.cpp \ condition_test.cpp \ diff --git a/base/base_tests/bwt_tests.cpp b/base/base_tests/bwt_tests.cpp new file mode 100644 index 0000000000..2338da3cc2 --- /dev/null +++ b/base/base_tests/bwt_tests.cpp @@ -0,0 +1,67 @@ +#include "testing/testing.hpp" + +#include "base/bwt.hpp" + +#include +#include + +using namespace base; +using namespace std; + +namespace +{ +string RevRevBWT(string const & s) +{ + string r; + auto const start = BWT(s, r); + + string rr; + RevBWT(start, r, rr); + return rr; +} + +UNIT_TEST(BWT_Smoke) +{ + { + TEST_EQUAL(BWT(0 /* n */, nullptr /* s */, nullptr /* r */), 0, ()); + } + + { + string r; + TEST_EQUAL(BWT(string() /* s */, r /* r */), 0, ()); + } + + { + string const s = "aaaaaa"; + string r; + TEST_EQUAL(BWT(s, r), 5, ()); + TEST_EQUAL(r, s, ()); + } + + { + string const s = "mississippi"; + string r; + TEST_EQUAL(BWT(s, r), 4, ()); + TEST_EQUAL(r, "pssmipissii", ()); + } +} + +UNIT_TEST(RevBWT_Smoke) +{ + string const strings[] = {"abaaba", "mississippi", "a b b", "Again and again and again"}; + for (auto const & s : strings) + TEST_EQUAL(s, RevRevBWT(s), ()); + + for (size_t i = 0; i < 100; ++i) + { + string const s(i, '\0'); + TEST_EQUAL(s, RevRevBWT(s), ()); + } + + for (size_t i = 0; i < 100; ++i) + { + string const s(i, 'a' + (i % 3)); + TEST_EQUAL(s, RevRevBWT(s), ()); + } +} +} // namespace diff --git a/base/base_tests/suffix_array_tests.cpp b/base/base_tests/suffix_array_tests.cpp index ef2824ef99..8d368b18a7 100644 --- a/base/base_tests/suffix_array_tests.cpp +++ b/base/base_tests/suffix_array_tests.cpp @@ -43,16 +43,24 @@ UNIT_TEST(Skew_Simple) TEST_EQUAL(pos[3], 0, ()); } + for (size_t length = 0; length < 100; ++length) { - for (size_t length = 0; length < 100; ++length) - { - string const s(length, 'a'); - vector pos; - Skew(s, pos); - TEST_EQUAL(pos.size(), s.size(), ()); - for (size_t i = 0; i < pos.size(); ++i) - TEST_EQUAL(pos[i], pos.size() - i - 1, ()); - } + string const s(length, 'a'); + vector pos; + Skew(s, pos); + TEST_EQUAL(pos.size(), s.size(), ()); + for (size_t i = 0; i < pos.size(); ++i) + TEST_EQUAL(pos[i], pos.size() - i - 1, ()); + } + + for (size_t length = 0; length < 100; ++length) + { + string const s(length, '\0'); + vector pos; + Skew(s, pos); + TEST_EQUAL(pos.size(), s.size(), ()); + for (size_t i = 0; i < pos.size(); ++i) + TEST_EQUAL(pos[i], pos.size() - i - 1, ()); } } diff --git a/base/bwt.cpp b/base/bwt.cpp new file mode 100644 index 0000000000..c416f673ad --- /dev/null +++ b/base/bwt.cpp @@ -0,0 +1,194 @@ +#include "base/bwt.hpp" + +#include "base/assert.hpp" +#include "base/suffix_array.hpp" + +#include +#include +#include +#include + +using namespace std; + +namespace +{ +size_t const kNumBytes = 256; + +// Fake trailing '$' for the BWT, used for original string +// reconstruction. +uint32_t const kEOS = 256; + +// FirstColumn represents the first column in the BWT matrix. As +// during reverse BWT we need to reconstruct canonical first column, +// with '$' as the first element, this wrapper is used. Also note that +// other characters in the first column are sorted, so we actually +// don't need to store them explicitly, it's enough to store start +// positions of the corresponding groups of consecutive characters. +class FirstColumn +{ +public: + FirstColumn(size_t n, uint8_t const * s) : m_n(n), m_starts({}) + { + for (size_t i = 0; i < n; ++i) + ++m_starts[s[i]]; + + size_t offset = 0; + for (size_t i = 0; i < m_starts.size(); ++i) + { + auto const count = m_starts[i]; + m_starts[i] = offset; + offset += count; + } + } + + size_t Size() const { return m_n + 1; } + + uint32_t operator[](size_t i) const + { + ASSERT_LESS(i, Size(), ()); + if (i == 0) + return kEOS; + + --i; + auto it = upper_bound(m_starts.begin(), m_starts.end(), i); + ASSERT(it != m_starts.begin(), ()); + --it; + return static_cast(distance(m_starts.begin(), it)); + } + + // Returns the rank of the i-th symbol among symbols with the same + // value. + size_t Rank(size_t i) const + { + ASSERT_LESS(i, Size(), ()); + if (i == 0) + return 0; + + --i; + auto it = upper_bound(m_starts.begin(), m_starts.end(), i); + if (it == m_starts.begin()) + return i; + --it; + return i - *it; + } + +private: + size_t const m_n; + array m_starts; +}; + +// LastColumn represents the last column in the BWT matrix. As during +// reverse BWT we need to reconstruct canonical last column, |s| is +// replaced by s[start] + s[0, start) + '$' + s[start, n). +class LastColumn +{ +public: + LastColumn(size_t n, size_t start, uint8_t const * s) : m_n(n), m_start(start), m_s(s) + { + for (size_t i = 0; i < Size(); ++i) + { + auto const b = (*this)[i]; + if (b == kEOS) + continue; + ASSERT_LESS(b, kNumBytes, ()); + m_table[b].push_back(i); + } + } + + size_t Size() const { return m_n + 1; } + + uint32_t operator[](size_t i) const + { + if (i == 0) + { + ASSERT_LESS(m_start, m_n, ()); + return m_s[m_start]; + } + + if (i == m_start + 1) + return kEOS; + + ASSERT_LESS_OR_EQUAL(i, m_n, ()); + return m_s[i - 1]; + } + + // Returns the index of the |rank|-th |byte| in the canonical BWT + // last column. + size_t Select(uint32_t byte, size_t rank) + { + if (byte == kEOS) + { + ASSERT_EQUAL(rank, 0, ()); + return 0; + } + + ASSERT_LESS(rank, m_table[byte].size(), (byte, rank)); + return m_table[byte][rank]; + } + +private: + size_t const m_n; + size_t const m_start; + uint8_t const * const m_s; + array, kNumBytes> m_table; +}; +} // namespace + +namespace base +{ +size_t BWT(size_t n, uint8_t const * s, uint8_t * r) +{ + vector sa(n); + Skew(n, s, sa.data()); + + size_t result = 0; + for (size_t i = 0; i < n; ++i) + { + if (sa[i] != 0) + { + r[i] = s[sa[i] - 1]; + } + else + { + result = i; + r[i] = s[n - 1]; + } + } + return result; +} + +size_t BWT(string const & s, string & r) +{ + auto const n = s.size(); + r.assign(n, '\0'); + return BWT(n, reinterpret_cast(s.data()), reinterpret_cast(&r[0])); +} + +void RevBWT(size_t n, size_t start, uint8_t const * s, uint8_t * r) +{ + if (n == 0) + return; + + FirstColumn first(n, s); + LastColumn last(n, start, s); + + auto curr = start + 1; + for (size_t i = 0; i < n; ++i) + { + ASSERT_LESS(curr, first.Size(), ()); + ASSERT(first[curr] != kEOS, ()); + + r[i] = first[curr]; + curr = last.Select(r[i], first.Rank(curr)); + } + + ASSERT_EQUAL(first[curr], kEOS, ()); +} + +void RevBWT(size_t start, string const & s, string & r) +{ + auto const n = s.size(); + r.assign(n, '\0'); + RevBWT(n, start, reinterpret_cast(s.data()), reinterpret_cast(&r[0])); +} +} // namespace base diff --git a/base/bwt.hpp b/base/bwt.hpp new file mode 100644 index 0000000000..f1d01dbc20 --- /dev/null +++ b/base/bwt.hpp @@ -0,0 +1,60 @@ +#pragma once + +#include +#include + +namespace base +{ +// Computes the Burrows-Wheeler transform of the string |s|, stores +// result in the string |r|. Note - the size of |r| must be |n|. +// Returns the index of the original string among the all sorted +// rotations of the |s|. +// +// *NOTE* in contrast to popular explanations of BWT, we do not append +// to |s| trailing '$' that is less than any other character in |s|. +// The reason is that |s| can be an arbitrary byte string, with zero +// bytes inside, so implementation of this trailing '$' is expensive, +// and, actually, not needed. +// +// For example, if |s| is "abaaba", canonical BWT is: +// +// Sorted rotations: canonical BWT: +// $abaaba a +// a$abaab b +// aaba$ab b +// aba$aba a +// * abaaba$ $ +// ba$abaa a +// baaba$a a +// +// where '*' denotes original string. +// +// Our implementation will sort rotations in a way as there is an +// implicit '$' that is less than any other byte in |s|, but does not +// return this '$'. Therefore, the order of rotations will be the same +// as above, without the first '$abaaba': +// +// Sorted rotations: ours BWT: +// aabaab b +// aabaab b +// abaaba a +// * abaaba a +// baabaa a +// baabaa a +// +// where '*' denotes the index of original string. As one can see, +// there are two 'abaaba' strings, but as mentioned, rotations are +// sorted like there is an implicit '$' at the end of the original +// string. It's possible to get from "ours BWT" to the "original BWT", +// see the code for details. +// +// Complexity: O(n) time and O(n) memory. +size_t BWT(size_t n, uint8_t const * s, uint8_t * r); +size_t BWT(std::string const & s, std::string & r); + +// Inverse Burrows-Wheeler transform. +// +// Complexity: O(n) time and O(n) memory. +void RevBWT(size_t n, size_t start, uint8_t const * s, uint8_t * r); +void RevBWT(size_t start, std::string const & s, std::string & r); +} // namespace base diff --git a/base/suffix_array.hpp b/base/suffix_array.hpp index 41cf35196a..47a05bf904 100644 --- a/base/suffix_array.hpp +++ b/base/suffix_array.hpp @@ -9,7 +9,7 @@ namespace base // Builds suffix array for the string |s| and stores result in the // |sa| array. Size of |sa| must be not less than |n|. // -// Time complexity: O(n) +// Complexity: O(n) time and O(n) memory. void Skew(size_t n, uint8_t const * s, size_t * sa); void Skew(std::string const & s, std::vector & sa); } // namespace base