From ef8f9b155eace440038a87df7768d7b13ca8a5d8 Mon Sep 17 00:00:00 2001 From: Yuri Gorshenin Date: Thu, 31 Aug 2017 19:24:33 +0300 Subject: [PATCH] [search] Expose num typos to ranking info. --- base/levenshtein_dfa.hpp | 3 + search/geocoder.cpp | 5 +- search/ranker.cpp | 53 ++++++++----- search/ranking_info.cpp | 2 +- search/ranking_info.hpp | 4 +- search/ranking_utils.cpp | 33 ++++++++ search/ranking_utils.hpp | 77 +++++++++++++++++++ search/search_integration_tests/helpers.cpp | 5 +- search/search_integration_tests/helpers.hpp | 3 +- .../processor_test.cpp | 39 ++++++++++ search/utils.cpp | 5 ++ search/utils.hpp | 4 +- 12 files changed, 204 insertions(+), 29 deletions(-) diff --git a/base/levenshtein_dfa.hpp b/base/levenshtein_dfa.hpp index 96ec173817..88569d0b65 100644 --- a/base/levenshtein_dfa.hpp +++ b/base/levenshtein_dfa.hpp @@ -91,6 +91,9 @@ public: LevenshteinDFA const & m_dfa; }; + LevenshteinDFA(LevenshteinDFA const &) = default; + LevenshteinDFA(LevenshteinDFA &&) = default; + LevenshteinDFA(UniString const & s, size_t prefixCharsToKeep, size_t maxErrors); LevenshteinDFA(std::string const & s, size_t prefixCharsToKeep, size_t maxErrors); LevenshteinDFA(UniString const & s, size_t maxErrors); diff --git a/search/geocoder.cpp b/search/geocoder.cpp index c0e9366913..98317e4fcd 100644 --- a/search/geocoder.cpp +++ b/search/geocoder.cpp @@ -353,7 +353,7 @@ void Geocoder::SetParams(Params const & params) // Here and below, we use LevenshteinDFAs for fuzzy // matching. But due to performance reasons, we assume that the // first letter is always correct. - request.m_names.emplace_back(s, 1 /* prefixCharsToKeep */, GetMaxErrorsForToken(s)); + request.m_names.emplace_back(BuildLevenshteinDFA(s)); }); for (auto const & index : m_params.GetTypeIndices(i)) request.m_categories.emplace_back(FeatureTypeToString(index)); @@ -363,8 +363,7 @@ void Geocoder::SetParams(Params const & params) { auto & request = m_prefixTokenRequest; m_params.GetToken(i).ForEach([&request](UniString const & s) { - request.m_names.emplace_back( - LevenshteinDFA(s, 1 /* prefixCharsToKeep */, GetMaxErrorsForToken(s))); + request.m_names.emplace_back(BuildLevenshteinDFA(s)); }); for (auto const & index : m_params.GetTypeIndices(i)) request.m_categories.emplace_back(FeatureTypeToString(index)); diff --git a/search/ranker.cpp b/search/ranker.cpp index c180a9f789..405fd99459 100644 --- a/search/ranker.cpp +++ b/search/ranker.cpp @@ -18,27 +18,32 @@ namespace search { namespace { -template -void UpdateNameScore(string const & name, TSlice const & slice, NameScore & bestScore) +struct NameScores { - auto const score = GetNameScore(name, slice); - if (score > bestScore) - bestScore = score; + NameScore m_nameScore = NAME_SCORE_ZERO; + ErrorsMade m_errorsMade; +}; + +template +void UpdateNameScores(string const & name, TSlice const & slice, NameScores & bestScores) +{ + bestScores.m_nameScore = std::max(bestScores.m_nameScore, GetNameScore(name, slice)); + bestScores.m_errorsMade = ErrorsMade::Min(bestScores.m_errorsMade, GetErrorsMade(name, slice)); } template -void UpdateNameScore(vector const & tokens, TSlice const & slice, - NameScore & bestScore) +void UpdateNameScores(vector const & tokens, TSlice const & slice, + NameScores & bestScores) { - auto const score = GetNameScore(tokens, slice); - if (score > bestScore) - bestScore = score; + bestScores.m_nameScore = std::max(bestScores.m_nameScore, GetNameScore(tokens, slice)); + bestScores.m_errorsMade = ErrorsMade::Min(bestScores.m_errorsMade, GetErrorsMade(tokens, slice)); } -NameScore GetNameScore(FeatureType const & ft, Geocoder::Params const & params, - TokenRange const & range, Model::Type type) +NameScores GetNameScores(FeatureType const & ft, Geocoder::Params const & params, + TokenRange const & range, Model::Type type) { - NameScore bestScore = NAME_SCORE_ZERO; + NameScores bestScores; + TokenSlice slice(params, range); TokenSliceNoCategories sliceNoCategories(params, range); @@ -50,14 +55,14 @@ NameScore GetNameScore(FeatureType const & ft, Geocoder::Params const & params, vector tokens; PrepareStringForMatching(name, tokens); - UpdateNameScore(tokens, slice, bestScore); - UpdateNameScore(tokens, sliceNoCategories, bestScore); + UpdateNameScores(tokens, slice, bestScores); + UpdateNameScores(tokens, sliceNoCategories, bestScores); } if (type == Model::TYPE_BUILDING) - UpdateNameScore(ft.GetHouseNumber(), sliceNoCategories, bestScore); + UpdateNameScores(ft.GetHouseNumber(), sliceNoCategories, bestScores); - return bestScore; + return bestScores; } void RemoveDuplicatingLinear(vector & values) @@ -210,7 +215,11 @@ class PreResult2Maker info.m_distanceToPivot = MercatorBounds::DistanceOnEarth(center, pivot); info.m_rank = preInfo.m_rank; info.m_type = preInfo.m_type; - info.m_nameScore = GetNameScore(ft, m_params, preInfo.InnermostTokenRange(), info.m_type); + + auto const nameScores = GetNameScores(ft, m_params, preInfo.InnermostTokenRange(), info.m_type); + + auto nameScore = nameScores.m_nameScore; + auto errorsMade = nameScores.m_errorsMade; if (info.m_type != Model::TYPE_STREET && preInfo.m_geoParts.m_street != IntersectionResult::kInvalidId) @@ -219,12 +228,16 @@ class PreResult2Maker FeatureType street; if (LoadFeature(FeatureID(mwmId, preInfo.m_geoParts.m_street), street)) { - NameScore const nameScore = GetNameScore( + auto const nameScores = GetNameScores( street, m_params, preInfo.m_tokenRange[Model::TYPE_STREET], Model::TYPE_STREET); - info.m_nameScore = min(info.m_nameScore, nameScore); + nameScore = min(nameScore, nameScores.m_nameScore); + errorsMade += nameScores.m_errorsMade; } } + info.m_nameScore = nameScore; + info.m_errorsMade = errorsMade; + TokenSlice slice(m_params, preInfo.InnermostTokenRange()); feature::TypesHolder holder(ft); vector> matched(slice.Size()); diff --git a/search/ranking_info.cpp b/search/ranking_info.cpp index f3cef0300e..bd0a52230a 100644 --- a/search/ranking_info.cpp +++ b/search/ranking_info.cpp @@ -60,6 +60,7 @@ string DebugPrint(RankingInfo const & info) os << "m_distanceToPivot:" << info.m_distanceToPivot << ","; os << "m_rank:" << static_cast(info.m_rank) << ","; os << "m_nameScore:" << DebugPrint(info.m_nameScore) << ","; + os << "m_errorsMade:" << DebugPrint(info.m_errorsMade) << ","; os << "m_type:" << DebugPrint(info.m_type) << ","; os << "m_pureCats:" << info.m_pureCats << ","; os << "m_falseCats:" << info.m_falseCats; @@ -98,5 +99,4 @@ double RankingInfo::GetLinearModelRank() const return kDistanceToPivot * distanceToPivot + kRank * rank + kNameScore[nameScore] + kType[m_type] + m_falseCats * kFalseCats; } - } // namespace search diff --git a/search/ranking_info.hpp b/search/ranking_info.hpp index 23080043f5..fc9af152aa 100644 --- a/search/ranking_info.hpp +++ b/search/ranking_info.hpp @@ -23,6 +23,9 @@ struct RankingInfo // Score for the feature's name. NameScore m_nameScore = NAME_SCORE_ZERO; + // Number of typos. + ErrorsMade m_errorsMade; + // Search type for the feature. Model::Type m_type = Model::TYPE_COUNT; @@ -45,5 +48,4 @@ struct RankingInfo }; string DebugPrint(RankingInfo const & info); - } // namespace search diff --git a/search/ranking_utils.cpp b/search/ranking_utils.cpp index 97d0c5792f..1ab51786e6 100644 --- a/search/ranking_utils.cpp +++ b/search/ranking_utils.cpp @@ -1,8 +1,11 @@ #include "search/ranking_utils.hpp" +#include "base/dfa_helpers.hpp" + #include "std/transform_iterator.hpp" #include +#include using namespace strings; @@ -13,6 +16,18 @@ namespace UniString AsciiToUniString(char const * s) { return UniString(s, s + strlen(s)); } } // namespace +string DebugPrint(ErrorsMade const & errorsMade) +{ + ostringstream os; + os << "ErrorsMade [ "; + if (errorsMade.IsValid()) + os << errorsMade.m_errorsMade; + else + os << "invalid"; + os << " ]"; + return os.str(); +} + namespace impl { bool FullMatch(QueryParams::Token const & token, UniString const & text) @@ -35,6 +50,24 @@ bool PrefixMatch(QueryParams::Token const & token, UniString const & text) } return false; } + +ErrorsMade GetMinErrorsMade(std::vector const & tokens, + strings::UniString const & text) +{ + auto const dfa = BuildLevenshteinDFA(text); + + ErrorsMade errorsMade; + + for (auto const & token : tokens) + { + auto it = dfa.Begin(); + strings::DFAMove(it, token.begin(), token.end()); + if (it.Accepts()) + errorsMade = ErrorsMade::Min(errorsMade, ErrorsMade(it.ErrorsMade())); + } + + return errorsMade; +} } // namespace impl bool IsStopWord(UniString const & s) diff --git a/search/ranking_utils.hpp b/search/ranking_utils.hpp index 29a6ae0102..871cef6167 100644 --- a/search/ranking_utils.hpp +++ b/search/ranking_utils.hpp @@ -2,13 +2,16 @@ #include "search/model.hpp" #include "search/query_params.hpp" +#include "search/utils.hpp" #include "indexer/search_delimiters.hpp" #include "indexer/search_string_utils.hpp" +#include "base/levenshtein_dfa.hpp" #include "base/stl_add.hpp" #include "base/string_utils.hpp" +#include #include #include #include @@ -18,11 +21,59 @@ namespace search { class QueryParams; +struct ErrorsMade +{ + static size_t constexpr kInfiniteErrors = std::numeric_limits::max(); + + ErrorsMade() = default; + explicit ErrorsMade(size_t errorsMade) : m_errorsMade(errorsMade) {} + + bool IsValid() const { return m_errorsMade != kInfiniteErrors; } + + template + static ErrorsMade Combine(ErrorsMade const & lhs, ErrorsMade const & rhs, Fn && fn) + { + if (!lhs.IsValid()) + return rhs; + if (!rhs.IsValid()) + return lhs; + return ErrorsMade(fn(lhs.m_errorsMade, rhs.m_errorsMade)); + } + + static ErrorsMade Min(ErrorsMade const & lhs, ErrorsMade const & rhs) + { + return Combine(lhs, rhs, [](size_t u, size_t v) { return std::min(u, v); }); + } + + friend ErrorsMade operator+(ErrorsMade const & lhs, ErrorsMade const & rhs) + { + return Combine(lhs, rhs, [](size_t u, size_t v) { return u + v; }); + } + + ErrorsMade & operator+=(ErrorsMade const & rhs) + { + *this = *this + rhs; + return *this; + } + + bool operator==(ErrorsMade const & rhs) const { return m_errorsMade == rhs.m_errorsMade; } + + size_t m_errorsMade = kInfiniteErrors; +}; + +string DebugPrint(ErrorsMade const & errorsMade); + namespace impl { bool FullMatch(QueryParams::Token const & token, strings::UniString const & text); bool PrefixMatch(QueryParams::Token const & token, strings::UniString const & text); + +// Returns the minimum number of errors needed to match |text| with +// any of the |tokens|. If it's not possible in accordance with +// GetMaxErrorsForToken(|text|), returns kInfiniteErrors. +ErrorsMade GetMinErrorsMade(std::vector const & tokens, + strings::UniString const & text); } // namespace impl // The order and numeric values are important here. Please, check all @@ -92,4 +143,30 @@ NameScore GetNameScore(std::vector const & tokens, Slice con } string DebugPrint(NameScore score); + +// Returns total number of errors that were made during matching +// feature |tokens| by a query - query tokens are in |slice|. +template +ErrorsMade GetErrorsMade(std::vector const & tokens, Slice const & slice) +{ + ErrorsMade totalErrorsMade; + + for (size_t i = 0; i < slice.Size(); ++i) + { + ErrorsMade errorsMade; + slice.Get(i).ForEach([&](strings::UniString const & s) { + errorsMade = ErrorsMade::Min(errorsMade, impl::GetMinErrorsMade(tokens, s)); + }); + + totalErrorsMade += errorsMade; + } + + return totalErrorsMade; +} + +template +ErrorsMade GetErrorsMade(std::string const & s, Slice const & slice) +{ + return GetErrorsMade({strings::MakeUniString(s)}, slice); +} } // namespace search diff --git a/search/search_integration_tests/helpers.cpp b/search/search_integration_tests/helpers.cpp index c537163d49..758394801f 100644 --- a/search/search_integration_tests/helpers.cpp +++ b/search/search_integration_tests/helpers.cpp @@ -79,11 +79,12 @@ bool SearchTest::ResultsMatch(SearchParams const & params, TRules const & rules) return ResultsMatch(request.Results(), rules); } -unique_ptr SearchTest::MakeRequest(string const & query) +unique_ptr SearchTest::MakeRequest( + string const & query, string const & locale /* = "en" */) { SearchParams params; params.m_query = query; - params.m_inputLocale = "en"; + params.m_inputLocale = locale; params.m_mode = Mode::Everywhere; params.m_suggestsEnabled = false; diff --git a/search/search_integration_tests/helpers.hpp b/search/search_integration_tests/helpers.hpp index 199e8b11f3..f42579f4f7 100644 --- a/search/search_integration_tests/helpers.hpp +++ b/search/search_integration_tests/helpers.hpp @@ -113,7 +113,8 @@ public: bool ResultsMatch(SearchParams const & params, TRules const & rules); - unique_ptr MakeRequest(string const & query); + unique_ptr MakeRequest(string const & query, + string const & locale = "en"); size_t CountFeatures(m2::RectD const & rect); diff --git a/search/search_integration_tests/processor_test.cpp b/search/search_integration_tests/processor_test.cpp index d04f5ac266..41b0cce5e2 100644 --- a/search/search_integration_tests/processor_test.cpp +++ b/search/search_integration_tests/processor_test.cpp @@ -419,6 +419,45 @@ UNIT_CLASS_TEST(ProcessorTest, TestRankingInfo) } } +UNIT_CLASS_TEST(ProcessorTest, TestRankingInfo_ErrorsMade) +{ + string const countryName = "Wonderland"; + + TestCity chekhov(m2::PointD(0, 0), "Чехов", "ru", 100 /* rank */); + TestStreet pushkinskaya( + vector{m2::PointD(-0.5, -0.5), m2::PointD(0, 0), m2::PointD(0.5, 0.5)}, + "Улица Пушкинская", "ru"); + TestPOI lermontov(m2::PointD(0, 0), "Трактиръ Лермонтовъ", "ru"); + lermontov.SetTypes({{"amenity", "cafe"}}); + + auto worldId = BuildWorld([&](TestMwmBuilder & builder) { builder.Add(chekhov); }); + + auto wonderlandId = BuildCountry(countryName, [&](TestMwmBuilder & builder) { + builder.Add(pushkinskaya); + builder.Add(lermontov); + }); + + SetViewport(m2::RectD(m2::PointD(-1, -1), m2::PointD(1, 1))); + + auto checkErrors = [&](string const & query, ErrorsMade const & errorsMade) { + auto request = MakeRequest(query, "ru"); + auto const & results = request->Results(); + + TRules rules{ExactMatch(wonderlandId, lermontov)}; + TEST(ResultsMatch(results, rules), ()); + TEST_EQUAL(results.size(), 1, ()); + + TEST_EQUAL(results[0].GetRankingInfo().m_errorsMade, errorsMade, ()); + }; + + checkErrors("кафе лермонтов", ErrorsMade(1)); + checkErrors("трактир лермонтов", ErrorsMade(2)); + checkErrors("кафе", ErrorsMade()); + checkErrors("пушкенская трактир лермонтов", ErrorsMade(3)); + checkErrors("пушкенская кафе", ErrorsMade(1)); + checkErrors("пушкинская трактиръ лермонтовъ", ErrorsMade(0)); +} + UNIT_CLASS_TEST(ProcessorTest, TestHouseNumbers) { string const countryName = "HouseNumberLand"; diff --git a/search/utils.cpp b/search/utils.cpp index c584dc9d30..18caccfa58 100644 --- a/search/utils.cpp +++ b/search/utils.cpp @@ -15,4 +15,9 @@ size_t GetMaxErrorsForToken(strings::UniString const & token) return 1; return 2; } + +strings::LevenshteinDFA BuildLevenshteinDFA(strings::UniString const & s) +{ + return strings::LevenshteinDFA(s, 1 /* prefixCharsToKeep */, GetMaxErrorsForToken(s)); +} } // namespace search diff --git a/search/utils.hpp b/search/utils.hpp index 4fb6b17a36..bbde779ea8 100644 --- a/search/utils.hpp +++ b/search/utils.hpp @@ -66,6 +66,8 @@ using TLocales = buffer_vector; size_t GetMaxErrorsForToken(strings::UniString const & token); +strings::LevenshteinDFA BuildLevenshteinDFA(strings::UniString const & s); + template void ForEachCategoryType(StringSliceBase const & slice, TLocales const & locales, CategoriesHolder const & categories, ToDo && todo) @@ -108,7 +110,7 @@ void ForEachCategoryTypeFuzzy(StringSliceBase const & slice, TLocales const & lo // todo(@m, @y). We build dfa twice for each token: here and in geocoder.cpp. // A possible optimization is to build each dfa once and save it. Note that // dfas for the prefix tokens differ, i.e. we ignore slice.IsPrefix(i) here. - strings::LevenshteinDFA const dfa(token, 1 /* prefixCharsToKeep */, GetMaxErrorsForToken(token)); + strings::LevenshteinDFA const dfa(BuildLevenshteinDFA(token)); trieRootIt.ForEachMove([&](Trie::Char const & c, Trie::Iterator const & trieStartIt) { if (std::binary_search(sortedLocales.begin(), sortedLocales.end(), static_cast(c)))