diff --git a/base/mem_trie.hpp b/base/mem_trie.hpp index 4749c21707..723b5f2b05 100644 --- a/base/mem_trie.hpp +++ b/base/mem_trie.hpp @@ -137,6 +137,8 @@ public: m_node.m_values.ForEach(std::forward(toDo)); } + ValuesHolder const & GetValues() const { return m_node.m_values; } + private: MemTrie::Node const & m_node; }; diff --git a/generator/search_index_builder.cpp b/generator/search_index_builder.cpp index d5829e657f..7e848d3cbc 100644 --- a/generator/search_index_builder.cpp +++ b/generator/search_index_builder.cpp @@ -214,7 +214,7 @@ struct ValueBuilder void MakeValue(FeatureType const & ft, uint32_t index, FeatureWithRankAndCenter & v) const { - v.m_featureId = index; + v.m_id = index; // get BEST geometry rect of feature v.m_pt = feature::GetCenter(ft); @@ -229,7 +229,7 @@ struct ValueBuilder void MakeValue(FeatureType const & /* f */, uint32_t index, FeatureIndexValue & value) const { - value.m_featureId = index; + value.m_id = index; } }; diff --git a/indexer/trie.hpp b/indexer/trie.hpp index eebdd0929a..26a57428da 100644 --- a/indexer/trie.hpp +++ b/indexer/trie.hpp @@ -19,7 +19,8 @@ uint32_t constexpr kDefaultChar = 0; template struct Iterator { - using Value = typename ValueList::Value; + using List = ValueList; + using Value = typename List::Value; struct Edge { @@ -33,7 +34,7 @@ struct Iterator virtual std::unique_ptr> GoToEdge(size_t i) const = 0; buffer_vector m_edges; - ValueList m_values; + List m_values; }; template diff --git a/search/CMakeLists.txt b/search/CMakeLists.txt index f8646c9d05..425c6ab313 100644 --- a/search/CMakeLists.txt +++ b/search/CMakeLists.txt @@ -5,6 +5,8 @@ set( algos.hpp approximate_string_match.cpp approximate_string_match.hpp + base/inverted_list.hpp + base/mem_search_index.hpp cancel_exception.hpp categories_cache.cpp categories_cache.hpp diff --git a/search/base/inverted_list.hpp b/search/base/inverted_list.hpp new file mode 100644 index 0000000000..877d95937e --- /dev/null +++ b/search/base/inverted_list.hpp @@ -0,0 +1,55 @@ +#pragma once + +#include "base/assert.hpp" + +#include +#include +#include + +namespace search +{ +namespace base +{ +template +class InvertedList +{ +public: + using value_type = Id; + using Value = Id; + + bool Add(Id const & id) + { + auto it = std::lower_bound(m_ids.begin(), m_ids.end(), id); + if (it != m_ids.end() && *it == id) + return false; + m_ids.insert(it, id); + return true; + } + + bool Erase(Id const & id) + { + auto it = std::lower_bound(m_ids.begin(), m_ids.end(), id); + if (it == m_ids.end() || *it != id) + return false; + m_ids.erase(it); + return true; + } + + template + void ForEach(ToDo && toDo) const + { + for (auto const & id : m_ids) + toDo(id); + } + + size_t Size() const { return m_ids.size(); } + + bool Empty() const { return Size() == 0; } + + void Clear() { m_ids.clear(); } + +private: + std::vector m_ids; +}; +} // namespace base +} // namespace search diff --git a/search/base/mem_search_index.hpp b/search/base/mem_search_index.hpp new file mode 100644 index 0000000000..48411491ca --- /dev/null +++ b/search/base/mem_search_index.hpp @@ -0,0 +1,89 @@ +#pragma once + +#include "search/base/inverted_list.hpp" + +#include "indexer/trie.hpp" + +#include "base/assert.hpp" +#include "base/mem_trie.hpp" +#include "base/string_utils.hpp" + +#include +#include + +namespace search +{ +namespace base +{ +template +class MemSearchIndex +{ +public: + using Token = strings::UniString; + using Char = Token::value_type; + using List = InvertedList; + using Trie = my::MemTrie; + + class Iterator : public trie::Iterator + { + public: + using Base = trie::Iterator; + using InnerIterator = typename Trie::Iterator; + + explicit Iterator(InnerIterator const & inIt) + { + Base::m_values = inIt.GetValues(); + inIt.ForEachMove([&](Char c, InnerIterator it) + { + Base::m_edges.emplace_back(); + Base::m_edges.back().m_label.push_back(c); + m_moves.push_back(it); + }); + } + + ~Iterator() override = default; + + // trie::Iterator overrides: + std::unique_ptr Clone() const override { return my::make_unique(*this); } + + std::unique_ptr GoToEdge(size_t i) const override + { + ASSERT_LESS(i, m_moves.size(), ()); + return my::make_unique(m_moves[i]); + } + + private: + std::vector m_moves; + }; + + void Add(Id const & id, Doc const & doc) + { + ForEachToken(id, doc, [&](Token const & token) { m_trie.Add(token, id); }); + } + + void Erase(Id const & id, Doc const & doc) + { + ForEachToken(id, doc, [&](Token const & token) { m_trie.Erase(token, id); }); + } + + Iterator GetRootIterator() const { return Iterator(m_trie.GetRootIterator()); } + +private: + template + void ForEachToken(Id const & id, Doc const & doc, Fn && fn) + { + doc.ForEachToken([&](int8_t lang, Token const & token) { + if (lang < 0) + return; + + Token t; + t.push_back(static_cast(lang)); + t.insert(t.end(), token.begin(), token.end()); + fn(t); + }); + } + + Trie m_trie; +}; +} // namespace base +} // namespace search diff --git a/search/common.hpp b/search/common.hpp index 37839e5c0d..41be032f6a 100644 --- a/search/common.hpp +++ b/search/common.hpp @@ -14,7 +14,7 @@ namespace search using QueryTokens = buffer_vector; using Locales = - base::SafeSmallSet(CategoriesHolder::kMaxSupportedLocaleIndex) + 1>; + ::base::SafeSmallSet(CategoriesHolder::kMaxSupportedLocaleIndex) + 1>; /// Upper bound for max count of tokens for indexing and scoring. int constexpr MAX_TOKENS = 32; diff --git a/search/feature_offset_match.hpp b/search/feature_offset_match.hpp index 4c709d1cb5..401cc399c1 100644 --- a/search/feature_offset_match.hpp +++ b/search/feature_offset_match.hpp @@ -29,8 +29,8 @@ namespace impl { namespace { -template -bool FindLangIndex(trie::Iterator> const & trieRoot, uint8_t lang, uint32_t & langIx) +template +bool FindLangIndex(trie::Iterator const & trieRoot, uint8_t lang, uint32_t & langIx) { ASSERT_LESS(trieRoot.m_edges.size(), numeric_limits::max(), ()); @@ -49,12 +49,12 @@ bool FindLangIndex(trie::Iterator> const & trieRoot, uint8_t la } } // namespace -template -bool MatchInTrie(trie::Iterator> const & trieRoot, +template +bool MatchInTrie(trie::Iterator const & trieRoot, strings::UniChar const * rootPrefix, size_t rootPrefixSize, DFA const & dfa, ToDo && toDo) { - using TrieDFAIt = shared_ptr>>; + using TrieDFAIt = shared_ptr>; using DFAIt = typename DFA::Iterator; using State = pair; @@ -102,20 +102,7 @@ bool MatchInTrie(trie::Iterator> const & trieRoot, template class OffsetIntersector { - struct Hash - { - size_t operator()(Value const & v) const { return v.m_featureId; } - }; - - struct Equal - { - bool operator()(Value const & v1, Value const & v2) const - { - return v1.m_featureId == v2.m_featureId; - } - }; - - using Set = unordered_set; + using Set = unordered_set; Filter const & m_filter; unique_ptr m_prevSet; @@ -129,7 +116,7 @@ public: if (m_prevSet && !m_prevSet->count(v)) return; - if (!m_filter(v.m_featureId)) + if (!m_filter(v.m_id)) return; m_set->insert(v); @@ -155,10 +142,11 @@ public: }; } // namespace impl -template +template struct TrieRootPrefix { - using Iterator = trie::Iterator>; + using Value = typename ValueList::Value; + using Iterator = trie::Iterator; Iterator const & m_root; strings::UniChar const * m_prefix; @@ -188,7 +176,7 @@ public: void operator()(Value const & v) { - if (m_filter(v.m_featureId)) + if (m_filter(v.m_id)) m_values.push_back(v); } @@ -224,8 +212,8 @@ struct SearchTrieRequest // Calls |toDo| for each feature accepted by at least one DFA. // // *NOTE* |toDo| may be called several times for the same feature. -template -void MatchInTrie(vector const & dfas, TrieRootPrefix const & trieRoot, ToDo && toDo) +template +void MatchInTrie(vector const & dfas, TrieRootPrefix const & trieRoot, ToDo && toDo) { for (auto const & dfa : dfas) impl::MatchInTrie(trieRoot.m_root, trieRoot.m_prefix, trieRoot.m_prefixSize, dfa, toDo); @@ -234,9 +222,9 @@ void MatchInTrie(vector const & dfas, TrieRootPrefix const & trieRoo // Calls |toDo| for each feature in categories branch matching to |request|. // // *NOTE* |toDo| may be called several times for the same feature. -template +template bool MatchCategoriesInTrie(SearchTrieRequest const & request, - trie::Iterator> const & trieRoot, ToDo && toDo) + trie::Iterator const & trieRoot, ToDo && toDo) { uint32_t langIx = 0; if (!impl::FindLangIndex(trieRoot, search::kCategoriesLang, langIx)) @@ -246,16 +234,16 @@ bool MatchCategoriesInTrie(SearchTrieRequest const & request, ASSERT_GREATER_OR_EQUAL(edge.size(), 1, ()); auto const catRoot = trieRoot.GoToEdge(langIx); - MatchInTrie(request.m_categories, TrieRootPrefix(*catRoot, edge), toDo); + MatchInTrie(request.m_categories, TrieRootPrefix(*catRoot, edge), toDo); return true; } // Calls |toDo| with trie root prefix and language code on each // language allowed by |request|. -template +template void ForEachLangPrefix(SearchTrieRequest const & request, - trie::Iterator> const & trieRoot, ToDo && toDo) + trie::Iterator const & trieRoot, ToDo && toDo) { ASSERT_LESS(trieRoot.m_edges.size(), numeric_limits::max(), ()); @@ -268,7 +256,7 @@ void ForEachLangPrefix(SearchTrieRequest const & request, if (edge[0] < search::kCategoriesLang && request.IsLangExist(lang)) { auto const langRoot = trieRoot.GoToEdge(langIx); - TrieRootPrefix langPrefix(*langRoot, edge); + TrieRootPrefix langPrefix(*langRoot, edge); toDo(langPrefix, lang); } } @@ -276,20 +264,23 @@ void ForEachLangPrefix(SearchTrieRequest const & request, // Calls |toDo| for each feature whose description matches to // |request|. Each feature will be passed to |toDo| only once. -template +template void MatchFeaturesInTrie(SearchTrieRequest const & request, - trie::Iterator> const & trieRoot, Filter const & filter, + trie::Iterator const & trieRoot, Filter const & filter, ToDo && toDo) { + using Value = typename ValueList::Value; + TrieValuesHolder categoriesHolder(filter); bool const categoriesMatched = MatchCategoriesInTrie(request, trieRoot, categoriesHolder); impl::OffsetIntersector intersector(filter); - ForEachLangPrefix(request, trieRoot, - [&request, &intersector](TrieRootPrefix & langRoot, int8_t /* lang */) { - MatchInTrie(request.m_names, langRoot, intersector); - }); + ForEachLangPrefix( + request, trieRoot, + [&request, &intersector](TrieRootPrefix & langRoot, int8_t /* lang */) { + MatchInTrie(request.m_names, langRoot, intersector); + }); if (categoriesMatched) categoriesHolder.ForEachValue(intersector); @@ -298,12 +289,12 @@ void MatchFeaturesInTrie(SearchTrieRequest const & request, intersector.ForEachResult(forward(toDo)); } -template -void MatchPostcodesInTrie(TokenSlice const & slice, - trie::Iterator> const & trieRoot, +template +void MatchPostcodesInTrie(TokenSlice const & slice, trie::Iterator const & trieRoot, Filter const & filter, ToDo && toDo) { using namespace strings; + using Value = typename ValueList::Value; uint32_t langIx = 0; if (!impl::FindLangIndex(trieRoot, search::kPostcodesLang, langIx)) @@ -319,13 +310,13 @@ void MatchPostcodesInTrie(TokenSlice const & slice, { vector> dfas; slice.Get(i).ForEach([&dfas](UniString const & s) { dfas.emplace_back(UniStringDFA(s)); }); - MatchInTrie(dfas, TrieRootPrefix(*postcodesRoot, edge), intersector); + MatchInTrie(dfas, TrieRootPrefix(*postcodesRoot, edge), intersector); } else { vector dfas; slice.Get(i).ForEach([&dfas](UniString const & s) { dfas.emplace_back(s); }); - MatchInTrie(dfas, TrieRootPrefix(*postcodesRoot, edge), intersector); + MatchInTrie(dfas, TrieRootPrefix(*postcodesRoot, edge), intersector); } intersector.NextStep(); diff --git a/search/query_params.hpp b/search/query_params.hpp index 3699777b9e..9780fc6464 100644 --- a/search/query_params.hpp +++ b/search/query_params.hpp @@ -23,7 +23,7 @@ class QueryParams public: using String = strings::UniString; using TypeIndices = vector; - using Langs = base::SafeSmallSet; + using Langs = ::base::SafeSmallSet; struct Token { diff --git a/search/retrieval.cpp b/search/retrieval.cpp index c412a4dc47..2485a271d5 100644 --- a/search/retrieval.cpp +++ b/search/retrieval.cpp @@ -49,7 +49,7 @@ public: { if ((++m_counter & 0xFF) == 0) BailIfCancelled(m_cancellable); - m_features.push_back(value.m_featureId); + m_features.push_back(value.m_id); } inline void operator()(uint32_t feature) { m_features.push_back(feature); } diff --git a/search/search_index_values.hpp b/search/search_index_values.hpp index 4605774abb..ff8d2f75fa 100644 --- a/search/search_index_values.hpp +++ b/search/search_index_values.hpp @@ -25,44 +25,70 @@ // A wrapper around feature index. struct FeatureIndexValue { - FeatureIndexValue() : m_featureId(0) {} + FeatureIndexValue() : m_id(0) {} - FeatureIndexValue(uint64_t featureId) : m_featureId(featureId) {} + FeatureIndexValue(uint64_t id) : m_id(id) {} - bool operator<(FeatureIndexValue const & o) const { return m_featureId < o.m_featureId; } + bool operator<(FeatureIndexValue const & o) const { return m_id < o.m_id; } - bool operator==(FeatureIndexValue const & o) const { return m_featureId == o.m_featureId; } + bool operator==(FeatureIndexValue const & o) const { return m_id == o.m_id; } - void Swap(FeatureIndexValue & o) { swap(m_featureId, o.m_featureId); } + void Swap(FeatureIndexValue & o) { swap(m_id, o.m_id); } - uint64_t m_featureId; + uint64_t m_id; }; +namespace std +{ +template <> +class hash +{ +public: + size_t operator()(FeatureIndexValue const & value) const + { + return std::hash{}(value.m_id); + } +}; +} // namespace std + struct FeatureWithRankAndCenter { - FeatureWithRankAndCenter() : m_pt(m2::PointD()), m_featureId(0), m_rank(0) {} + FeatureWithRankAndCenter() : m_pt(m2::PointD()), m_id(0), m_rank(0) {} - FeatureWithRankAndCenter(m2::PointD pt, uint32_t featureId, uint8_t rank) - : m_pt(pt), m_featureId(featureId), m_rank(rank) + FeatureWithRankAndCenter(m2::PointD pt, uint32_t id, uint8_t rank) + : m_pt(pt), m_id(id), m_rank(rank) { } - bool operator<(FeatureWithRankAndCenter const & o) const { return m_featureId < o.m_featureId; } + bool operator<(FeatureWithRankAndCenter const & o) const { return m_id < o.m_id; } - bool operator==(FeatureWithRankAndCenter const & o) const { return m_featureId == o.m_featureId; } + bool operator==(FeatureWithRankAndCenter const & o) const { return m_id == o.m_id; } void Swap(FeatureWithRankAndCenter & o) { swap(m_pt, o.m_pt); - swap(m_featureId, o.m_featureId); + swap(m_id, o.m_id); swap(m_rank, o.m_rank); } - m2::PointD m_pt; // Center point of the feature. - uint32_t m_featureId; // Feature identifier. - uint8_t m_rank; // Rank of the feature. + m2::PointD m_pt; // Center point of the feature. + uint32_t m_id; // Feature identifier. + uint8_t m_rank; // Rank of the feature. }; +namespace std +{ +template <> +class hash +{ +public: + size_t operator()(FeatureWithRankAndCenter const & value) const + { + return std::hash{}(value.m_id); + } +}; +} // namespace std + template class SingleValueSerializer; @@ -78,7 +104,7 @@ public: void Serialize(Sink & sink, Value const & v) const { serial::SavePoint(sink, v.m_pt, m_codingParams); - WriteToSink(sink, v.m_featureId); + WriteToSink(sink, v.m_id); WriteToSink(sink, v.m_rank); } @@ -93,7 +119,7 @@ public: void DeserializeFromSource(Source & source, Value & v) const { v.m_pt = serial::LoadPoint(source, m_codingParams); - v.m_featureId = ReadPrimitiveFromSource(source); + v.m_id = ReadPrimitiveFromSource(source); v.m_rank = ReadPrimitiveFromSource(source); } @@ -117,7 +143,7 @@ public: template void Serialize(Sink & sink, Value const & v) const { - WriteToSink(sink, v.m_featureId); + WriteToSink(sink, v.m_id); } template @@ -130,7 +156,7 @@ public: template void DeserializeFromSource(Source & source, Value & v) const { - v.m_featureId = ReadPrimitiveFromSource(source); + v.m_id = ReadPrimitiveFromSource(source); } }; @@ -159,7 +185,7 @@ public: { std::vector ids(values.size()); for (size_t i = 0; i < ids.size(); ++i) - ids[i] = values[i].m_featureId; + ids[i] = values[i].m_id; m_cbv = coding::CompressedBitVectorBuilder::FromBitPositions(move(ids)); } diff --git a/search/search_tests/CMakeLists.txt b/search/search_tests/CMakeLists.txt index 50a6d11560..48955c1a3b 100644 --- a/search/search_tests/CMakeLists.txt +++ b/search/search_tests/CMakeLists.txt @@ -15,6 +15,7 @@ set( locality_finder_test.cpp locality_scorer_test.cpp locality_selector_test.cpp + mem_search_index_tests.cpp point_rect_matcher_tests.cpp query_saver_tests.cpp ranking_tests.cpp diff --git a/search/search_tests/mem_search_index_tests.cpp b/search/search_tests/mem_search_index_tests.cpp new file mode 100644 index 0000000000..37bb07461e --- /dev/null +++ b/search/search_tests/mem_search_index_tests.cpp @@ -0,0 +1,132 @@ +#include "testing/testing.hpp" + +#include "search/base/mem_search_index.hpp" +#include "search/feature_offset_match.hpp" + +#include "indexer/search_string_utils.hpp" + +#include "coding/multilang_utf8_string.hpp" + +#include "base/string_utils.hpp" +#include "base/uni_string_dfa.hpp" + +#include +#include +#include +#include +#include + +using namespace search::base; +using namespace search; +using namespace std; +using namespace strings; + +struct Id +{ + explicit Id(uint64_t id) : m_id(id) {} + + bool operator==(Id const & rhs) const { return m_id == rhs.m_id; } + bool operator!=(Id const & rhs) const { return !(*this == rhs); } + bool operator<(Id const & rhs) const { return m_id < rhs.m_id; } + + uint64_t m_id; +}; + +string DebugPrint(Id const & id) { return DebugPrint(id.m_id); } + +template<> +class hash +{ +public: + size_t operator()(Id const & id) const { return std::hash{}(id.m_id); } +}; + +class Doc +{ +public: + Doc(string const & text, string const & lang) : m_lang(StringUtf8Multilang::GetLangIndex(lang)) + { + NormalizeAndTokenizeString(text, m_tokens); + } + + template + void ForEachToken(ToDo && toDo) const + { + for (auto const & token : m_tokens) + toDo(m_lang, token); + } + +private: + vector m_tokens; + int8_t m_lang; +}; + +class MemSearchIndexTest +{ +public: + void Add(Id const & id, Doc const & doc) { m_index.Add(id, doc); } + + void Erase(Id const & id, Doc const & doc) { m_index.Erase(id, doc); } + + vector StrictQuery(string const & query, string const & lang) const + { + vector prev; + bool full = true; + + vector tokens; + NormalizeAndTokenizeString(query, tokens); + for (auto const & token : tokens) + { + vector curr; + + SearchTrieRequest request; + request.m_names.emplace_back(token); + request.m_langs.Insert(StringUtf8Multilang::GetLangIndex(lang)); + MatchFeaturesInTrie(request, m_index.GetRootIterator(), + [](uint64_t /* id */) { return true; }, + [&curr](Id const & id) { curr.push_back(id); }); + my::SortUnique(curr); + if (full) + { + prev = curr; + full = false; + } + else + { + vector intersection; + set_intersection(prev.begin(), prev.end(), curr.begin(), curr.end(), + back_inserter(intersection)); + prev = intersection; + } + } + + return prev; + } + +protected: + MemSearchIndex m_index; +}; + +UNIT_CLASS_TEST(MemSearchIndexTest, Smoke) +{ + Id const kHamlet{31337}; + Id const kMacbeth{600613}; + Doc const hamlet{"To be or not to be: that is the question...", "en"}; + Doc const macbeth{"When shall we three meet again? In thunder, lightning, or in rain? ...", "en"}; + + Add(kHamlet, hamlet); + Add(kMacbeth, macbeth); + + TEST_EQUAL(StrictQuery("Thunder", "en"), vector({kMacbeth}), ()); + TEST_EQUAL(StrictQuery("Question", "en"), vector({kHamlet}), ()); + TEST_EQUAL(StrictQuery("or", "en"), vector({kHamlet, kMacbeth}), ()); + TEST_EQUAL(StrictQuery("thunder lightning rain", "en"), vector({kMacbeth}), ()); + + Erase(kMacbeth, macbeth); + + TEST_EQUAL(StrictQuery("Thunder", "en"), vector{}, ()); + TEST_EQUAL(StrictQuery("to be or not to be", "en"), vector({kHamlet}), ()); + + Erase(kHamlet, hamlet); + TEST_EQUAL(StrictQuery("question", "en"), vector{}, ()); +}