diff --git a/base/base_tests/levenshtein_dfa_test.cpp b/base/base_tests/levenshtein_dfa_test.cpp index 84374246de..99e0fbc716 100644 --- a/base/base_tests/levenshtein_dfa_test.cpp +++ b/base/base_tests/levenshtein_dfa_test.cpp @@ -22,16 +22,21 @@ enum class Status struct Result { Result() = default; - Result(Status status, size_t errorsMade = 0) : m_status(status), m_errorsMade(errorsMade) {} + Result(Status status, size_t errorsMade = 0, size_t prefixErrorsMade = 0) + : m_status(status), m_errorsMade(errorsMade), m_prefixErrorsMade(prefixErrorsMade) + { + } bool operator==(Result const & rhs) const { return m_status == rhs.m_status && - (m_errorsMade == rhs.m_errorsMade || m_status == Status::Rejects); + (m_errorsMade == rhs.m_errorsMade || m_status == Status::Rejects) && + (m_prefixErrorsMade == rhs.m_prefixErrorsMade || m_status == Status::Rejects); } Status m_status = Status::Accepts; size_t m_errorsMade = 0; + size_t m_prefixErrorsMade = 0; }; string DebugPrint(Status status) @@ -50,7 +55,8 @@ string DebugPrint(Result const & result) ostringstream os; os << "Result [ "; os << "status: " << DebugPrint(result.m_status) << ", "; - os << "errorsMade: " << result.m_errorsMade << " ]"; + os << "errorsMade: " << result.m_errorsMade << ", "; + os << "prefixErrorsMade: " << result.m_prefixErrorsMade << " ]"; return os.str(); } @@ -59,10 +65,10 @@ Result GetResult(LevenshteinDFA const & dfa, std::string const & s) auto it = dfa.Begin(); DFAMove(it, s); if (it.Accepts()) - return Result(Status::Accepts, it.ErrorsMade()); + return Result(Status::Accepts, it.ErrorsMade(), it.PrefixErrorsMade()); if (it.Rejects()) - return Result(Status::Rejects, it.ErrorsMade()); - return Result(Status::Intermediate, it.ErrorsMade()); + return Result(Status::Rejects, it.ErrorsMade(), it.PrefixErrorsMade()); + return Result(Status::Intermediate, it.ErrorsMade(), it.PrefixErrorsMade()); } bool Accepts(LevenshteinDFA const & dfa, std::string const & s) @@ -154,29 +160,39 @@ UNIT_TEST(LevenshteinDFA_ErrorsMade) { LevenshteinDFA dfa("москва", 1 /* prefixSize */, 2 /* maxErrors */); - TEST_EQUAL(GetResult(dfa, "москва"), Result(Status::Accepts, 0 /* errorsMade */), ()); - TEST_EQUAL(GetResult(dfa, "москв"), Result(Status::Accepts, 1 /* errorsMade */), ()); - TEST_EQUAL(GetResult(dfa, "моск"), Result(Status::Accepts, 2 /* errorsMade */), ()); + TEST_EQUAL(GetResult(dfa, "москва"), + Result(Status::Accepts, 0 /* errorsMade */, 0 /* prefixErrorsMade */), ()); + TEST_EQUAL(GetResult(dfa, "москв"), + Result(Status::Accepts, 1 /* errorsMade */, 0 /* prefixErrorsMade */), ()); + TEST_EQUAL(GetResult(dfa, "моск"), + Result(Status::Accepts, 2 /* errorsMade */, 0 /* prefixErrorsMade */), ()); TEST_EQUAL(GetResult(dfa, "мос").m_status, Status::Intermediate, ()); + TEST_EQUAL(GetResult(dfa, "мос").m_prefixErrorsMade, 0, ()); - TEST_EQUAL(GetResult(dfa, "моксав"), Result(Status::Accepts, 2 /* errorsMade */), ()); + TEST_EQUAL(GetResult(dfa, "моксав"), + Result(Status::Accepts, 2 /* errorsMade */, 2 /* prefixErrorsMade */), ()); TEST_EQUAL(GetResult(dfa, "максав").m_status, Status::Rejects, ()); TEST_EQUAL(GetResult(dfa, "мсовк").m_status, Status::Intermediate, ()); - TEST_EQUAL(GetResult(dfa, "мсовка"), Result(Status::Accepts, 2 /* errorsMade */), ()); + TEST_EQUAL(GetResult(dfa, "мсовк").m_prefixErrorsMade, 2, ()); + TEST_EQUAL(GetResult(dfa, "мсовка"), + Result(Status::Accepts, 2 /* errorsMade */, 2 /* prefixErrorsMade */), ()); TEST_EQUAL(GetResult(dfa, "мсовкб").m_status, Status::Rejects, ()); } { LevenshteinDFA dfa("aa", 0 /* prefixSize */, 2 /* maxErrors */); - TEST_EQUAL(GetResult(dfa, "abab"), Result(Status::Accepts, 2 /* errorsMade */), ()); + TEST_EQUAL(GetResult(dfa, "abab"), + Result(Status::Accepts, 2 /* errorsMade */, 2 /* prefixErrorsMade */), ()); } { LevenshteinDFA dfa("mississippi", 0 /* prefixSize */, 0 /* maxErrors */); TEST_EQUAL(GetResult(dfa, "misisipi").m_status, Status::Rejects, ()); TEST_EQUAL(GetResult(dfa, "mississipp").m_status, Status::Intermediate, ()); - TEST_EQUAL(GetResult(dfa, "mississippi"), Result(Status::Accepts, 0 /* errorsMade */), ()); + TEST_EQUAL(GetResult(dfa, "mississipp").m_prefixErrorsMade, 0, ()); + TEST_EQUAL(GetResult(dfa, "mississippi"), + Result(Status::Accepts, 0 /* errorsMade */, 0 /* prefixErrorsMade */), ()); } { @@ -185,9 +201,9 @@ UNIT_TEST(LevenshteinDFA_ErrorsMade) size_t const maxErrors = 1; string const str = "yekaterinburg"; vector> const queries = { - {"yekaterinburg", Result(Status::Accepts, 0 /* errorsMade */)}, - {"ekaterinburg", Result(Status::Accepts, 1 /* errorsMade */)}, - {"jekaterinburg", Result(Status::Accepts, 1 /* errorsMade */)}, + {"yekaterinburg", Result(Status::Accepts, 0 /* errorsMade */, 0 /* prefixErrorsMade */)}, + {"ekaterinburg", Result(Status::Accepts, 1 /* errorsMade */, 1 /* prefixErrorsMade */)}, + {"jekaterinburg", Result(Status::Accepts, 1 /* errorsMade */, 1 /* prefixErrorsMade */)}, {"iekaterinburg", Result(Status::Rejects)}}; for (auto const & q : queries) @@ -199,8 +215,10 @@ UNIT_TEST(LevenshteinDFA_ErrorsMade) { LevenshteinDFA dfa("кафе", 1 /* prefixSize */, 1 /* maxErrors */); - TEST_EQUAL(GetResult(dfa, "кафе"), Result(Status::Accepts, 0 /* errorsMade */), ()); - TEST_EQUAL(GetResult(dfa, "кафер"), Result(Status::Accepts, 1 /* errorsMade */), ()); + TEST_EQUAL(GetResult(dfa, "кафе"), + Result(Status::Accepts, 0 /* errorsMade */, 0 /* prefixErrorsMade */), ()); + TEST_EQUAL(GetResult(dfa, "кафер"), + Result(Status::Accepts, 1 /* errorsMade */, 1 /* prefixErrorsMade */), ()); } } @@ -214,26 +232,33 @@ UNIT_TEST(LevenshteinDFA_PrefixDFAModifier) TEST(!it.Accepts(), ()); TEST(!it.Rejects(), ()); + TEST_EQUAL(it.PrefixErrorsMade(), 0, ()); + // |maxErrors| for all non-accepting states. + TEST_EQUAL(it.ErrorsMade(), 2, ()); DFAMove(it, "c"); TEST(it.Accepts(), ()); TEST(!it.Rejects(), ()); TEST_EQUAL(it.ErrorsMade(), 2, ()); + TEST_EQUAL(it.PrefixErrorsMade(), 0, ()); DFAMove(it, "d"); TEST(it.Accepts(), ()); TEST(!it.Rejects(), ()); TEST_EQUAL(it.ErrorsMade(), 1, ()); + TEST_EQUAL(it.PrefixErrorsMade(), 0, ()); DFAMove(it, "e"); TEST(it.Accepts(), ()); TEST(!it.Rejects(), ()); TEST_EQUAL(it.ErrorsMade(), 0, ()); + TEST_EQUAL(it.PrefixErrorsMade(), 0, ()); DFAMove(it, "fghijklmn"); TEST(it.Accepts(), ()); TEST(!it.Rejects(), ()); TEST_EQUAL(it.ErrorsMade(), 0, ()); + TEST_EQUAL(it.PrefixErrorsMade(), 0, ()); } } diff --git a/base/dfa_helpers.hpp b/base/dfa_helpers.hpp index 87ee1de2a6..12894f8f9d 100644 --- a/base/dfa_helpers.hpp +++ b/base/dfa_helpers.hpp @@ -62,6 +62,7 @@ public: bool Accepts() const { return m_accepts; } bool Rejects() const { return !Accepts() && m_it.Rejects(); } size_t ErrorsMade() const { return m_it.ErrorsMade(); } + size_t PrefixErrorsMade() const { return m_it.PrefixErrorsMade(); } private: friend class PrefixDFAModifier; diff --git a/base/levenshtein_dfa.cpp b/base/levenshtein_dfa.cpp index 070f104c9c..8c5055616f 100644 --- a/base/levenshtein_dfa.cpp +++ b/base/levenshtein_dfa.cpp @@ -237,6 +237,7 @@ LevenshteinDFA::LevenshteinDFA(UniString const & s, size_t prefixSize, m_transitions.emplace_back(m_alphabet.size()); m_accepting.push_back(false); m_errorsMade.push_back(ErrorsMade(state)); + m_prefixErrorsMade.push_back(PrefixErrorsMade(state)); }; pushState(MakeStart(), kStartingState); @@ -351,6 +352,14 @@ size_t LevenshteinDFA::ErrorsMade(State const & s) const return errorsMade; } +size_t LevenshteinDFA::PrefixErrorsMade(State const & s) const +{ + size_t errorsMade = m_maxErrors; + for (auto const & p : s.m_positions) + errorsMade = std::min(errorsMade, m_maxErrors - p.m_errorsLeft); + return errorsMade; +} + size_t LevenshteinDFA::Move(size_t s, UniChar c) const { ASSERT_GREATER(m_alphabet.size(), 0, ()); diff --git a/base/levenshtein_dfa.hpp b/base/levenshtein_dfa.hpp index 5505c94b23..fa52fa89cf 100644 --- a/base/levenshtein_dfa.hpp +++ b/base/levenshtein_dfa.hpp @@ -82,6 +82,7 @@ public: bool Rejects() const { return m_dfa.IsRejecting(m_s); } size_t ErrorsMade() const { return m_dfa.ErrorsMade(m_s); } + size_t PrefixErrorsMade() const { return m_dfa.PrefixErrorsMade(m_s); } private: friend class LevenshteinDFA; @@ -126,6 +127,10 @@ private: size_t ErrorsMade(State const & s) const; size_t ErrorsMade(size_t s) const { return m_errorsMade[s]; } + // Returns minimum number of errors already made. This number cannot decrease. + size_t PrefixErrorsMade(State const & s) const; + size_t PrefixErrorsMade(size_t s) const { return m_prefixErrorsMade[s]; } + size_t Move(size_t s, UniChar c) const; size_t const m_size; @@ -136,6 +141,7 @@ private: std::vector> m_transitions; std::vector m_accepting; std::vector m_errorsMade; + std::vector m_prefixErrorsMade; }; std::string DebugPrint(LevenshteinDFA::Position const & p); diff --git a/base/uni_string_dfa.hpp b/base/uni_string_dfa.hpp index 57765aab90..6187e8562a 100644 --- a/base/uni_string_dfa.hpp +++ b/base/uni_string_dfa.hpp @@ -19,6 +19,7 @@ public: bool Accepts() const { return !Rejects() && m_pos == m_s.size(); } bool Rejects() const { return m_rejected; } size_t ErrorsMade() const { return 0; } + size_t PrefixErrorsMade() const { return 0; } private: friend class UniStringDFA;