diff --git a/base/base_tests/mem_trie_test.cpp b/base/base_tests/mem_trie_test.cpp index f6e3f5f6b8..892b242d10 100644 --- a/base/base_tests/mem_trie_test.cpp +++ b/base/base_tests/mem_trie_test.cpp @@ -3,17 +3,19 @@ #include "base/mem_trie.hpp" #include +#include #include #include #include +using namespace base; using namespace std; namespace { using Key = string; using Value = int; -using Trie = my::MemTrie>; +using Trie = MemTrie>; using Data = vector>; Data GetTrieContents(Trie const & trie) @@ -27,55 +29,84 @@ Data GetTrieContents(Trie const & trie) class MemTrieTest { public: - Data GetActualContents() const { return ::GetTrieContents(m_trie); } + Data GetActualContents() const + { + Data data; + m_trie.ForEachInTrie([&data](Key const & k, Value const & v) { data.emplace_back(k, v); }); + sort(data.begin(), data.end()); + return data; + } - Data GetExpectedContents() const { return m_data; } + Data GetExpectedContents() const { return {m_data.cbegin(), m_data.cend()}; } + + vector GetValuesByKey(Key const & key) const + { + vector values; + m_trie.ForEachInNode(key, [&](Value const & value) { values.push_back(value); }); + sort(values.begin(), values.end()); + return values; + } + + Data GetContentsByPrefix(Key const & prefix) const + { + Data data; + m_trie.ForEachInSubtree(prefix, + [&data](Key const & k, Value const & v) { data.emplace_back(k, v); }); + sort(data.begin(), data.end()); + return data; + } + + bool HasKey(Key const & key) const { return m_trie.HasKey(key); } + bool HasPrefix(Key const & prefix) const { return m_trie.HasPrefix(prefix); } size_t GetNumNodes() const { return m_trie.GetNumNodes(); } void Add(Key const & key, Value const & value) { m_trie.Add(key, value); - - auto const kv = make_pair(key, value); - auto it = lower_bound(m_data.begin(), m_data.end(), kv); - m_data.insert(it, kv); + m_data.insert(make_pair(key, value)); } void Erase(Key const & key, Value const & value) { m_trie.Erase(key, value); - - auto const kv = make_pair(key, value); - auto it = lower_bound(m_data.begin(), m_data.end(), kv); - if (it != m_data.end() && *it == kv) - m_data.erase(it); + m_data.erase(make_pair(key, value)); } protected: Trie m_trie; - Data m_data; + multiset> m_data; }; UNIT_CLASS_TEST(MemTrieTest, Basic) { TEST_EQUAL(GetNumNodes(), 1, ()); + TEST(!HasKey(""), ()); + TEST(!HasKey("a"), ()); + TEST(!HasPrefix(""), ()); + TEST(!HasPrefix("a"), ()); + Data const data = {{"roger", 3}, {"amy", 1}, {"emma", 1}, {"ann", 1}, {"rob", 1}, {"roger", 2}, {"", 0}, {"roger", 1}}; for (auto const & kv : data) Add(kv.first, kv.second); - TEST_EQUAL(GetNumNodes(), 16, ()); + + TEST_EQUAL(GetNumNodes(), 8, ()); + TEST(HasKey(""), ()); + TEST(!HasKey("a"), ()); + TEST(HasPrefix(""), ()); + TEST(HasPrefix("a"), ()); TEST_EQUAL(GetExpectedContents(), GetActualContents(), ()); - TEST_EQUAL(GetNumNodes(), 16, ()); + TEST_EQUAL(GetNumNodes(), 8, ()); Trie newTrie(move(m_trie)); TEST_EQUAL(m_trie.GetNumNodes(), 1, ()); TEST(GetTrieContents(m_trie).empty(), ()); - TEST_EQUAL(newTrie.GetNumNodes(), 16, ()); + TEST_EQUAL(newTrie.GetNumNodes(), 8, ()); TEST_EQUAL(GetTrieContents(newTrie), GetExpectedContents(), ()); } @@ -83,32 +114,78 @@ UNIT_CLASS_TEST(MemTrieTest, KeysRemoval) { TEST_EQUAL(GetNumNodes(), 1, ()); + TEST(!HasKey("r"), ()); + TEST(!HasPrefix("r"), ()); + TEST(!HasKey("ro"), ()); + TEST(!HasPrefix("ro"), ()); + Data const data = {{"bobby", 1}, {"robby", 2}, {"rob", 3}, {"r", 4}, {"robert", 5}, {"bob", 6}}; for (auto const & kv : data) Add(kv.first, kv.second); - TEST_EQUAL(GetNumNodes(), 14, ()); + TEST(HasKey("r"), ()); + TEST(HasPrefix("r"), ()); + TEST(!HasKey("ro"), ()); + TEST(HasPrefix("ro"), ()); + + TEST_EQUAL(GetNumNodes(), 7, ()); TEST_EQUAL(GetExpectedContents(), GetActualContents(), ()); Erase("r", 3); - TEST_EQUAL(GetNumNodes(), 14, ()); + TEST_EQUAL(GetNumNodes(), 7, ()); TEST_EQUAL(GetExpectedContents(), GetActualContents(), ()); Erase("r", 4); - TEST_EQUAL(GetNumNodes(), 14, ()); + TEST_EQUAL(GetNumNodes(), 6, ()); TEST_EQUAL(GetExpectedContents(), GetActualContents(), ()); Erase("robert", 5); - TEST_EQUAL(GetNumNodes(), 11, ()); + TEST_EQUAL(GetNumNodes(), 5, ()); TEST_EQUAL(GetExpectedContents(), GetActualContents(), ()); Erase("rob", 3); - TEST_EQUAL(GetNumNodes(), 11, ()); + TEST_EQUAL(GetNumNodes(), 4, ()); TEST_EQUAL(GetExpectedContents(), GetActualContents(), ()); Erase("robby", 2); - TEST_EQUAL(GetNumNodes(), 6, ()); + TEST_EQUAL(GetNumNodes(), 3, ()); TEST_EQUAL(GetExpectedContents(), GetActualContents(), ()); } + +UNIT_CLASS_TEST(MemTrieTest, ForEachInNode) +{ + Add("abracadabra", 0); + Add("abra", 1); + Add("abra", 2); + Add("abrau", 3); + + TEST_EQUAL(GetValuesByKey("a"), vector{}, ()); + TEST_EQUAL(GetValuesByKey("abrac"), vector{}, ()); + TEST_EQUAL(GetValuesByKey("abracadabr"), vector{}, ()); + TEST_EQUAL(GetValuesByKey("void"), vector{}, ()); + + TEST_EQUAL(GetValuesByKey("abra"), vector({1, 2}), ()); + TEST_EQUAL(GetValuesByKey("abracadabra"), vector({0}), ()); + TEST_EQUAL(GetValuesByKey("abrau"), vector({3}), ()); +} + +UNIT_CLASS_TEST(MemTrieTest, ForEachInSubtree) +{ + Add("abracadabra", 0); + Add("abra", 1); + Add("abra", 2); + Add("abrau", 3); + + Data const all = {{"abra", 1}, {"abra", 2}, {"abracadabra", 0}, {"abrau", 3}}; + + TEST_EQUAL(GetContentsByPrefix(""), all, ()); + TEST_EQUAL(GetContentsByPrefix("a"), all, ()); + TEST_EQUAL(GetContentsByPrefix("abra"), all, ()); + TEST_EQUAL(GetContentsByPrefix("abracadabr"), Data({{"abracadabra", 0}}), ()); + TEST_EQUAL(GetContentsByPrefix("abracadabra"), Data({{"abracadabra", 0}}), ()); + TEST_EQUAL(GetContentsByPrefix("void"), Data{}, ()); + TEST_EQUAL(GetContentsByPrefix("abra"), all, ()); + TEST_EQUAL(GetContentsByPrefix("abrau"), Data({{"abrau", 3}}), ()); +} } // namespace diff --git a/base/mem_trie.hpp b/base/mem_trie.hpp index 723b5f2b05..234a4e0cd0 100644 --- a/base/mem_trie.hpp +++ b/base/mem_trie.hpp @@ -1,18 +1,19 @@ #pragma once #include "base/assert.hpp" +#include "base/buffer_vector.hpp" #include "base/macros.hpp" #include "base/stl_add.hpp" #include #include -#include #include #include #include +#include #include -namespace my +namespace base { template class MapMoves @@ -25,6 +26,13 @@ public: toDo(subtree.first, *subtree.second); } + template + void ForEach(ToDo && toDo) + { + for (auto const & subtree : m_subtrees) + toDo(subtree.first, *subtree.second); + } + Subtree * GetSubtree(Char const & c) const { auto const it = m_subtrees.find(c); @@ -48,20 +56,83 @@ public: return *node; } + void AddSubtree(Char const & c, std::unique_ptr subtree) + { + ASSERT(!GetSubtree(c), ()); + m_subtrees.emplace(c, move(subtree)); + } + void EraseSubtree(Char const & c) { m_subtrees.erase(c); } + size_t Size() const { return m_subtrees.size(); } bool Empty() const { return m_subtrees.empty(); } void Clear() { m_subtrees.clear(); } + void Swap(MapMoves & rhs) { m_subtrees.swap(rhs.m_subtrees); } + private: std::map> m_subtrees; }; -template +template +class VectorMoves +{ +public: + template + void ForEach(ToDo && toDo) const + { + for (auto const & subtree : m_subtrees) + toDo(subtree.first, *subtree.second); + } + + Subtree * GetSubtree(Char const & c) const + { + for (auto const & subtree : m_subtrees) + { + if (subtree.first == c) + return subtree.second.get(); + } + return nullptr; + } + + Subtree & GetOrCreateSubtree(Char const & c, bool & created) + { + for (size_t i = 0; i < m_subtrees.size(); ++i) + { + if (m_subtrees[i].first == c) + { + created = false; + return *m_subtrees[i].second; + } + } + + created = true; + m_subtrees.emplace_back(c, my::make_unique()); + return *m_subtrees.back().second; + } + + void AddSubtree(Char const & c, std::unique_ptr subtree) + { + ASSERT(!GetSubtree(c), ()); + m_subtrees.emplace_back(c, std::move(subtree)); + } + + bool Empty() const { return m_subtrees.empty(); } + + void Clear() { m_subtrees.clear(); } + + void Swap(VectorMoves & rhs) { m_subtrees.swap(rhs.m_subtrees); } + +private: + std::vector>> m_subtrees; +}; + +template struct VectorValues { - using value_type = Value; + using value_type = V; + using Value = V; template void Add(Args &&... args) @@ -79,13 +150,16 @@ struct VectorValues template void ForEach(ToDo && toDo) const { - std::for_each(m_values.begin(), m_values.end(), std::forward(toDo)); + for (auto const & value : m_values) + toDo(value); } bool Empty() const { return m_values.empty(); } void Clear() { m_values.clear(); } + void Swap(VectorValues & rhs) { m_values.swap(rhs.m_values); } + std::vector m_values; }; @@ -107,7 +181,6 @@ public: MemTrie & operator=(MemTrie && rhs) { m_root = std::move(rhs.m_root); - m_numNodes = rhs.m_numNodes; rhs.Clear(); return *this; } @@ -137,6 +210,7 @@ public: m_node.m_values.ForEach(std::forward(toDo)); } + String GetLabel() const { return m_node.m_edge.template As(); } ValuesHolder const & GetValues() const { return m_node.m_values; } private: @@ -147,20 +221,57 @@ public: template void Add(String const & key, Args &&... args) { - auto * cur = &m_root; - for (auto const & c : key) + auto * curr = &m_root; + + auto it = key.begin(); + while (it != key.end()) { bool created; - cur = &cur->GetMove(c, created); + curr = &curr->GetOrCreateMove(*it, created); + + auto & edge = curr->m_edge; + if (created) - ++m_numNodes; + { + edge.Assign(it, key.end()); + it = key.end(); + continue; + } + + ASSERT(!edge.Empty(), ()); + ASSERT_EQUAL(edge[0], *it, ()); + + size_t i = 0; + SkipEqual(edge, key.end(), i, it); + + if (i == edge.Size()) + { + // We may directly add value to the |curr| values, when edge + // equals to the rest of the |key|. Otherwise we need to jump + // to the next iteration of the loop and continue traversal of + // the trie. + continue; + } + + // We need to split the edge to |curr|. + auto node = my::make_unique(); + + ASSERT_LESS(i, edge.Size(), ()); + node->m_edge = edge.Drop(i); + + ASSERT(!edge.Empty(), ()); + auto const next = edge[0]; + + node->Swap(*curr); + curr->AddChild(next, std::move(node)); } - cur->Add(std::forward(args)...); + + curr->AddValue(std::forward(args)...); } void Erase(String const & key, Value const & value) { - return Erase(m_root, 0 /* level */, key, value); + Erase(m_root, key.begin(), key.end(), value); } // Traverses all key-value pairs in the trie and calls |toDo| on each of them. @@ -177,31 +288,106 @@ public: template void ForEachInNode(String const & prefix, ToDo && toDo) const { - if (auto const * root = MoveTo(prefix)) - ForEachInNode(*root, prefix, std::forward(toDo)); + MoveTo(prefix, true /* fullMatch */, + [&](Node const & node, Edge const & /* edge */, size_t /* offset */) { + node.m_values.ForEach(std::forward(toDo)); + }); } // Calls |toDo| for each key-value pair in a subtree that is // reachable by |prefix| from the trie root. Does nothing if such // subtree does not exist. template - void ForEachInSubtree(String prefix, ToDo && toDo) const + void ForEachInSubtree(String const & prefix, ToDo && toDo) const { - if (auto const * root = MoveTo(prefix)) - ForEachInSubtree(*root, prefix, std::forward(toDo)); + MoveTo(prefix, false /* fullMatch */, [&](Node const & node, Edge const & edge, size_t offset) { + String p = prefix; + for (; offset < edge.Size(); ++offset) + p.push_back(edge[offset]); + ForEachInSubtree(node, p, std::forward(toDo)); + }); } - void Clear() + bool HasKey(String const & key) const { - m_root.Clear(); - m_numNodes = 1; + bool exists = false; + MoveTo(key, true /* fullMatch */, + [&](Node const & node, Edge const & /* edge */, size_t /* offset */) { + exists = !node.m_values.Empty(); + }); + return exists; } - size_t GetNumNodes() const { return m_numNodes; } + bool HasPrefix(String const & prefix) const + { + bool exists = false; + MoveTo(prefix, false /* fullMatch */, [&](Node const & node, Edge const & /* edge */, + size_t /* offset */) { exists = !node.Empty(); }); + return exists; + } + + void Clear() { m_root.Clear(); } + + size_t GetNumNodes() const { return m_root.GetNumNodes(); } + Iterator GetRootIterator() const { return Iterator(m_root); } + Node const & GetRoot() const { return m_root; } private: + class Edge + { + public: + Edge() = default; + + template + Edge(It begin, It end) + { + Assign(begin, end); + } + + template + void Assign(It begin, It end) + { + m_label.assign(begin, end); + std::reverse(m_label.begin(), m_label.end()); + } + + Edge Drop(size_t n) + { + ASSERT_LESS_OR_EQUAL(n, Size(), ()); + + Edge prefix(m_label.rbegin(), m_label.rbegin() + n); + m_label.erase(m_label.begin() + Size() - n, m_label.end()); + return prefix; + } + + void Prepend(Edge const & prefix) + { + m_label.insert(m_label.end(), prefix.m_label.begin(), prefix.m_label.end()); + } + + Char operator[](size_t i) const + { + ASSERT_LESS(i, Size(), ()); + return *(m_label.rbegin() + i); + } + + size_t Size() const { return m_label.size(); } + + bool Empty() const { return Size() == 0; } + + void Swap(Edge & rhs) { m_label.swap(rhs.m_label); } + + template + Sequence As() const { return {m_label.rbegin(), m_label.rend()}; } + + friend std::string DebugPrint(Edge const & edge) { return edge.template As(); } + + private: + std::vector m_label; + }; + struct Node { Node() = default; @@ -209,17 +395,22 @@ private: Node & operator=(Node && /* rhs */) = default; - Node & GetMove(Char const & c, bool & created) + Node & GetOrCreateMove(Char const & c, bool & created) { return m_moves.GetOrCreateSubtree(c, created); } Node * GetMove(Char const & c) const { return m_moves.GetSubtree(c); } + void AddChild(Char const & c, std::unique_ptr node) + { + m_moves.AddSubtree(c, std::move(node)); + } + void EraseMove(Char const & c) { m_moves.EraseSubtree(c); } template - void Add(Args &&... args) + void AddValue(Args &&... args) { m_values.Add(std::forward(args)...); } @@ -237,51 +428,103 @@ private: m_values.Clear(); } + size_t GetNumNodes() const + { + size_t size = 1; + m_moves.ForEach( + [&size](Char /* c */, Node const & child) { size += child.GetNumNodes(); }); + return size; + } + + void Swap(Node & rhs) + { + m_edge.Swap(rhs.m_edge); + m_moves.Swap(rhs.m_moves); + m_values.Swap(rhs.m_values); + } + + Edge m_edge; Moves m_moves; ValuesHolder m_values; DISALLOW_COPY(Node); }; - Node const * MoveTo(String const & key) const + template + void MoveTo(String const & prefix, bool fullMatch, Fn && fn) const { auto const * cur = &m_root; - for (auto const & c : key) + + auto it = prefix.begin(); + + if (it == prefix.end()) { - cur = cur->GetMove(c); - if (!cur) - break; + fn(*cur, cur->m_edge, 0 /* offset */); + return; + } + + while (true) + { + ASSERT(it != prefix.end(), ()); + + cur = cur->GetMove(*it); + if (!cur) + return; + + auto const & edge = cur->m_edge; + size_t i = 0; + SkipEqual(edge, prefix.end(), i, it); + + if (i < edge.Size()) + { + if (it != prefix.end() || fullMatch) + return; + } + + ASSERT(i == edge.Size() || (it == prefix.end() && !fullMatch), ()); + + if (it == prefix.end()) + { + fn(*cur, edge, i /* offset */); + return; + } + + ASSERT(it != prefix.end() && i == edge.Size(), ()); } - return cur; } - void Erase(Node & root, size_t level, String const & key, Value const & value) + template + void Erase(Node & root, It cur, It end, Value const & value) { - if (level == key.size()) + if (cur == end) { root.EraseValue(value); + if (root.m_values.Empty() && root.m_moves.Size() == 1) + { + Node child; + root.m_moves.ForEach([&](Char const & /* c */, Node & node) { child.Swap(node); }); + child.m_edge.Prepend(root.m_edge); + root.Swap(child); + } return; } - ASSERT_LESS(level, key.size(), ()); - auto * child = root.GetMove(key[level]); + auto const symbol = *cur; + + auto * child = root.GetMove(symbol); if (!child) return; - Erase(*child, level + 1, key, value); - if (child->Empty()) - { - root.EraseMove(key[level]); - --m_numNodes; - } - } - // Calls |toDo| for each key-value pair in a |node| that is - // reachable by |prefix| from the trie root. - template - void ForEachInNode(Node const & node, String const & prefix, ToDo && toDo) const - { - node.m_values.ForEach( - std::bind(std::forward(toDo), std::ref(prefix), std::placeholders::_1)); + auto const & edge = child->m_edge; + size_t i = 0; + SkipEqual(edge, end, i, cur); + + if (i == edge.Size()) + { + Erase(*child, cur, end, value); + if (child->Empty()) + root.EraseMove(symbol); + } } // Calls |toDo| for each key-value pair in subtree where |node| is a @@ -290,19 +533,30 @@ private: template void ForEachInSubtree(Node const & node, String & prefix, ToDo && toDo) const { - ForEachInNode(node, prefix, toDo); + node.m_values.ForEach([&prefix, &toDo](Value const & value) { toDo(prefix, value); }); node.m_moves.ForEach([&](Char c, Node const & node) { - prefix.push_back(c); + auto const size = prefix.size(); + auto const edge = node.m_edge.template As(); + prefix.insert(prefix.end(), edge.begin(), edge.end()); ForEachInSubtree(node, prefix, toDo); - prefix.pop_back(); + prefix.resize(size); }); } + template + void SkipEqual(Edge const & edge, It end, size_t & i, It & cur) const + { + while (i < edge.Size() && cur != end && edge[i] == *cur) + { + ++i; + ++cur; + } + } + Node m_root; - size_t m_numNodes = 1; DISALLOW_COPY(MemTrie); }; -} // namespace my +} // namespace base diff --git a/indexer/categories_holder.hpp b/indexer/categories_holder.hpp index 4240e64f6e..9d3562f6a6 100644 --- a/indexer/categories_holder.hpp +++ b/indexer/categories_holder.hpp @@ -52,7 +52,7 @@ public: private: using String = strings::UniString; using Type2CategoryCont = multimap>; - using Trie = my::MemTrie>; + using Trie = base::MemTrie>; Type2CategoryCont m_type2cat; @@ -113,8 +113,7 @@ public: void ForEachTypeByName(int8_t locale, String const & name, ToDo && toDo) const { auto const localePrefix = String(1, static_cast(locale)); - m_name2type.ForEachInNode(localePrefix + name, - my::MakeIgnoreFirstArgument(forward(toDo))); + m_name2type.ForEachInNode(localePrefix + name, forward(toDo)); } inline GroupTranslations const & GetGroupTranslations() const { return m_groupTranslations; } diff --git a/indexer/categories_index.cpp b/indexer/categories_index.cpp index c67b8416cc..7b5632bbee 100644 --- a/indexer/categories_index.cpp +++ b/indexer/categories_index.cpp @@ -12,7 +12,7 @@ namespace { -void AddAllNonemptySubstrings(my::MemTrie> & trie, +void AddAllNonemptySubstrings(base::MemTrie> & trie, string const & s, uint32_t value) { ASSERT(!s.empty(), ()); @@ -37,7 +37,7 @@ void ForEachToken(string const & s, TF && fn) fn(strings::ToUtf8(token)); } -void TokenizeAndAddAllSubstrings(my::MemTrie> & trie, +void TokenizeAndAddAllSubstrings(base::MemTrie> & trie, string const & s, uint32_t value) { auto fn = [&](string const & token) diff --git a/indexer/categories_index.hpp b/indexer/categories_index.hpp index 498fbfb489..3c3aad877f 100644 --- a/indexer/categories_index.hpp +++ b/indexer/categories_index.hpp @@ -69,6 +69,6 @@ private: // here because this class may be used from Objectvie-C // so a default constructor is needed. CategoriesHolder const * m_catHolder = nullptr; - my::MemTrie> m_trie; + base::MemTrie> m_trie; }; } // namespace indexer diff --git a/indexer/indexer_tests/categories_test.cpp b/indexer/indexer_tests/categories_test.cpp index aec616c418..62cfaedfbb 100644 --- a/indexer/indexer_tests/categories_test.cpp +++ b/indexer/indexer_tests/categories_test.cpp @@ -404,6 +404,7 @@ UNIT_TEST(CategoriesIndex_AllCategories) index.AddAllCategoriesInAllLangs(); // Consider deprecating this method if this bound rises as high as a million. + LOG(LINFO, ("Num of nodes in the CategoriesIndex trie:", index.GetNumTrieNodes())); TEST_LESS(index.GetNumTrieNodes(), 400000, ()); } #endif diff --git a/indexer/search_string_utils.cpp b/indexer/search_string_utils.cpp index 37e75e4052..d1ff0ab62c 100644 --- a/indexer/search_string_utils.cpp +++ b/indexer/search_string_utils.cpp @@ -7,6 +7,8 @@ #include "3party/utfcpp/source/utf8/unchecked.h" +#include + using namespace std; using namespace strings; @@ -145,7 +147,13 @@ public: { using value_type = bool; - void Add(bool value) { m_value = m_value || value; } + BooleanSum() { Clear(); } + + void Add(bool value) + { + m_value = m_value || value; + m_empty = false; + } template void ForEach(ToDo && toDo) const @@ -153,52 +161,22 @@ public: toDo(m_value); } - void Clear() { m_value = false; } - - bool m_value = false; - }; - - template - class Moves - { - public: - template - void ForEach(ToDo && toDo) const + void Clear() { - for (auto const & subtree : m_subtrees) - toDo(subtree.first, *subtree.second); + m_value = false; + m_empty = true; } - Subtree * GetSubtree(Char const & c) const + bool Empty() const { return m_empty; } + + void Swap(BooleanSum & rhs) { - for (auto const & subtree : m_subtrees) - { - if (subtree.first == c) - return subtree.second.get(); - } - return nullptr; + swap(m_value, rhs.m_value); + swap(m_empty, rhs.m_empty); } - Subtree & GetOrCreateSubtree(Char const & c, bool & created) - { - for (size_t i = 0; i < m_subtrees.size(); ++i) - { - if (m_subtrees[i].first == c) - { - created = false; - return *m_subtrees[i].second; - } - } - - created = true; - m_subtrees.emplace_back(c, make_unique()); - return *m_subtrees.back().second; - } - - void Clear() { m_subtrees.clear(); } - - private: - buffer_vector>, 8> m_subtrees; + bool m_value; + bool m_empty; }; StreetsSynonymsHolder() @@ -273,28 +251,11 @@ public: } } - bool MatchPrefix(UniString const & s) const - { - bool found = false; - m_strings.ForEachInNode(s, [&](UniString const & prefix, bool /* value */) { - ASSERT_EQUAL(s, prefix, ()); - found = true; - }); - return found; - } - - bool FullMatch(UniString const & s) const - { - bool found = false; - m_strings.ForEachInNode(s, [&](UniString const & prefix, bool value) { - ASSERT_EQUAL(s, prefix, ()); - found = value; - }); - return found; - } + bool MatchPrefix(UniString const & s) const { return m_strings.HasPrefix(s); } + bool FullMatch(UniString const & s) const { return m_strings.HasKey(s); } private: - my::MemTrie m_strings; + base::MemTrie m_strings; }; StreetsSynonymsHolder g_streets; diff --git a/indexer/trie.hpp b/indexer/trie.hpp index 26a57428da..eb67580b93 100644 --- a/indexer/trie.hpp +++ b/indexer/trie.hpp @@ -3,9 +3,12 @@ #include "base/assert.hpp" #include "base/base.hpp" #include "base/buffer_vector.hpp" +#include "base/mem_trie.hpp" +#include "base/stl_add.hpp" #include #include +#include namespace trie { @@ -25,6 +28,12 @@ struct Iterator struct Edge { using EdgeLabel = buffer_vector; + + Edge() = default; + + template + Edge(It begin, It end): m_label(begin, end) {} + EdgeLabel m_label; }; @@ -37,6 +46,40 @@ struct Iterator List m_values; }; +template +class MemTrieIterator final : public trie::Iterator +{ +public: + using Base = trie::Iterator; + + using Char = typename String::value_type; + using InnerIterator = typename base::MemTrie::Iterator; + + explicit MemTrieIterator(InnerIterator const & inIt) + { + Base::m_values = inIt.GetValues(); + inIt.ForEachMove([&](Char c, InnerIterator it) { + auto const label = it.GetLabel(); + Base::m_edges.emplace_back(label.begin(), label.end()); + m_moves.push_back(it); + }); + } + + ~MemTrieIterator() override = default; + + // Iterator overrides: + std::unique_ptr Clone() const override { return my::make_unique(*this); } + + std::unique_ptr GoToEdge(size_t i) const override + { + ASSERT_LESS(i, m_moves.size(), ()); + return my::make_unique(m_moves[i]); + } + +private: + std::vector m_moves; +}; + template void ForEachRef(Iterator const & it, ToDo && toDo, String const & s) { diff --git a/search/base/inverted_list.hpp b/search/base/inverted_list.hpp index 81581fd224..fb10f79e53 100644 --- a/search/base/inverted_list.hpp +++ b/search/base/inverted_list.hpp @@ -50,6 +50,8 @@ public: void Clear() { m_ids.clear(); } + void Swap(InvertedList & rhs) { m_ids.swap(rhs.m_ids); } + private: std::vector m_ids; }; diff --git a/search/base/mem_search_index.hpp b/search/base/mem_search_index.hpp index dfbc53891c..9821c5f546 100644 --- a/search/base/mem_search_index.hpp +++ b/search/base/mem_search_index.hpp @@ -27,38 +27,8 @@ public: using Token = strings::UniString; using Char = Token::value_type; using List = InvertedList; - using Trie = my::MemTrie; - - class Iterator : public trie::Iterator - { - public: - using Base = trie::Iterator; - using InnerIterator = typename Trie::Iterator; - - explicit Iterator(InnerIterator const & innerIt) - { - Base::m_values = innerIt.GetValues(); - innerIt.ForEachMove([&](Char c, InnerIterator it) { - Base::m_edges.emplace_back(); - Base::m_edges.back().m_label.push_back(c); - m_moves.push_back(it); - }); - } - - ~Iterator() override = default; - - // trie::Iterator overrides: - std::unique_ptr Clone() const override { return my::make_unique(*this); } - - std::unique_ptr GoToEdge(size_t i) const override - { - ASSERT_LESS(i, m_moves.size(), ()); - return my::make_unique(m_moves[i]); - } - - private: - std::vector m_moves; - }; + using Trie = ::base::MemTrie; + using Iterator = trie::MemTrieIterator; void Add(Id const & id, Doc const & doc) { diff --git a/search/cities_boundaries_table.cpp b/search/cities_boundaries_table.cpp index bd1cff62e3..81b30649b0 100644 --- a/search/cities_boundaries_table.cpp +++ b/search/cities_boundaries_table.cpp @@ -83,7 +83,7 @@ bool CitiesBoundariesTable::Load() size_t boundary = 0; localities.ForEach([&](uint64_t fid) { ASSERT_LESS(boundary, all.size(), ()); - m_table[base::asserted_cast(fid)] = move(all[boundary]); + m_table[::base::asserted_cast(fid)] = move(all[boundary]); ++boundary; }); ASSERT_EQUAL(boundary, all.size(), ()); diff --git a/search/feature_offset_match.hpp b/search/feature_offset_match.hpp index b50f982730..5be155062c 100644 --- a/search/feature_offset_match.hpp +++ b/search/feature_offset_match.hpp @@ -194,18 +194,29 @@ private: template struct SearchTrieRequest { - bool IsLangExist(int8_t lang) const { return m_langs.Contains(lang); } + template + void SetLangs(Langs const & langs) + { + m_langs.clear(); + for (auto const & lang : langs) + { + if (lang >= 0 && lang <= numeric_limits::max()) + m_langs.insert(static_cast(lang)); + } + } + + bool HasLang(int8_t lang) const { return m_langs.find(lang) != m_langs.cend(); } void Clear() { m_names.clear(); m_categories.clear(); - m_langs.Clear(); + m_langs.clear(); } std::vector m_names; std::vector m_categories; - QueryParams::Langs m_langs; + std::unordered_set m_langs; }; // Calls |toDo| for each feature accepted by at least one DFA. @@ -253,7 +264,7 @@ void ForEachLangPrefix(SearchTrieRequest const & request, auto const & edge = trieRoot.m_edges[langIx].m_label; ASSERT_GREATER_OR_EQUAL(edge.size(), 1, ()); int8_t const lang = static_cast(edge[0]); - if (edge[0] < search::kCategoriesLang && request.IsLangExist(lang)) + if (edge[0] < search::kCategoriesLang && request.HasLang(lang)) { auto const langRoot = trieRoot.GoToEdge(langIx); TrieRootPrefix langPrefix(*langRoot, edge); diff --git a/search/geocoder.cpp b/search/geocoder.cpp index 386646bf80..dcf4990ddc 100644 --- a/search/geocoder.cpp +++ b/search/geocoder.cpp @@ -316,7 +316,7 @@ CBV DecimateCianResults(CBV const & cbv) size_t const kMaxCianResults = 10000; minstd_rand rng(0); auto survivedIds = - base::RandomSample(base::checked_cast(cbv.PopCount()), kMaxCianResults, rng); + base::RandomSample(::base::checked_cast(cbv.PopCount()), kMaxCianResults, rng); sort(survivedIds.begin(), survivedIds.end()); auto it = survivedIds.begin(); vector setBits; @@ -380,7 +380,7 @@ void Geocoder::SetParams(Params const & params) }); for (auto const & index : m_params.GetTypeIndices(i)) request.m_categories.emplace_back(FeatureTypeToString(index)); - request.m_langs = m_params.GetLangs(); + request.SetLangs(m_params.GetLangs()); } else { @@ -390,7 +390,7 @@ void Geocoder::SetParams(Params const & params) }); for (auto const & index : m_params.GetTypeIndices(i)) request.m_categories.emplace_back(FeatureTypeToString(index)); - request.m_langs = m_params.GetLangs(); + request.SetLangs(m_params.GetLangs()); } } @@ -780,7 +780,7 @@ void Geocoder::MatchCategories(BaseContext & ctx, bool aroundPivot) } auto emit = [&](uint64_t bit) { - auto const featureId = base::asserted_cast(bit); + auto const featureId = ::base::asserted_cast(bit); Model::Type type; if (!GetTypeInGeocoding(ctx, featureId, type)) return; @@ -1009,7 +1009,7 @@ void Geocoder::CreateStreetsLayerAndMatchLowerLayers(BaseContext & ctx, vector sortedFeatures; sortedFeatures.reserve(base::checked_cast(prediction.m_features.PopCount())); prediction.m_features.ForEach([&sortedFeatures](uint64_t bit) { - sortedFeatures.push_back(base::asserted_cast(bit)); + sortedFeatures.push_back(::base::asserted_cast(bit)); }); layer.m_sortedFeatures = &sortedFeatures; @@ -1044,7 +1044,7 @@ void Geocoder::MatchPOIsAndBuildings(BaseContext & ctx, size_t curToken) if (m_filter->NeedToFilter(m_postcodes.m_features)) filtered = m_filter->Filter(m_postcodes.m_features); filtered.ForEach([&](uint64_t bit) { - auto const featureId = base::asserted_cast(bit); + auto const featureId = ::base::asserted_cast(bit); Model::Type type; if (GetTypeInGeocoding(ctx, featureId, type)) { @@ -1085,7 +1085,7 @@ void Geocoder::MatchPOIsAndBuildings(BaseContext & ctx, size_t curToken) vector features; m_postcodes.m_features.ForEach([&features](uint64_t bit) { - features.push_back(base::asserted_cast(bit)); + features.push_back(::base::asserted_cast(bit)); }); layer.m_sortedFeatures = &features; return FindPaths(ctx); @@ -1103,7 +1103,7 @@ void Geocoder::MatchPOIsAndBuildings(BaseContext & ctx, size_t curToken) // any. auto clusterize = [&](uint64_t bit) { - auto const featureId = base::asserted_cast(bit); + auto const featureId = ::base::asserted_cast(bit); Model::Type type; if (!GetTypeInGeocoding(ctx, featureId, type)) return; @@ -1164,7 +1164,7 @@ void Geocoder::MatchPOIsAndBuildings(BaseContext & ctx, size_t curToken) ends[i] = clusters[i].size(); filtered.ForEach([&](uint64_t bit) { - auto const featureId = base::asserted_cast(bit); + auto const featureId = ::base::asserted_cast(bit); bool found = false; for (size_t i = 0; i < kNumClusters && !found; ++i) { @@ -1393,7 +1393,7 @@ void Geocoder::MatchUnclassified(BaseContext & ctx, size_t curToken) auto emitUnclassified = [&](uint64_t bit) { - auto const featureId = base::asserted_cast(bit); + auto const featureId = ::base::asserted_cast(bit); Model::Type type; if (!GetTypeInGeocoding(ctx, featureId, type)) return; diff --git a/search/pre_ranker.cpp b/search/pre_ranker.cpp index 9d234a3331..c447b3e1d9 100644 --- a/search/pre_ranker.cpp +++ b/search/pre_ranker.cpp @@ -270,7 +270,7 @@ void PreRanker::FilterForViewportSearch() if (m <= old) { - for (size_t i : base::RandomSample(old, m, m_rng)) + for (size_t i : ::base::RandomSample(old, m, m_rng)) results.push_back(m_results[bucket[i]]); } else @@ -278,7 +278,7 @@ void PreRanker::FilterForViewportSearch() for (size_t i = 0; i < old; ++i) results.push_back(m_results[bucket[i]]); - for (size_t i : base::RandomSample(bucket.size() - old, m - old, m_rng)) + for (size_t i : ::base::RandomSample(bucket.size() - old, m - old, m_rng)) results.push_back(m_results[bucket[old + i]]); } } @@ -290,7 +290,7 @@ void PreRanker::FilterForViewportSearch() else { m_results.clear(); - for (size_t i : base::RandomSample(results.size(), BatchSize(), m_rng)) + for (size_t i : ::base::RandomSample(results.size(), BatchSize(), m_rng)) m_results.push_back(results[i]); } } diff --git a/search/retrieval.cpp b/search/retrieval.cpp index 2af4faa3ef..7af36d0622 100644 --- a/search/retrieval.cpp +++ b/search/retrieval.cpp @@ -160,7 +160,7 @@ bool MatchFeatureByNameAndType(FeatureType const & ft, SearchTrieRequest co bool matched = false; ft.ForEachName([&](int8_t lang, string const & name) { - if (name.empty() || !request.IsLangExist(lang)) + if (name.empty() || !request.HasLang(lang)) return base::ControlFlow::Continue; vector tokens; diff --git a/search/search_tests/locality_scorer_test.cpp b/search/search_tests/locality_scorer_test.cpp index 302420899c..8d04504d66 100644 --- a/search/search_tests/locality_scorer_test.cpp +++ b/search/search_tests/locality_scorer_test.cpp @@ -83,8 +83,7 @@ public: } else { - m_searchIndex.ForEachInNode(name, [&](UniString const & /* prefix */, - uint32_t featureId) { ids.push_back(featureId); }); + m_searchIndex.ForEachInNode(name, [&](uint32_t featureId) { ids.push_back(featureId); }); } }); @@ -115,7 +114,7 @@ protected: unordered_map> m_names; LocalityScorer m_scorer; - my::MemTrie> m_searchIndex; + base::MemTrie> m_searchIndex; }; } // namespace diff --git a/search/search_tests/mem_search_index_tests.cpp b/search/search_tests/mem_search_index_tests.cpp index 52a409630d..bc26ab47bd 100644 --- a/search/search_tests/mem_search_index_tests.cpp +++ b/search/search_tests/mem_search_index_tests.cpp @@ -67,7 +67,7 @@ public: { SearchTrieRequest request; request.m_names.emplace_back(token); - request.m_langs.Insert(StringUtf8Multilang::GetLangIndex(lang)); + request.m_langs.insert(StringUtf8Multilang::GetLangIndex(lang)); vector curr; MatchFeaturesInTrie(request, m_index.GetRootIterator(), diff --git a/search/utils.hpp b/search/utils.hpp index 07a051eaad..66e3067e6f 100644 --- a/search/utils.hpp +++ b/search/utils.hpp @@ -1,21 +1,23 @@ #pragma once #include "search/common.hpp" +#include "search/feature_offset_match.hpp" #include "search/token_slice.hpp" #include "indexer/categories_holder.hpp" #include "indexer/mwm_set.hpp" #include "indexer/search_delimiters.hpp" #include "indexer/search_string_utils.hpp" +#include "indexer/trie.hpp" #include "base/levenshtein_dfa.hpp" #include "base/stl_helpers.hpp" #include "base/string_utils.hpp" -#include +#include +#include #include #include -#include #include class Index; @@ -23,50 +25,6 @@ class MwmInfo; namespace search { -// todo(@m, @y). Unite with the similar function in search/feature_offset_match.hpp. -template -bool MatchInTrie(TrieIt const & trieStartIt, DFA const & dfa, ToDo && toDo) -{ - using Char = typename TrieIt::Char; - using DFAIt = typename DFA::Iterator; - using State = pair; - - std::queue q; - - { - auto it = dfa.Begin(); - if (it.Rejects()) - return false; - q.emplace(trieStartIt, it); - } - - bool found = false; - - while (!q.empty()) - { - auto const p = q.front(); - q.pop(); - - auto const & trieIt = p.first; - auto const & dfaIt = p.second; - - if (dfaIt.Accepts()) - { - trieIt.ForEachInNode(toDo); - found = true; - } - - trieIt.ForEachMove([&](Char const & c, TrieIt const & nextTrieIt) { - auto nextDfaIt = dfaIt; - nextDfaIt.Move(c); - if (!nextDfaIt.Rejects()) - q.emplace(nextTrieIt, nextDfaIt); - }); - } - - return found; -} - size_t GetMaxErrorsForToken(strings::UniString const & token); strings::LevenshteinDFA BuildLevenshteinDFA(strings::UniString const & s); @@ -100,23 +58,22 @@ template void ForEachCategoryTypeFuzzy(StringSliceBase const & slice, Locales const & locales, CategoriesHolder const & categories, ToDo && todo) { - using Trie = my::MemTrie>; + using Iterator = trie::MemTrieIterator>; auto const & trie = categories.GetNameToTypesTrie(); - auto const & trieRootIt = trie.GetRootIterator(); + Iterator const iterator(trie.GetRootIterator()); for (size_t i = 0; i < slice.Size(); ++i) { - auto const & token = slice.Get(i); // todo(@m, @y). We build dfa twice for each token: here and in geocoder.cpp. // A possible optimization is to build each dfa once and save it. Note that // dfas for the prefix tokens differ, i.e. we ignore slice.IsPrefix(i) here. - strings::LevenshteinDFA const dfa(BuildLevenshteinDFA(token)); + SearchTrieRequest request; + request.m_names.push_back(BuildLevenshteinDFA(slice.Get(i))); + request.SetLangs(locales); - trieRootIt.ForEachMove([&](Trie::Char const & c, Trie::Iterator const & trieStartIt) { - if (locales.Contains(static_cast(c))) - MatchInTrie(trieStartIt, dfa, std::bind(todo, i, std::placeholders::_1)); - }); + MatchFeaturesInTrie(request, iterator, [&](uint32_t /* type */) { return true; } /* filter */, + std::bind(todo, i, std::placeholders::_1)); } }