diff --git a/base/base_tests/levenshtein_dfa_test.cpp b/base/base_tests/levenshtein_dfa_test.cpp index 2d8bdc0910..bed77e2784 100644 --- a/base/base_tests/levenshtein_dfa_test.cpp +++ b/base/base_tests/levenshtein_dfa_test.cpp @@ -3,8 +3,10 @@ #include "base/dfa_helpers.hpp" #include "base/levenshtein_dfa.hpp" +#include #include +using namespace std; using namespace strings; namespace @@ -16,30 +18,63 @@ enum class Status Intermediate }; -Status GetStatus(LevenshteinDFA const & dfa, std::string const & s) +struct Result +{ + Result() = default; + Result(Status status, size_t errorsMade): m_status(status), m_errorsMade(errorsMade) {} + + bool operator==(Result const & rhs) const + { + return m_status == rhs.m_status && m_errorsMade == rhs.m_errorsMade; + } + + Status m_status = Status::Accepts; + size_t m_errorsMade = 0; +}; + +string DebugPrint(Status status) +{ + switch (status) + { + case Status::Accepts: return "Accepts"; + case Status::Rejects: return "Rejects"; + case Status::Intermediate: return "Intermediate"; + } +} + +string DebugPrint(Result const & result) +{ + ostringstream os; + os << "Result [ "; + os << "status: " << DebugPrint(result.m_status) << ", "; + os << "errorsMade: " << result.m_errorsMade << " ]"; + return os.str(); +} + +Result GetResult(LevenshteinDFA const & dfa, std::string const & s) { auto it = dfa.Begin(); DFAMove(it, s); if (it.Accepts()) - return Status::Accepts; + return Result(Status::Accepts, it.ErrorsMade()); if (it.Rejects()) - return Status::Rejects; - return Status::Intermediate; + return Result(Status::Rejects, it.ErrorsMade()); + return Result(Status::Intermediate, it.ErrorsMade()); } bool Accepts(LevenshteinDFA const & dfa, std::string const & s) { - return GetStatus(dfa, s) == Status::Accepts; + return GetResult(dfa, s).m_status == Status::Accepts; } bool Rejects(LevenshteinDFA const & dfa, std::string const & s) { - return GetStatus(dfa, s) == Status::Rejects; + return GetResult(dfa, s).m_status == Status::Rejects; } bool Intermediate(LevenshteinDFA const & dfa, std::string const & s) { - return GetStatus(dfa, s) == Status::Intermediate; + return GetResult(dfa, s).m_status == Status::Intermediate; } UNIT_TEST(LevenshteinDFA_Smoke) @@ -110,4 +145,41 @@ UNIT_TEST(LevenshteinDFA_Prefix) TEST(Accepts(dfa, "моксва"), ()); } } + +UNIT_TEST(LevenshteinDFA_ErrorsMade) +{ + { + LevenshteinDFA dfa("москва", 1 /* prefixCharsToKeep */, 2 /* maxErrors */); + + TEST_EQUAL(GetResult(dfa, "москва"), Result(Status::Accepts, 0 /* errorsMade */), ()); + TEST_EQUAL(GetResult(dfa, "москв"), Result(Status::Accepts, 1 /* errorsMade */), ()); + TEST_EQUAL(GetResult(dfa, "моск"), Result(Status::Accepts, 2 /* errorsMade */), ()); + TEST_EQUAL(GetResult(dfa, "мос").m_status, Status::Intermediate, ()); + + TEST_EQUAL(GetResult(dfa, "моксав"), Result(Status::Accepts, 2 /* errorsMade */), ()); + TEST_EQUAL(GetResult(dfa, "максав").m_status, Status::Rejects, ()); + + TEST_EQUAL(GetResult(dfa, "мсовк").m_status, Status::Intermediate, ()); + TEST_EQUAL(GetResult(dfa, "мсовка"), Result(Status::Accepts, 2 /* errorsMade */), ()); + TEST_EQUAL(GetResult(dfa, "мсовкб").m_status, Status::Rejects, ()); + } + + { + LevenshteinDFA dfa("aa", 0 /* prefixCharsToKeep */, 2 /* maxErrors */); + TEST_EQUAL(GetResult(dfa, "abab"), Result(Status::Accepts, 2 /* errorsMade */), ()); + } + + { + LevenshteinDFA dfa("mississippi", 0 /* prefixCharsToKeep */, 0 /* maxErrors */); + TEST_EQUAL(GetResult(dfa, "misisipi").m_status, Status::Rejects, ()); + TEST_EQUAL(GetResult(dfa, "mississipp").m_status, Status::Intermediate, ()); + TEST_EQUAL(GetResult(dfa, "mississippi"), Result(Status::Accepts, 0 /* errorsMade */), ()); + } + + { + LevenshteinDFA dfa("кафе", 1 /* prefixCharsToKeep */, 1 /* maxErrors */); + TEST_EQUAL(GetResult(dfa, "кафе"), Result(Status::Accepts, 0 /* errorsMade */), ()); + TEST_EQUAL(GetResult(dfa, "кафер"), Result(Status::Accepts, 1 /* errorsMade */), ()); + } +} } // namespace diff --git a/base/levenshtein_dfa.cpp b/base/levenshtein_dfa.cpp index 1290d6634c..40454192f8 100644 --- a/base/levenshtein_dfa.cpp +++ b/base/levenshtein_dfa.cpp @@ -191,10 +191,14 @@ LevenshteinDFA::LevenshteinDFA(UniString const & s, size_t prefixCharsToKeep, si ASSERT_EQUAL(id, m_transitions.size(), ()); ASSERT_EQUAL(visited.count(state), 0, (state, id)); + ASSERT_EQUAL(m_transitions.size(), m_accepting.size(), ()); + ASSERT_EQUAL(m_transitions.size(), m_errorsMade.size(), ()); + states.emplace(state); visited[state] = id; m_transitions.emplace_back(m_alphabet.size()); m_accepting.push_back(false); + m_errorsMade.push_back(ErrorsMade(state)); }; pushState(MakeStart(), kStartingState); @@ -296,6 +300,19 @@ bool LevenshteinDFA::IsAccepting(State const & s) const return false; } +size_t LevenshteinDFA::ErrorsMade(State const & s) const +{ + size_t errorsMade = m_maxErrors; + for (auto const & p : s.m_positions) + { + if (!IsAccepting(p)) + continue; + auto const errorsLeft = p.m_errorsLeft - (m_size - p.m_offset); + errorsMade = std::min(errorsMade, m_maxErrors - errorsLeft); + } + return errorsMade; +} + size_t LevenshteinDFA::Move(size_t s, UniChar c) const { ASSERT_GREATER(m_alphabet.size(), 0, ()); diff --git a/base/levenshtein_dfa.hpp b/base/levenshtein_dfa.hpp index 5b2ceeaa42..96ec173817 100644 --- a/base/levenshtein_dfa.hpp +++ b/base/levenshtein_dfa.hpp @@ -80,6 +80,8 @@ public: bool Accepts() const { return m_dfa.IsAccepting(m_s); } bool Rejects() const { return m_dfa.IsRejecting(m_s); } + size_t ErrorsMade() const { return m_dfa.ErrorsMade(m_s); } + private: friend class LevenshteinDFA; @@ -115,6 +117,10 @@ private: inline bool IsRejecting(State const & s) const { return s.m_positions.empty(); } inline bool IsRejecting(size_t s) const { return s == kRejectingState; } + // Returns minimum number of made errors among accepting positions in |s|. + size_t ErrorsMade(State const & s) const; + size_t ErrorsMade(size_t s) const { return m_errorsMade[s]; } + size_t Move(size_t s, UniChar c) const; size_t const m_size; @@ -124,6 +130,7 @@ private: std::vector> m_transitions; std::vector m_accepting; + std::vector m_errorsMade; }; std::string DebugPrint(LevenshteinDFA::Position const & p);