[search] MemSearchIndex prototype.

This commit is contained in:
Yuri Gorshenin 2017-11-28 18:29:14 +03:00 committed by Vladimir Byko-Ianko
parent bc7e14b34c
commit 15fc8505fe
13 changed files with 368 additions and 69 deletions

View file

@ -137,6 +137,8 @@ public:
m_node.m_values.ForEach(std::forward<ToDo>(toDo));
}
ValuesHolder const & GetValues() const { return m_node.m_values; }
private:
MemTrie::Node const & m_node;
};

View file

@ -214,7 +214,7 @@ struct ValueBuilder<FeatureWithRankAndCenter>
void MakeValue(FeatureType const & ft, uint32_t index, FeatureWithRankAndCenter & v) const
{
v.m_featureId = index;
v.m_id = index;
// get BEST geometry rect of feature
v.m_pt = feature::GetCenter(ft);
@ -229,7 +229,7 @@ struct ValueBuilder<FeatureIndexValue>
void MakeValue(FeatureType const & /* f */, uint32_t index, FeatureIndexValue & value) const
{
value.m_featureId = index;
value.m_id = index;
}
};

View file

@ -19,7 +19,8 @@ uint32_t constexpr kDefaultChar = 0;
template <typename ValueList>
struct Iterator
{
using Value = typename ValueList::Value;
using List = ValueList;
using Value = typename List::Value;
struct Edge
{
@ -33,7 +34,7 @@ struct Iterator
virtual std::unique_ptr<Iterator<ValueList>> GoToEdge(size_t i) const = 0;
buffer_vector<Edge, 8> m_edges;
ValueList m_values;
List m_values;
};
template <typename ValueList, typename ToDo, typename String>

View file

@ -5,6 +5,8 @@ set(
algos.hpp
approximate_string_match.cpp
approximate_string_match.hpp
base/inverted_list.hpp
base/mem_search_index.hpp
cancel_exception.hpp
categories_cache.cpp
categories_cache.hpp

View file

@ -0,0 +1,55 @@
#pragma once
#include "base/assert.hpp"
#include <algorithm>
#include <cstddef>
#include <vector>
namespace search
{
namespace base
{
template <typename Id>
class InvertedList
{
public:
using value_type = Id;
using Value = Id;
bool Add(Id const & id)
{
auto it = std::lower_bound(m_ids.begin(), m_ids.end(), id);
if (it != m_ids.end() && *it == id)
return false;
m_ids.insert(it, id);
return true;
}
bool Erase(Id const & id)
{
auto it = std::lower_bound(m_ids.begin(), m_ids.end(), id);
if (it == m_ids.end() || *it != id)
return false;
m_ids.erase(it);
return true;
}
template <typename ToDo>
void ForEach(ToDo && toDo) const
{
for (auto const & id : m_ids)
toDo(id);
}
size_t Size() const { return m_ids.size(); }
bool Empty() const { return Size() == 0; }
void Clear() { m_ids.clear(); }
private:
std::vector<Id> m_ids;
};
} // namespace base
} // namespace search

View file

@ -0,0 +1,89 @@
#pragma once
#include "search/base/inverted_list.hpp"
#include "indexer/trie.hpp"
#include "base/assert.hpp"
#include "base/mem_trie.hpp"
#include "base/string_utils.hpp"
#include <cstdint>
#include <vector>
namespace search
{
namespace base
{
template <typename Id, typename Doc>
class MemSearchIndex
{
public:
using Token = strings::UniString;
using Char = Token::value_type;
using List = InvertedList<Id>;
using Trie = my::MemTrie<Token, List>;
class Iterator : public trie::Iterator<List>
{
public:
using Base = trie::Iterator<List>;
using InnerIterator = typename Trie::Iterator;
explicit Iterator(InnerIterator const & inIt)
{
Base::m_values = inIt.GetValues();
inIt.ForEachMove([&](Char c, InnerIterator it)
{
Base::m_edges.emplace_back();
Base::m_edges.back().m_label.push_back(c);
m_moves.push_back(it);
});
}
~Iterator() override = default;
// trie::Iterator<List> overrides:
std::unique_ptr<Base> Clone() const override { return my::make_unique<Iterator>(*this); }
std::unique_ptr<Base> GoToEdge(size_t i) const override
{
ASSERT_LESS(i, m_moves.size(), ());
return my::make_unique<Iterator>(m_moves[i]);
}
private:
std::vector<InnerIterator> m_moves;
};
void Add(Id const & id, Doc const & doc)
{
ForEachToken(id, doc, [&](Token const & token) { m_trie.Add(token, id); });
}
void Erase(Id const & id, Doc const & doc)
{
ForEachToken(id, doc, [&](Token const & token) { m_trie.Erase(token, id); });
}
Iterator GetRootIterator() const { return Iterator(m_trie.GetRootIterator()); }
private:
template <typename Fn>
void ForEachToken(Id const & id, Doc const & doc, Fn && fn)
{
doc.ForEachToken([&](int8_t lang, Token const & token) {
if (lang < 0)
return;
Token t;
t.push_back(static_cast<Char>(lang));
t.insert(t.end(), token.begin(), token.end());
fn(t);
});
}
Trie m_trie;
};
} // namespace base
} // namespace search

View file

@ -14,7 +14,7 @@ namespace search
using QueryTokens = buffer_vector<strings::UniString, 32>;
using Locales =
base::SafeSmallSet<static_cast<uint64_t>(CategoriesHolder::kMaxSupportedLocaleIndex) + 1>;
::base::SafeSmallSet<static_cast<uint64_t>(CategoriesHolder::kMaxSupportedLocaleIndex) + 1>;
/// Upper bound for max count of tokens for indexing and scoring.
int constexpr MAX_TOKENS = 32;

View file

@ -29,8 +29,8 @@ namespace impl
{
namespace
{
template <typename Value>
bool FindLangIndex(trie::Iterator<ValueList<Value>> const & trieRoot, uint8_t lang, uint32_t & langIx)
template <typename ValueList>
bool FindLangIndex(trie::Iterator<ValueList> const & trieRoot, uint8_t lang, uint32_t & langIx)
{
ASSERT_LESS(trieRoot.m_edges.size(), numeric_limits<uint32_t>::max(), ());
@ -49,12 +49,12 @@ bool FindLangIndex(trie::Iterator<ValueList<Value>> const & trieRoot, uint8_t la
}
} // namespace
template <typename Value, typename DFA, typename ToDo>
bool MatchInTrie(trie::Iterator<ValueList<Value>> const & trieRoot,
template <typename ValueList, typename DFA, typename ToDo>
bool MatchInTrie(trie::Iterator<ValueList> const & trieRoot,
strings::UniChar const * rootPrefix, size_t rootPrefixSize, DFA const & dfa,
ToDo && toDo)
{
using TrieDFAIt = shared_ptr<trie::Iterator<ValueList<Value>>>;
using TrieDFAIt = shared_ptr<trie::Iterator<ValueList>>;
using DFAIt = typename DFA::Iterator;
using State = pair<TrieDFAIt, DFAIt>;
@ -102,20 +102,7 @@ bool MatchInTrie(trie::Iterator<ValueList<Value>> const & trieRoot,
template <typename Filter, typename Value>
class OffsetIntersector
{
struct Hash
{
size_t operator()(Value const & v) const { return v.m_featureId; }
};
struct Equal
{
bool operator()(Value const & v1, Value const & v2) const
{
return v1.m_featureId == v2.m_featureId;
}
};
using Set = unordered_set<Value, Hash, Equal>;
using Set = unordered_set<Value>;
Filter const & m_filter;
unique_ptr<Set> m_prevSet;
@ -129,7 +116,7 @@ public:
if (m_prevSet && !m_prevSet->count(v))
return;
if (!m_filter(v.m_featureId))
if (!m_filter(v.m_id))
return;
m_set->insert(v);
@ -155,10 +142,11 @@ public:
};
} // namespace impl
template <typename Value>
template <typename ValueList>
struct TrieRootPrefix
{
using Iterator = trie::Iterator<ValueList<Value>>;
using Value = typename ValueList::Value;
using Iterator = trie::Iterator<ValueList>;
Iterator const & m_root;
strings::UniChar const * m_prefix;
@ -188,7 +176,7 @@ public:
void operator()(Value const & v)
{
if (m_filter(v.m_featureId))
if (m_filter(v.m_id))
m_values.push_back(v);
}
@ -224,8 +212,8 @@ struct SearchTrieRequest
// Calls |toDo| for each feature accepted by at least one DFA.
//
// *NOTE* |toDo| may be called several times for the same feature.
template <typename DFA, typename Value, typename ToDo>
void MatchInTrie(vector<DFA> const & dfas, TrieRootPrefix<Value> const & trieRoot, ToDo && toDo)
template <typename DFA, typename ValueList, typename ToDo>
void MatchInTrie(vector<DFA> const & dfas, TrieRootPrefix<ValueList> const & trieRoot, ToDo && toDo)
{
for (auto const & dfa : dfas)
impl::MatchInTrie(trieRoot.m_root, trieRoot.m_prefix, trieRoot.m_prefixSize, dfa, toDo);
@ -234,9 +222,9 @@ void MatchInTrie(vector<DFA> const & dfas, TrieRootPrefix<Value> const & trieRoo
// Calls |toDo| for each feature in categories branch matching to |request|.
//
// *NOTE* |toDo| may be called several times for the same feature.
template <typename DFA, typename Value, typename ToDo>
template <typename DFA, typename ValueList, typename ToDo>
bool MatchCategoriesInTrie(SearchTrieRequest<DFA> const & request,
trie::Iterator<ValueList<Value>> const & trieRoot, ToDo && toDo)
trie::Iterator<ValueList> const & trieRoot, ToDo && toDo)
{
uint32_t langIx = 0;
if (!impl::FindLangIndex(trieRoot, search::kCategoriesLang, langIx))
@ -246,16 +234,16 @@ bool MatchCategoriesInTrie(SearchTrieRequest<DFA> const & request,
ASSERT_GREATER_OR_EQUAL(edge.size(), 1, ());
auto const catRoot = trieRoot.GoToEdge(langIx);
MatchInTrie(request.m_categories, TrieRootPrefix<Value>(*catRoot, edge), toDo);
MatchInTrie(request.m_categories, TrieRootPrefix<ValueList>(*catRoot, edge), toDo);
return true;
}
// Calls |toDo| with trie root prefix and language code on each
// language allowed by |request|.
template <typename DFA, typename Value, typename ToDo>
template <typename DFA, typename ValueList, typename ToDo>
void ForEachLangPrefix(SearchTrieRequest<DFA> const & request,
trie::Iterator<ValueList<Value>> const & trieRoot, ToDo && toDo)
trie::Iterator<ValueList> const & trieRoot, ToDo && toDo)
{
ASSERT_LESS(trieRoot.m_edges.size(), numeric_limits<uint32_t>::max(), ());
@ -268,7 +256,7 @@ void ForEachLangPrefix(SearchTrieRequest<DFA> const & request,
if (edge[0] < search::kCategoriesLang && request.IsLangExist(lang))
{
auto const langRoot = trieRoot.GoToEdge(langIx);
TrieRootPrefix<Value> langPrefix(*langRoot, edge);
TrieRootPrefix<ValueList> langPrefix(*langRoot, edge);
toDo(langPrefix, lang);
}
}
@ -276,20 +264,23 @@ void ForEachLangPrefix(SearchTrieRequest<DFA> const & request,
// Calls |toDo| for each feature whose description matches to
// |request|. Each feature will be passed to |toDo| only once.
template <typename DFA, typename Value, typename Filter, typename ToDo>
template <typename DFA, typename ValueList, typename Filter, typename ToDo>
void MatchFeaturesInTrie(SearchTrieRequest<DFA> const & request,
trie::Iterator<ValueList<Value>> const & trieRoot, Filter const & filter,
trie::Iterator<ValueList> const & trieRoot, Filter const & filter,
ToDo && toDo)
{
using Value = typename ValueList::Value;
TrieValuesHolder<Filter, Value> categoriesHolder(filter);
bool const categoriesMatched = MatchCategoriesInTrie(request, trieRoot, categoriesHolder);
impl::OffsetIntersector<Filter, Value> intersector(filter);
ForEachLangPrefix(request, trieRoot,
[&request, &intersector](TrieRootPrefix<Value> & langRoot, int8_t /* lang */) {
MatchInTrie(request.m_names, langRoot, intersector);
});
ForEachLangPrefix(
request, trieRoot,
[&request, &intersector](TrieRootPrefix<ValueList> & langRoot, int8_t /* lang */) {
MatchInTrie(request.m_names, langRoot, intersector);
});
if (categoriesMatched)
categoriesHolder.ForEachValue(intersector);
@ -298,12 +289,12 @@ void MatchFeaturesInTrie(SearchTrieRequest<DFA> const & request,
intersector.ForEachResult(forward<ToDo>(toDo));
}
template <typename Value, typename Filter, typename ToDo>
void MatchPostcodesInTrie(TokenSlice const & slice,
trie::Iterator<ValueList<Value>> const & trieRoot,
template <typename ValueList, typename Filter, typename ToDo>
void MatchPostcodesInTrie(TokenSlice const & slice, trie::Iterator<ValueList> const & trieRoot,
Filter const & filter, ToDo && toDo)
{
using namespace strings;
using Value = typename ValueList::Value;
uint32_t langIx = 0;
if (!impl::FindLangIndex(trieRoot, search::kPostcodesLang, langIx))
@ -319,13 +310,13 @@ void MatchPostcodesInTrie(TokenSlice const & slice,
{
vector<PrefixDFAModifier<UniStringDFA>> dfas;
slice.Get(i).ForEach([&dfas](UniString const & s) { dfas.emplace_back(UniStringDFA(s)); });
MatchInTrie(dfas, TrieRootPrefix<Value>(*postcodesRoot, edge), intersector);
MatchInTrie(dfas, TrieRootPrefix<ValueList>(*postcodesRoot, edge), intersector);
}
else
{
vector<UniStringDFA> dfas;
slice.Get(i).ForEach([&dfas](UniString const & s) { dfas.emplace_back(s); });
MatchInTrie(dfas, TrieRootPrefix<Value>(*postcodesRoot, edge), intersector);
MatchInTrie(dfas, TrieRootPrefix<ValueList>(*postcodesRoot, edge), intersector);
}
intersector.NextStep();

View file

@ -23,7 +23,7 @@ class QueryParams
public:
using String = strings::UniString;
using TypeIndices = vector<uint32_t>;
using Langs = base::SafeSmallSet<StringUtf8Multilang::kMaxSupportedLanguages>;
using Langs = ::base::SafeSmallSet<StringUtf8Multilang::kMaxSupportedLanguages>;
struct Token
{

View file

@ -49,7 +49,7 @@ public:
{
if ((++m_counter & 0xFF) == 0)
BailIfCancelled(m_cancellable);
m_features.push_back(value.m_featureId);
m_features.push_back(value.m_id);
}
inline void operator()(uint32_t feature) { m_features.push_back(feature); }

View file

@ -25,44 +25,70 @@
// A wrapper around feature index.
struct FeatureIndexValue
{
FeatureIndexValue() : m_featureId(0) {}
FeatureIndexValue() : m_id(0) {}
FeatureIndexValue(uint64_t featureId) : m_featureId(featureId) {}
FeatureIndexValue(uint64_t id) : m_id(id) {}
bool operator<(FeatureIndexValue const & o) const { return m_featureId < o.m_featureId; }
bool operator<(FeatureIndexValue const & o) const { return m_id < o.m_id; }
bool operator==(FeatureIndexValue const & o) const { return m_featureId == o.m_featureId; }
bool operator==(FeatureIndexValue const & o) const { return m_id == o.m_id; }
void Swap(FeatureIndexValue & o) { swap(m_featureId, o.m_featureId); }
void Swap(FeatureIndexValue & o) { swap(m_id, o.m_id); }
uint64_t m_featureId;
uint64_t m_id;
};
namespace std
{
template <>
class hash<FeatureIndexValue>
{
public:
size_t operator()(FeatureIndexValue const & value) const
{
return std::hash<uint64_t>{}(value.m_id);
}
};
} // namespace std
struct FeatureWithRankAndCenter
{
FeatureWithRankAndCenter() : m_pt(m2::PointD()), m_featureId(0), m_rank(0) {}
FeatureWithRankAndCenter() : m_pt(m2::PointD()), m_id(0), m_rank(0) {}
FeatureWithRankAndCenter(m2::PointD pt, uint32_t featureId, uint8_t rank)
: m_pt(pt), m_featureId(featureId), m_rank(rank)
FeatureWithRankAndCenter(m2::PointD pt, uint32_t id, uint8_t rank)
: m_pt(pt), m_id(id), m_rank(rank)
{
}
bool operator<(FeatureWithRankAndCenter const & o) const { return m_featureId < o.m_featureId; }
bool operator<(FeatureWithRankAndCenter const & o) const { return m_id < o.m_id; }
bool operator==(FeatureWithRankAndCenter const & o) const { return m_featureId == o.m_featureId; }
bool operator==(FeatureWithRankAndCenter const & o) const { return m_id == o.m_id; }
void Swap(FeatureWithRankAndCenter & o)
{
swap(m_pt, o.m_pt);
swap(m_featureId, o.m_featureId);
swap(m_id, o.m_id);
swap(m_rank, o.m_rank);
}
m2::PointD m_pt; // Center point of the feature.
uint32_t m_featureId; // Feature identifier.
uint8_t m_rank; // Rank of the feature.
m2::PointD m_pt; // Center point of the feature.
uint32_t m_id; // Feature identifier.
uint8_t m_rank; // Rank of the feature.
};
namespace std
{
template <>
class hash<FeatureWithRankAndCenter>
{
public:
size_t operator()(FeatureWithRankAndCenter const & value) const
{
return std::hash<uint64_t>{}(value.m_id);
}
};
} // namespace std
template <typename Value>
class SingleValueSerializer;
@ -78,7 +104,7 @@ public:
void Serialize(Sink & sink, Value const & v) const
{
serial::SavePoint(sink, v.m_pt, m_codingParams);
WriteToSink(sink, v.m_featureId);
WriteToSink(sink, v.m_id);
WriteToSink(sink, v.m_rank);
}
@ -93,7 +119,7 @@ public:
void DeserializeFromSource(Source & source, Value & v) const
{
v.m_pt = serial::LoadPoint(source, m_codingParams);
v.m_featureId = ReadPrimitiveFromSource<uint32_t>(source);
v.m_id = ReadPrimitiveFromSource<uint32_t>(source);
v.m_rank = ReadPrimitiveFromSource<uint8_t>(source);
}
@ -117,7 +143,7 @@ public:
template <typename Sink>
void Serialize(Sink & sink, Value const & v) const
{
WriteToSink(sink, v.m_featureId);
WriteToSink(sink, v.m_id);
}
template <typename Reader>
@ -130,7 +156,7 @@ public:
template <typename Source>
void DeserializeFromSource(Source & source, Value & v) const
{
v.m_featureId = ReadPrimitiveFromSource<uint64_t>(source);
v.m_id = ReadPrimitiveFromSource<uint64_t>(source);
}
};
@ -159,7 +185,7 @@ public:
{
std::vector<uint64_t> ids(values.size());
for (size_t i = 0; i < ids.size(); ++i)
ids[i] = values[i].m_featureId;
ids[i] = values[i].m_id;
m_cbv = coding::CompressedBitVectorBuilder::FromBitPositions(move(ids));
}

View file

@ -15,6 +15,7 @@ set(
locality_finder_test.cpp
locality_scorer_test.cpp
locality_selector_test.cpp
mem_search_index_tests.cpp
point_rect_matcher_tests.cpp
query_saver_tests.cpp
ranking_tests.cpp

View file

@ -0,0 +1,132 @@
#include "testing/testing.hpp"
#include "search/base/mem_search_index.hpp"
#include "search/feature_offset_match.hpp"
#include "indexer/search_string_utils.hpp"
#include "coding/multilang_utf8_string.hpp"
#include "base/string_utils.hpp"
#include "base/uni_string_dfa.hpp"
#include <algorithm>
#include <cstdint>
#include <iterator>
#include <string>
#include <vector>
using namespace search::base;
using namespace search;
using namespace std;
using namespace strings;
struct Id
{
explicit Id(uint64_t id) : m_id(id) {}
bool operator==(Id const & rhs) const { return m_id == rhs.m_id; }
bool operator!=(Id const & rhs) const { return !(*this == rhs); }
bool operator<(Id const & rhs) const { return m_id < rhs.m_id; }
uint64_t m_id;
};
string DebugPrint(Id const & id) { return DebugPrint(id.m_id); }
template<>
class hash<Id>
{
public:
size_t operator()(Id const & id) const { return std::hash<uint64_t>{}(id.m_id); }
};
class Doc
{
public:
Doc(string const & text, string const & lang) : m_lang(StringUtf8Multilang::GetLangIndex(lang))
{
NormalizeAndTokenizeString(text, m_tokens);
}
template <typename ToDo>
void ForEachToken(ToDo && toDo) const
{
for (auto const & token : m_tokens)
toDo(m_lang, token);
}
private:
vector<strings::UniString> m_tokens;
int8_t m_lang;
};
class MemSearchIndexTest
{
public:
void Add(Id const & id, Doc const & doc) { m_index.Add(id, doc); }
void Erase(Id const & id, Doc const & doc) { m_index.Erase(id, doc); }
vector<Id> StrictQuery(string const & query, string const & lang) const
{
vector<Id> prev;
bool full = true;
vector<UniString> tokens;
NormalizeAndTokenizeString(query, tokens);
for (auto const & token : tokens)
{
vector<Id> curr;
SearchTrieRequest<UniStringDFA> request;
request.m_names.emplace_back(token);
request.m_langs.Insert(StringUtf8Multilang::GetLangIndex(lang));
MatchFeaturesInTrie(request, m_index.GetRootIterator(),
[](uint64_t /* id */) { return true; },
[&curr](Id const & id) { curr.push_back(id); });
my::SortUnique(curr);
if (full)
{
prev = curr;
full = false;
}
else
{
vector<Id> intersection;
set_intersection(prev.begin(), prev.end(), curr.begin(), curr.end(),
back_inserter(intersection));
prev = intersection;
}
}
return prev;
}
protected:
MemSearchIndex<Id, Doc> m_index;
};
UNIT_CLASS_TEST(MemSearchIndexTest, Smoke)
{
Id const kHamlet{31337};
Id const kMacbeth{600613};
Doc const hamlet{"To be or not to be: that is the question...", "en"};
Doc const macbeth{"When shall we three meet again? In thunder, lightning, or in rain? ...", "en"};
Add(kHamlet, hamlet);
Add(kMacbeth, macbeth);
TEST_EQUAL(StrictQuery("Thunder", "en"), vector<Id>({kMacbeth}), ());
TEST_EQUAL(StrictQuery("Question", "en"), vector<Id>({kHamlet}), ());
TEST_EQUAL(StrictQuery("or", "en"), vector<Id>({kHamlet, kMacbeth}), ());
TEST_EQUAL(StrictQuery("thunder lightning rain", "en"), vector<Id>({kMacbeth}), ());
Erase(kMacbeth, macbeth);
TEST_EQUAL(StrictQuery("Thunder", "en"), vector<Id>{}, ());
TEST_EQUAL(StrictQuery("to be or not to be", "en"), vector<Id>({kHamlet}), ());
Erase(kHamlet, hamlet);
TEST_EQUAL(StrictQuery("question", "en"), vector<Id>{}, ());
}