diff --git a/base/mem_trie.hpp b/base/mem_trie.hpp index fdd35af365..efc5a90cc0 100644 --- a/base/mem_trie.hpp +++ b/base/mem_trie.hpp @@ -3,6 +3,7 @@ #include "base/macros.hpp" #include "base/stl_add.hpp" +#include #include #include #include @@ -15,7 +16,12 @@ namespace my template class MemTrie { +private: + struct Node; + public: + using Char = typename String::value_type; + MemTrie() = default; MemTrie(MemTrie && rhs) { *this = std::move(rhs); } @@ -23,10 +29,38 @@ public: { m_root = std::move(rhs.m_root); m_numNodes = rhs.m_numNodes; - rhs.m_numNodes = 1; + rhs.Clear(); return *this; } + // A read-only iterator wrapping a Node. Any modification to the + // underlying trie is assumed to invalidate the iterator. + class Iterator + { + public: + Iterator(MemTrie::Node const & node) : m_node(node) {} + + // Iterates over all possible moves from this Iterator's node + // and calls |toDo| with two arguments: + // (Char of the move, Iterator wrapping the node of the move). + template + void ForEachMove(ToDo && toDo) const + { + for (auto const & move : m_node.m_moves) + toDo(move.first, Iterator(*move.second)); + } + + // Calls |toDo| for every value in this Iterator's node. + template + void ForEachInNode(ToDo && toDo) const + { + std::for_each(m_node.m_values.begin(), m_node.m_values.end(), std::forward(toDo)); + } + + private: + MemTrie::Node const & m_node; + }; + // Adds a key-value pair to the trie. void Add(String const & key, Value const & value) { @@ -69,12 +103,20 @@ public: ForEachInSubtree(*root, prefix, std::forward(toDo)); } + void Clear() + { + m_root.Clear(); + m_numNodes = 1; + } + size_t GetNumNodes() const { return m_numNodes; } + Iterator GetRootIterator() const { return Iterator(m_root); } + Node const & GetRoot() const { return m_root; } private: struct Node { - using Char = typename String::value_type; + friend class MemTrie::Iterator; Node() = default; Node(Node && /* rhs */) = default; @@ -98,6 +140,12 @@ private: void AddValue(Value const & value) { m_values.push_back(value); } + void Clear() + { + m_moves.clear(); + m_values.clear(); + } + std::map> m_moves; std::vector m_values; diff --git a/base/string_utils.cpp b/base/string_utils.cpp index d1ac84b24a..f2546161fb 100644 --- a/base/string_utils.cpp +++ b/base/string_utils.cpp @@ -237,19 +237,14 @@ bool IsASCIIString(std::string const & str) bool IsASCIIDigit(UniChar c) { return c >= '0' && c <= '9'; } bool IsASCIILatin(UniChar c) { return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z'); } + bool StartsWith(UniString const & s, UniString const & p) { - if (p.size() > s.size()) - return false; - for (size_t i = 0; i < p.size(); ++i) - { - if (s[i] != p[i]) - return false; - } - return true; + return StartsWith(s.begin(), s.end(), p.begin(), p.end()); } bool StartsWith(std::string const & s1, char const * s2) { return (s1.compare(0, strlen(s2), s2) == 0); } + bool EndsWith(std::string const & s1, char const * s2) { size_t const n = s1.size(); diff --git a/base/string_utils.hpp b/base/string_utils.hpp index af8670c5a4..8fc2228b40 100644 --- a/base/string_utils.hpp +++ b/base/string_utils.hpp @@ -439,6 +439,17 @@ std::string to_string_dac(double d, int dac); inline std::string to_string_with_digits_after_comma(double d, int dac) { return to_string_dac(d, dac); } //@} +template +bool StartsWith(IterT1 beg, IterT1 end, IterT2 begPrefix, IterT2 endPrefix) +{ + while (beg != end && begPrefix != endPrefix && *beg == *begPrefix) + { + ++beg; + ++begPrefix; + } + return begPrefix == endPrefix; +} + bool StartsWith(UniString const & s, UniString const & p); bool StartsWith(std::string const & s1, char const * s2); diff --git a/indexer/categories_holder.cpp b/indexer/categories_holder.cpp index c3871b2628..4cd965ef1a 100644 --- a/indexer/categories_holder.cpp +++ b/indexer/categories_holder.cpp @@ -69,7 +69,7 @@ bool ParseEmoji(CategoriesHolder::Category::Name & name) return false; } - name.m_name = ToUtf8(UniString(1, static_cast(c))); + name.m_name = ToUtf8(UniString(1 /* numChars */, static_cast(c))); if (IsASCIIString(ToUtf8(search::NormalizeAndSimplifyString(name.m_name)))) { @@ -203,6 +203,8 @@ void CategoriesHolder::AddCategory(Category & cat, vector & types) auto const locale = synonym.m_locale; ASSERT_NOT_EQUAL(locale, kUnsupportedLocaleCode, ()); + auto const localePrefix = String(1, static_cast(locale)); + auto const uniName = search::NormalizeAndSimplifyString(synonym.m_name); vector tokens; @@ -213,10 +215,7 @@ void CategoriesHolder::AddCategory(Category & cat, vector & types) if (!ValidKeyToken(token)) continue; for (uint32_t const t : types) - { - auto it = m_name2type.emplace(locale, make_unique()).first; - it->second->Add(token, t); - } + m_name2type.Add(localePrefix + token, t); } } } @@ -243,7 +242,7 @@ bool CategoriesHolder::ValidKeyToken(String const & s) void CategoriesHolder::LoadFromStream(istream & s) { m_type2cat.clear(); - m_name2type.clear(); + m_name2type.Clear(); m_groupTranslations.clear(); State state = EParseTypes; diff --git a/indexer/categories_holder.hpp b/indexer/categories_holder.hpp index 84800bf9b0..7475cdf235 100644 --- a/indexer/categories_holder.hpp +++ b/indexer/categories_holder.hpp @@ -4,6 +4,7 @@ #include "base/stl_helpers.hpp" #include "base/string_utils.hpp" +#include "std/algorithm.hpp" #include "std/deque.hpp" #include "std/iostream.hpp" #include "std/map.hpp" @@ -56,7 +57,8 @@ private: Type2CategoryCont m_type2cat; // Maps locale and category token to the list of corresponding types. - map> m_name2type; + // Locale is treated as a special symbol prepended to the token. + Trie m_name2type; GroupTranslations m_groupTranslations; @@ -109,10 +111,9 @@ public: template void ForEachTypeByName(int8_t locale, String const & name, ToDo && toDo) const { - auto const it = m_name2type.find(locale); - if (it == m_name2type.end()) - return; - it->second->ForEachInNode(name, my::MakeIgnoreFirstArgument(forward(toDo))); + auto const localePrefix = String(1, static_cast(locale)); + m_name2type.ForEachInNode(localePrefix + name, + my::MakeIgnoreFirstArgument(forward(toDo))); } inline GroupTranslations const & GetGroupTranslations() const { return m_groupTranslations; } @@ -125,12 +126,14 @@ public: /// @returns raw classificator type if it's not localized in categories.txt. string GetReadableFeatureType(uint32_t type, int8_t locale) const; + // Exposes the tries that map category tokens to types. + Trie const & GetNameToTypesTrie() const { return m_name2type; } bool IsTypeExist(uint32_t type) const; inline void Swap(CategoriesHolder & r) { m_type2cat.swap(r.m_type2cat); - m_name2type.swap(r.m_name2type); + std::swap(m_name2type, r.m_name2type); } // Converts any language |locale| from UI to the corresponding diff --git a/search/CMakeLists.txt b/search/CMakeLists.txt index 5b2c22322f..870853153e 100644 --- a/search/CMakeLists.txt +++ b/search/CMakeLists.txt @@ -123,6 +123,7 @@ set( token_slice.hpp types_skipper.cpp types_skipper.hpp + utils.cpp utils.hpp viewport_search_callback.cpp viewport_search_callback.hpp diff --git a/search/common.hpp b/search/common.hpp index 2bdc9300a6..abd8a88c71 100644 --- a/search/common.hpp +++ b/search/common.hpp @@ -5,16 +5,4 @@ namespace search /// Upper bound for max count of tokens for indexing and scoring. int constexpr MAX_TOKENS = 32; int constexpr MAX_SUGGESTS_COUNT = 5; - -template -bool StartsWith(IterT1 beg, IterT1 end, IterT2 begPrefix, IterT2 endPrefix) -{ - while (beg != end && begPrefix != endPrefix && *beg == *begPrefix) - { - ++beg; - ++begPrefix; - } - return begPrefix == endPrefix; -} - } // namespace search diff --git a/search/feature_offset_match.hpp b/search/feature_offset_match.hpp index 1c57683304..f05cf5eff9 100644 --- a/search/feature_offset_match.hpp +++ b/search/feature_offset_match.hpp @@ -221,7 +221,7 @@ struct SearchTrieRequest QueryParams::Langs m_langs; }; -// Calls |toDo| for each feature accepted but at least one DFA. +// Calls |toDo| for each feature accepted by at least one DFA. // // *NOTE* |toDo| may be called several times for the same feature. template diff --git a/search/geocoder.cpp b/search/geocoder.cpp index 6a279ed36e..41a65e01ef 100644 --- a/search/geocoder.cpp +++ b/search/geocoder.cpp @@ -315,18 +315,6 @@ size_t OrderCountries(m2::RectD const & pivot, vector> & inf auto const sep = stable_partition(infos.begin(), infos.end(), intersects); return distance(infos.begin(), sep); } - -size_t GetMaxErrorsForToken(UniString const & token) -{ - bool const digitsOnly = all_of(token.begin(), token.end(), isdigit); - if (digitsOnly) - return 0; - if (token.size() < 4) - return 0; - if (token.size() < 8) - return 1; - return 2; -} } // namespace // Geocoder::Params -------------------------------------------------------------------------------- diff --git a/search/keyword_matcher.cpp b/search/keyword_matcher.cpp index 3b7ddcee60..073a9f80b8 100644 --- a/search/keyword_matcher.cpp +++ b/search/keyword_matcher.cpp @@ -1,9 +1,10 @@ -#include "keyword_matcher.hpp" +#include "search/keyword_matcher.hpp" #include "indexer/search_delimiters.hpp" #include "indexer/search_string_utils.hpp" #include "base/stl_add.hpp" +#include "base/string_utils.hpp" #include "std/algorithm.hpp" #include "std/sstream.hpp" @@ -67,7 +68,7 @@ KeywordMatcher::ScoreT KeywordMatcher::Score(StringT const * tokens, size_t coun bPrefixMatched = false; for (int j = 0; j < count && !bPrefixMatched; ++j) if (!isNameTokenMatched[j] && - StartsWith(tokens[j].begin(), tokens[j].end(), m_prefix.begin(), m_prefix.end())) + strings::StartsWith(tokens[j].begin(), tokens[j].end(), m_prefix.begin(), m_prefix.end())) { isNameTokenMatched[j] = bPrefixMatched = true; int8_t const tokenMatchDistance = int(m_keywords.size()) - j; diff --git a/search/processor.cpp b/search/processor.cpp index ea1042eb60..1382b9db5c 100644 --- a/search/processor.cpp +++ b/search/processor.cpp @@ -128,8 +128,8 @@ void SendStatistics(SearchParams const & params, m2::RectD const & viewport, Res GetPlatform().GetMarketingService().SendMarketingEvent(marketing::kSearchEmitResultsAndCoords, {}); } -// Removes all full-token stop words from |params|, unless |params| -// consists of all such tokens. +// Removes all full-token stop words from |params|. +// Does nothing if all tokens in |params| are non-prefix stop words. void RemoveStopWordsIfNeeded(QueryParams & params) { size_t numStopWords = 0; @@ -331,6 +331,7 @@ int8_t Processor::GetLanguage(int id) const { return m_ranker.GetLanguage(GetLangIndex(id)); } + m2::PointD Processor::GetPivotPoint() const { bool const viewportSearch = m_mode == Mode::Viewport; @@ -408,9 +409,16 @@ TLocales Processor::GetCategoryLocales() const } template -void Processor::ForEachCategoryType(StringSliceBase const & slice, ToDo && todo) const +void Processor::ForEachCategoryType(StringSliceBase const & slice, ToDo && toDo) const { - ::search::ForEachCategoryType(slice, GetCategoryLocales(), m_categories, forward(todo)); + ::search::ForEachCategoryType(slice, GetCategoryLocales(), m_categories, forward(toDo)); +} + +template +void Processor::ForEachCategoryTypeFuzzy(StringSliceBase const & slice, ToDo && toDo) const +{ + ::search::ForEachCategoryTypeFuzzy(slice, GetCategoryLocales(), m_categories, + forward(toDo)); } void Processor::Search(SearchParams const & params, m2::RectD const & viewport) @@ -671,11 +679,9 @@ void Processor::InitParams(QueryParams & params) } } }; - ForEachCategoryType(QuerySliceOnRawStrings(m_tokens, m_prefix), addSyms); - auto & langs = params.GetLangs(); - for (int i = 0; i < LANG_COUNT; ++i) - langs.Insert(GetLanguage(i)); + // todo(@m, @y). Shall we match prefix tokens for categories? + ForEachCategoryTypeFuzzy(QuerySliceOnRawStrings(m_tokens, m_prefix), addSyms); RemoveStopWordsIfNeeded(params); @@ -687,6 +693,12 @@ void Processor::InitParams(QueryParams & params) if (IsStreetSynonym(token.m_original)) params.GetTypeIndices(i).clear(); } + + for (size_t i = 0; i < params.GetNumTokens(); ++i) + my::SortUnique(params.GetTypeIndices(i)); + + for (int i = 0; i < LANG_COUNT; ++i) + params.GetLangs().Insert(GetLanguage(i)); } void Processor::InitGeocoder(Geocoder::Params & params) diff --git a/search/processor.hpp b/search/processor.hpp index 5e784954fd..3ad34f7ad1 100644 --- a/search/processor.hpp +++ b/search/processor.hpp @@ -139,7 +139,10 @@ protected: TLocales GetCategoryLocales() const; template - void ForEachCategoryType(StringSliceBase const & slice, ToDo && todo) const; + void ForEachCategoryType(StringSliceBase const & slice, ToDo && toDo) const; + + template + void ForEachCategoryTypeFuzzy(StringSliceBase const & slice, ToDo && toDo) const; m2::PointD GetPivotPoint() const; m2::RectD GetPivotRect() const; diff --git a/search/ranker.cpp b/search/ranker.cpp index 84b4048f2c..b2e708a162 100644 --- a/search/ranker.cpp +++ b/search/ranker.cpp @@ -8,6 +8,7 @@ #include "indexer/feature_algo.hpp" #include "base/logging.hpp" +#include "base/string_utils.hpp" #include "std/algorithm.hpp" #include "std/iterator.hpp" @@ -433,7 +434,7 @@ void Ranker::MatchForSuggestions(strings::UniString const & token, int8_t locale if ((suggest.m_prefixLength <= token.size()) && (token != s) && // do not push suggestion if it already equals to token (suggest.m_locale == locale) && // push suggestions only for needed language - StartsWith(s.begin(), s.end(), token.begin(), token.end())) + strings::StartsWith(s.begin(), s.end(), token.begin(), token.end())) { string const utf8Str = strings::ToUtf8(s); Result r(utf8Str, prologue + utf8Str + " "); diff --git a/search/search.pro b/search/search.pro index abb050f3c0..2c11aa354f 100644 --- a/search/search.pro +++ b/search/search.pro @@ -135,4 +135,5 @@ SOURCES += \ streets_matcher.cpp \ token_slice.cpp \ types_skipper.cpp \ + utils.cpp \ viewport_search_callback.cpp diff --git a/search/search_integration_tests/processor_test.cpp b/search/search_integration_tests/processor_test.cpp index ef8fdc3bc4..63aa3b5342 100644 --- a/search/search_integration_tests/processor_test.cpp +++ b/search/search_integration_tests/processor_test.cpp @@ -754,6 +754,9 @@ UNIT_CLASS_TEST(ProcessorTest, FuzzyMatch) TestPOI bar(m2::PointD(0, 0), "Черчилль", "ru"); bar.SetTypes({{"amenity", "pub"}}); + TestPOI metro(m2::PointD(5.0, 5.0), "Liceu", "es"); + metro.SetTypes({{"railway", "subway_entrance"}}); + BuildWorld([&](TestMwmBuilder & builder) { builder.Add(country); builder.Add(city); @@ -762,6 +765,7 @@ UNIT_CLASS_TEST(ProcessorTest, FuzzyMatch) auto id = BuildCountry(countryName, [&](TestMwmBuilder & builder) { builder.Add(street); builder.Add(bar); + builder.Add(metro); }); SetViewport(m2::RectD(m2::PointD(-1.0, -1.0), m2::PointD(1.0, 1.0))); @@ -778,6 +782,17 @@ UNIT_CLASS_TEST(ProcessorTest, FuzzyMatch) TEST(ResultsMatch("масква ленинргадский чирчиль", "ru", TRules{}), ()); TEST(ResultsMatch("моксва ленинргадский черчиль", "ru", rules), ()); + + TEST(ResultsMatch("food", "ru", rules), ()); + TEST(ResultsMatch("foood", "ru", rules), ()); + TEST(ResultsMatch("fod", "ru", TRules{}), ()); + + TRules rulesMetro = {ExactMatch(id, metro)}; + TEST(ResultsMatch("transporte", "es", rulesMetro), ()); + TEST(ResultsMatch("transport", "es", rulesMetro), ()); + TEST(ResultsMatch("transpurt", "en", rulesMetro), ()); + TEST(ResultsMatch("transpurrt", "es", rulesMetro), ()); + TEST(ResultsMatch("transportation", "en", TRules{}), ()); } } diff --git a/search/utils.hpp b/search/utils.hpp index c36f74736a..e64fc20a5c 100644 --- a/search/utils.hpp +++ b/search/utils.hpp @@ -3,14 +3,72 @@ #include "search/token_slice.hpp" #include "indexer/categories_holder.hpp" +#include "indexer/search_delimiters.hpp" +#include "indexer/search_string_utils.hpp" #include "base/buffer_vector.hpp" +#include "base/levenshtein_dfa.hpp" +#include "base/stl_helpers.hpp" #include "base/string_utils.hpp" +#include +#include +#include +#include +#include + namespace search { +// todo(@m, @y). Unite with the similar function in search/feature_offset_match.hpp. +template +bool MatchInTrie(Trie const & /* trie */, typename Trie::Iterator const & trieStartIt, + DFA const & dfa, ToDo && toDo) +{ + using Char = typename Trie::Char; + using TrieIt = typename Trie::Iterator; + using DFAIt = typename DFA::Iterator; + using State = pair; + + std::queue q; + + { + auto it = dfa.Begin(); + if (it.Rejects()) + return false; + q.emplace(trieStartIt, it); + } + + bool found = false; + + while (!q.empty()) + { + auto const p = q.front(); + q.pop(); + + auto const & trieIt = p.first; + auto const & dfaIt = p.second; + + if (dfaIt.Accepts()) + { + trieIt.ForEachInNode(toDo); + found = true; + } + + trieIt.ForEachMove([&](Char const & c, TrieIt const & nextTrieIt) { + auto nextDfaIt = dfaIt; + nextDfaIt.Move(c); + if (!nextDfaIt.Rejects()) + q.emplace(nextTrieIt, nextDfaIt); + }); + } + + return found; +} + using TLocales = buffer_vector; +size_t GetMaxErrorsForToken(strings::UniString const & token); + template void ForEachCategoryType(StringSliceBase const & slice, TLocales const & locales, CategoriesHolder const & categories, ToDo && todo) @@ -18,16 +76,50 @@ void ForEachCategoryType(StringSliceBase const & slice, TLocales const & locales for (size_t i = 0; i < slice.Size(); ++i) { auto const & token = slice.Get(i); - for (size_t j = 0; j < locales.size(); ++j) - categories.ForEachTypeByName(locales[j], token, bind(todo, i, _1)); + for (int8_t const locale : locales) + categories.ForEachTypeByName(locale, token, std::bind(todo, i, std::placeholders::_1)); // Special case processing of 2 codepoints emoji (e.g. black guy on a bike). // Only emoji synonyms can have one codepoint. if (token.size() > 1) { categories.ForEachTypeByName(CategoriesHolder::kEnglishCode, strings::UniString(1, token[0]), - bind(todo, i, _1)); + std::bind(todo, i, std::placeholders::_1)); } } } + +// Unlike ForEachCategoryType which extracts types by a token +// from |slice| by exactly matching it to a token in the name +// of a category, in the worst case this function has to loop through the tokens +// in all category synonyms in all |locales| in order to find a token +// whose edit distance is close enough to the required token from |slice|. +template +void ForEachCategoryTypeFuzzy(StringSliceBase const & slice, TLocales const & locales, + CategoriesHolder const & categories, ToDo && todo) +{ + using Trie = my::MemTrie; + + auto const & trie = categories.GetNameToTypesTrie(); + auto const & trieRootIt = trie.GetRootIterator(); + vector sortedLocales(locales.begin(), locales.end()); + my::SortUnique(sortedLocales); + + for (size_t i = 0; i < slice.Size(); ++i) + { + auto const & token = slice.Get(i); + // todo(@m, @y). We build dfa twice for each token: here and in geocoder.cpp. + // A possible optimization is to build each dfa once and save it. Note that + // dfas for the prefix tokens differ, i.e. we ignore slice.IsPrefix(i) here. + strings::LevenshteinDFA const dfa(token, 1 /* prefixCharsToKeep */, GetMaxErrorsForToken(token)); + + trieRootIt.ForEachMove([&](Trie::Char const & c, Trie::Iterator const & moveIt) { + if (std::binary_search(sortedLocales.begin(), sortedLocales.end(), static_cast(c))) + { + MatchInTrie(trie /* passed to infer the iterator's type */, moveIt, dfa, + std::bind(todo, i, std::placeholders::_1)); + } + }); + } +} } // namespace search