[search] Pre-localities selection is moved to LocalityScorer.

This commit is contained in:
Yuri Gorshenin 2017-02-01 13:31:06 +03:00
parent 5017b013e9
commit ef5e5bace7
8 changed files with 230 additions and 196 deletions

View file

@ -146,5 +146,5 @@ private:
size_t m_numNodes = 1;
DISALLOW_COPY(MemTrie);
}; // class MemTrie
};
} // namespace my

View file

@ -641,48 +641,9 @@ void Geocoder::FillLocalityCandidates(BaseContext const & ctx, CBV const & filte
size_t const maxNumLocalities,
vector<Locality> & preLocalities)
{
preLocalities.clear();
for (size_t startToken = 0; startToken < ctx.m_numTokens; ++startToken)
{
CBV intersection = filter.Intersect(ctx.m_features[startToken]);
if (intersection.IsEmpty())
continue;
CBV unfilteredIntersection = ctx.m_features[startToken];
for (size_t endToken = startToken + 1; endToken <= ctx.m_numTokens; ++endToken)
{
// Skip locality candidates that match only numbers.
if (!m_params.IsNumberTokens(startToken, endToken))
{
intersection.ForEach([&](uint32_t featureId)
{
Locality l;
l.m_countryId = m_context->GetId();
l.m_featureId = featureId;
l.m_startToken = startToken;
l.m_endToken = endToken;
l.m_prob = static_cast<double>(intersection.PopCount()) /
static_cast<double>(unfilteredIntersection.PopCount());
preLocalities.push_back(l);
});
}
if (endToken < ctx.m_numTokens)
{
intersection = intersection.Intersect(ctx.m_features[endToken]);
if (intersection.IsEmpty())
break;
unfilteredIntersection = unfilteredIntersection.Intersect(ctx.m_features[endToken]);
}
}
}
LocalityScorerDelegate delegate(*m_context, m_params);
LocalityScorer scorer(m_params, delegate);
scorer.GetTopLocalities(maxNumLocalities, preLocalities);
scorer.GetTopLocalities(m_context->GetId(), ctx, filter, maxNumLocalities, preLocalities);
}
void Geocoder::FillLocalitiesTable(BaseContext const & ctx)
@ -745,7 +706,7 @@ void Geocoder::FillLocalitiesTable(BaseContext const & ctx)
#if defined(DEBUG)
ft.GetName(StringUtf8Multilang::kDefaultCode, city.m_defaultName);
LOG(LDEBUG, ("City =", city.m_defaultName, radius));
LOG(LINFO, ("City =", city.m_defaultName, "radius =", radius, "prob = ", city.m_prob));
#endif
m_cities[{l.m_startToken, l.m_endToken}].push_back(city);
@ -794,7 +755,7 @@ void Geocoder::FillVillageLocalities(BaseContext const & ctx)
#if defined(DEBUG)
ft.GetName(StringUtf8Multilang::kDefaultCode, village.m_defaultName);
LOG(LDEBUG, ("Village =", village.m_defaultName, radius, "prob =", village.m_prob));
LOG(LDEBUG, ("Village =", village.m_defaultName, "radius =", radius, "prob =", village.m_prob));
#endif
m_cities[{l.m_startToken, l.m_endToken}].push_back(village);
@ -1421,8 +1382,12 @@ bool Geocoder::GetSearchTypeInGeocoding(BaseContext const & ctx, uint32_t featur
string DebugPrint(Geocoder::Locality const & locality)
{
ostringstream os;
os << "Locality [" << DebugPrint(locality.m_countryId) << ", featureId=" << locality.m_featureId
<< ", startToken=" << locality.m_startToken << ", endToken=" << locality.m_endToken << "]";
os << "Locality [ ";
os << "m_countryId=" << DebugPrint(locality.m_countryId) << ", ";
os << "m_featureId=" << locality.m_featureId << ", ";
os << "token range=[" << locality.m_startToken << ", " << locality.m_endToken << "), ";
os << "m_prob=" << locality.m_prob;
os << " ]";
return os.str();
}
} // namespace search

View file

@ -94,23 +94,27 @@ public:
struct Locality
{
Locality() : m_featureId(0), m_startToken(0), m_endToken(0), m_prob(0.0) {}
Locality() = default;
Locality(uint32_t featureId, size_t startToken, size_t endToken)
: m_featureId(featureId), m_startToken(startToken), m_endToken(endToken), m_prob(0.0)
Locality(MwmSet::MwmId const & countryId, uint32_t featureId, size_t startToken,
size_t endToken, double prob)
: m_countryId(countryId)
, m_featureId(featureId)
, m_startToken(startToken)
, m_endToken(endToken)
, m_prob(prob)
{
}
MwmSet::MwmId m_countryId;
uint32_t m_featureId;
size_t m_startToken;
size_t m_endToken;
uint32_t m_featureId = 0;
size_t m_startToken = 0;
size_t m_endToken = 0;
// Measures our belief in the fact that tokens in the range [m_startToken, m_endToken)
// indeed specify a locality. Currently it is set only for villages.
double m_prob;
string m_name;
// Measures our belief in the fact that tokens in the range
// [m_startToken, m_endToken) indeed specify a locality. Currently
// it is set only for villages.
double m_prob = 0.0;
};
// This struct represents a country or US- or Canadian- state. It
@ -193,8 +197,9 @@ private:
void InitLayer(SearchModel::SearchType type, size_t startToken, size_t endToken,
FeaturesLayer & layer);
void FillLocalityCandidates(BaseContext const & ctx, CBV const & filter,
size_t const maxNumLocalities, vector<Locality> & preLocalities);
void FillLocalityCandidates(BaseContext const & ctx,
CBV const & filter, size_t const maxNumLocalities,
vector<Locality> & preLocalities);
void FillLocalitiesTable(BaseContext const & ctx);

View file

@ -1,8 +1,12 @@
#include "search/locality_scorer.hpp"
#include "search/cbv.hpp"
#include "search/geocoder_context.hpp"
#include "search/token_slice.hpp"
#include "std/algorithm.hpp"
#include <algorithm>
#include <sstream>
#include <unordered_set>
namespace search
{
@ -36,9 +40,52 @@ LocalityScorer::LocalityScorer(QueryParams const & params, Delegate const & dele
{
}
void LocalityScorer::GetTopLocalities(size_t limit, vector<Geocoder::Locality> & localities) const
void LocalityScorer::GetTopLocalities(MwmSet::MwmId const & countryId, BaseContext const & ctx,
CBV const & filter, size_t limit,
std::vector<Geocoder::Locality> & localities)
{
vector<ExLocality> ls;
CHECK_EQUAL(ctx.m_numTokens, m_params.GetNumTokens(), ());
localities.clear();
for (size_t startToken = 0; startToken < ctx.m_numTokens; ++startToken)
{
CBV intersection = filter.Intersect(ctx.m_features[startToken]);
if (intersection.IsEmpty())
continue;
CBV unfilteredIntersection = ctx.m_features[startToken];
for (size_t endToken = startToken + 1; endToken <= ctx.m_numTokens; ++endToken)
{
// Skip locality candidates that match only numbers.
if (!m_params.IsNumberTokens(startToken, endToken))
{
intersection.ForEach([&](uint32_t featureId) {
double const prob = static_cast<double>(intersection.PopCount()) /
static_cast<double>(unfilteredIntersection.PopCount());
localities.emplace_back(countryId, featureId, startToken, endToken, prob);
});
}
if (endToken < ctx.m_numTokens)
{
intersection = intersection.Intersect(ctx.m_features[endToken]);
if (intersection.IsEmpty())
break;
unfilteredIntersection = unfilteredIntersection.Intersect(ctx.m_features[endToken]);
}
}
}
LeaveTopLocalities(limit, localities);
}
void LocalityScorer::LeaveTopLocalities(size_t limit,
std::vector<Geocoder::Locality> & localities) const
{
std::vector<ExLocality> ls;
ls.reserve(localities.size());
for (auto const & locality : localities)
ls.emplace_back(locality);
@ -55,23 +102,21 @@ void LocalityScorer::GetTopLocalities(size_t limit, vector<Geocoder::Locality> &
localities.push_back(l.m_locality);
}
void LocalityScorer::RemoveDuplicates(vector<ExLocality> & ls) const
void LocalityScorer::RemoveDuplicates(std::vector<ExLocality> & ls) const
{
sort(ls.begin(), ls.end(), [](ExLocality const & lhs, ExLocality const & rhs)
{
if (lhs.GetId() != rhs.GetId())
return lhs.GetId() < rhs.GetId();
return lhs.m_numTokens > rhs.m_numTokens;
});
ls.erase(unique(ls.begin(), ls.end(),
[](ExLocality const & lhs, ExLocality const & rhs)
{
return lhs.GetId() == rhs.GetId();
}),
std::sort(ls.begin(), ls.end(), [](ExLocality const & lhs, ExLocality const & rhs) {
if (lhs.GetId() != rhs.GetId())
return lhs.GetId() < rhs.GetId();
return lhs.m_numTokens > rhs.m_numTokens;
});
ls.erase(std::unique(ls.begin(), ls.end(),
[](ExLocality const & lhs, ExLocality const & rhs) {
return lhs.GetId() == rhs.GetId();
}),
ls.end());
}
void LocalityScorer::LeaveTopByRankAndProb(size_t limit, vector<ExLocality> & ls) const
void LocalityScorer::LeaveTopByRankAndProb(size_t limit, std::vector<ExLocality> & ls) const
{
if (ls.size() <= limit)
return;
@ -79,20 +124,19 @@ void LocalityScorer::LeaveTopByRankAndProb(size_t limit, vector<ExLocality> & ls
for (auto & l : ls)
l.m_rank = m_delegate.GetRank(l.GetId());
sort(ls.begin(), ls.end(), [](ExLocality const & lhs, ExLocality const & rhs)
{
if (lhs.m_locality.m_prob != rhs.m_locality.m_prob)
return lhs.m_locality.m_prob > rhs.m_locality.m_prob;
if (lhs.m_rank != rhs.m_rank)
return lhs.m_rank > rhs.m_rank;
return lhs.m_numTokens > rhs.m_numTokens;
});
std::sort(ls.begin(), ls.end(), [](ExLocality const & lhs, ExLocality const & rhs) {
if (lhs.m_locality.m_prob != rhs.m_locality.m_prob)
return lhs.m_locality.m_prob > rhs.m_locality.m_prob;
if (lhs.m_rank != rhs.m_rank)
return lhs.m_rank > rhs.m_rank;
return lhs.m_numTokens > rhs.m_numTokens;
});
ls.resize(limit);
}
void LocalityScorer::SortByNameAndProb(vector<ExLocality> & ls) const
void LocalityScorer::SortByNameAndProb(std::vector<ExLocality> & ls) const
{
vector<string> names;
std::vector<std::string> names;
for (auto & l : ls)
{
names.clear();
@ -105,38 +149,48 @@ void LocalityScorer::SortByNameAndProb(vector<ExLocality> & ls) const
l.m_locality.m_endToken)));
}
l.m_nameScore = score;
std::sort(ls.begin(), ls.end(), [](ExLocality const & lhs, ExLocality const & rhs) {
// Probabilities form a stronger signal than name scores do.
if (lhs.m_locality.m_prob != rhs.m_locality.m_prob)
return lhs.m_locality.m_prob > rhs.m_locality.m_prob;
if (IsAlmostFullMatch(lhs.m_nameScore) && IsAlmostFullMatch(rhs.m_nameScore))
{
// When both localities match well, e.g. full or full prefix
// match, the one with larger number of tokens is selected. In
// case of tie, the one with better score is selected.
if (lhs.m_numTokens != rhs.m_numTokens)
return lhs.m_numTokens > rhs.m_numTokens;
if (lhs.m_nameScore != rhs.m_nameScore)
return lhs.m_nameScore > rhs.m_nameScore;
}
else
{
// When name scores differ, the one with better name score is
// selected. In case of tie, the one with larger number of
// matched tokens is selected.
if (lhs.m_nameScore != rhs.m_nameScore)
return lhs.m_nameScore > rhs.m_nameScore;
if (lhs.m_numTokens != rhs.m_numTokens)
return lhs.m_numTokens > rhs.m_numTokens;
}
// Okay, in case of tie we select the one with better rank. This
// is a quite arbitrary decision and definitely may be improved.
return lhs.m_rank > rhs.m_rank;
});
}
sort(ls.begin(), ls.end(), [](ExLocality const & lhs, ExLocality const & rhs)
{
// Probabilities form a stronger signal than name scores do.
if (lhs.m_locality.m_prob != rhs.m_locality.m_prob)
return lhs.m_locality.m_prob > rhs.m_locality.m_prob;
if (IsAlmostFullMatch(lhs.m_nameScore) && IsAlmostFullMatch(rhs.m_nameScore))
{
// When both localities match well, e.g. full or full prefix
// match, the one with larger number of tokens is selected. In
// case of tie, the one with better score is selected.
if (lhs.m_numTokens != rhs.m_numTokens)
return lhs.m_numTokens > rhs.m_numTokens;
if (lhs.m_nameScore != rhs.m_nameScore)
return lhs.m_nameScore > rhs.m_nameScore;
}
else
{
// When name scores differ, the one with better name score is
// selected. In case of tie, the one with larger number of
// matched tokens is selected.
if (lhs.m_nameScore != rhs.m_nameScore)
return lhs.m_nameScore > rhs.m_nameScore;
if (lhs.m_numTokens != rhs.m_numTokens)
return lhs.m_numTokens > rhs.m_numTokens;
}
// Okay, in case of tie we select the one with better rank. This
// is a quite arbitrary decision and definitely may be improved.
return lhs.m_rank > rhs.m_rank;
});
}
string DebugPrint(LocalityScorer::ExLocality const & locality)
{
ostringstream os;
os << "LocalityScorer::ExLocality [ ";
os << "m_locality=" << DebugPrint(locality.m_locality) << ", ";
os << "m_numTokens=" << locality.m_numTokens << ", ";
os << "m_rank=" << static_cast<uint32_t>(locality.m_rank) << ", ";
os << "m_nameScore=" << DebugPrint(locality.m_nameScore);
os << " ]";
return os.str();
}
} // namespace search

View file

@ -3,12 +3,15 @@
#include "search/geocoder.hpp"
#include "search/ranking_utils.hpp"
#include "std/string.hpp"
#include "std/vector.hpp"
#include <cstdint>
#include <string>
#include <vector>
namespace search
{
class CBV;
class QueryParams;
struct BaseContext;
class LocalityScorer
{
@ -20,15 +23,17 @@ public:
public:
virtual ~Delegate() = default;
virtual void GetNames(uint32_t featureId, vector<string> & names) const = 0;
virtual void GetNames(uint32_t featureId, std::vector<std::string> & names) const = 0;
virtual uint8_t GetRank(uint32_t featureId) const = 0;
};
LocalityScorer(QueryParams const & params, Delegate const & delegate);
// Leaves at most |limit| elements of |localities|, ordered by some
// combination of ranks and number of matched tokens.
void GetTopLocalities(size_t limit, vector<Geocoder::Locality> & localities) const;
// Leaves at most |limit| elements of |localities|, ordered by their
// features.
void GetTopLocalities(MwmSet::MwmId const & countryId, BaseContext const & ctx,
CBV const & filter, size_t limit,
std::vector<Geocoder::Locality> & localities);
private:
struct ExLocality
@ -44,9 +49,15 @@ private:
NameScore m_nameScore;
};
void RemoveDuplicates(vector<ExLocality> & ls) const;
void LeaveTopByRankAndProb(size_t limit, vector<ExLocality> & ls) const;
void SortByNameAndProb(vector<ExLocality> & ls) const;
friend std::string DebugPrint(ExLocality const & locality);
// Leaves at most |limit| elements of |localities|, ordered by some
// combination of ranks and number of matched tokens.
void LeaveTopLocalities(size_t limit, std::vector<Geocoder::Locality> & localities) const;
void RemoveDuplicates(std::vector<ExLocality> & ls) const;
void LeaveTopByRankAndProb(size_t limit, std::vector<ExLocality> & ls) const;
void SortByNameAndProb(std::vector<ExLocality> & ls) const;
QueryParams const & m_params;
Delegate const & m_delegate;

View file

@ -86,5 +86,4 @@ NameScore GetNameScore(vector<strings::UniString> const & tokens, TSlice const &
}
string DebugPrint(NameScore score);
} // namespace search

View file

@ -29,17 +29,18 @@ omim_link_libraries(
search
indexer
editor
oauthcpp
storage
platform
opening_hours
geometry
coding
base
protobuf
jansson
succinct
oauthcpp
opening_hours
protobuf
pugixml
stats_client
succinct
${LIBZ}
)

View file

@ -1,11 +1,16 @@
#include "testing/testing.hpp"
#include "search/cbv.hpp"
#include "search/geocoder_context.hpp"
#include "search/locality_scorer.hpp"
#include "indexer/search_delimiters.hpp"
#include "indexer/search_string_utils.hpp"
#include "coding/compressed_bit_vector.hpp"
#include "base/assert.hpp"
#include "base/mem_trie.hpp"
#include "base/stl_add.hpp"
#include "base/stl_helpers.hpp"
#include "base/string_utils.hpp"
@ -20,61 +25,6 @@ using namespace strings;
namespace
{
void InitParams(string const & query, bool lastTokenIsPrefix, QueryParams & params)
{
params.Clear();
vector<UniString> tokens;
Delimiters delims;
SplitUniString(NormalizeAndSimplifyString(query), MakeBackInsertFunctor(tokens), delims);
if (lastTokenIsPrefix)
{
CHECK(!tokens.empty(), ());
params.InitWithPrefix(tokens.begin(), tokens.end() - 1, tokens.back());
}
else
{
params.InitNoPrefix(tokens.begin(), tokens.end());
}
}
void AddLocality(string const & name, uint32_t featureId, QueryParams & params,
vector<Geocoder::Locality> & localities)
{
set<UniString> tokens;
Delimiters delims;
SplitUniString(NormalizeAndSimplifyString(name), MakeInsertFunctor(tokens), delims);
size_t const numTokens = params.GetNumTokens();
for (size_t startToken = 0; startToken != numTokens; ++startToken)
{
for (size_t endToken = startToken + 1; endToken <= numTokens; ++endToken)
{
bool matches = true;
for (size_t i = startToken; i != endToken && matches; ++i)
{
UniString const & queryToken = params.GetToken(i).m_original;
if (params.IsPrefixToken(i))
{
matches = any_of(tokens.begin(), tokens.end(), [&queryToken](UniString const & token)
{
return StartsWith(token, queryToken);
});
}
else
{
matches = (tokens.count(queryToken) != 0);
}
}
if (matches)
localities.emplace_back(featureId, startToken, endToken);
}
}
}
class LocalityScorerTest : public LocalityScorer::Delegate
{
public:
@ -82,18 +32,70 @@ public:
void InitParams(string const & query, bool lastTokenIsPrefix)
{
::InitParams(query, lastTokenIsPrefix, m_params);
m_params.Clear();
vector<UniString> tokens;
Delimiters delims;
SplitUniString(NormalizeAndSimplifyString(query), MakeBackInsertFunctor(tokens), delims);
if (lastTokenIsPrefix)
{
CHECK(!tokens.empty(), ());
m_params.InitWithPrefix(tokens.begin(), tokens.end() - 1, tokens.back());
}
else
{
m_params.InitNoPrefix(tokens.begin(), tokens.end());
}
}
void AddLocality(string const & name, uint32_t featureId)
{
::AddLocality(name, featureId, m_params, m_localities);
set<UniString> tokens;
Delimiters delims;
SplitUniString(NormalizeAndSimplifyString(name), MakeInsertFunctor(tokens), delims);
for (auto const & token : tokens)
m_searchIndex.Add(token, featureId);
m_names[featureId].push_back(name);
}
void GetTopLocalities(size_t limit)
{
m_scorer.GetTopLocalities(limit, m_localities);
BaseContext ctx;
ctx.m_usedTokens.assign(m_params.GetNumTokens(), false);
ctx.m_numTokens = m_params.GetNumTokens();
for (size_t i = 0; i < m_params.GetNumTokens(); ++i)
{
auto const & token = m_params.GetToken(i);
bool const isPrefixToken = m_params.IsPrefixToken(i);
vector<uint64_t> ids;
token.ForEach([&](UniString const & name) {
if (isPrefixToken)
{
m_searchIndex.ForEachInSubtree(name,
[&](UniString const & /* prefix */, uint32_t featureId) {
ids.push_back(featureId);
});
}
else
{
m_searchIndex.ForEachInNode(name, [&](UniString const & /* prefix */,
uint32_t featureId) { ids.push_back(featureId); });
}
});
my::SortUnique(ids);
ctx.m_features.emplace_back(coding::CompressedBitVectorBuilder::FromBitPositions(ids));
}
CBV filter;
filter.SetFull();
m_scorer.GetTopLocalities(MwmSet::MwmId(), ctx, filter, limit, m_localities);
sort(m_localities.begin(), m_localities.end(), my::LessBy(&Geocoder::Locality::m_featureId));
}
@ -112,6 +114,8 @@ protected:
vector<Geocoder::Locality> m_localities;
unordered_map<uint32_t, vector<string>> m_names;
LocalityScorer m_scorer;
my::MemTrie<UniString, uint32_t> m_searchIndex;
};
} // namespace
@ -159,15 +163,12 @@ UNIT_CLASS_TEST(LocalityScorerTest, NumbersMatch)
AddLocality("поселок 1 мая", ID_MAY);
AddLocality("тверь", ID_TVER);
// Tver is the only matched locality as other localities were
// matched only by number.
GetTopLocalities(100 /* limit */);
TEST_EQUAL(4, m_localities.size(), ());
TEST_EQUAL(m_localities[0].m_featureId, ID_MARCH, ());
TEST_EQUAL(m_localities[1].m_featureId, ID_APRIL, ());
TEST_EQUAL(m_localities[2].m_featureId, ID_MAY, ());
TEST_EQUAL(m_localities[3].m_featureId, ID_TVER, ());
TEST_EQUAL(1, m_localities.size(), ());
TEST_EQUAL(m_localities[0].m_featureId, ID_TVER, ());
// Tver is the best matching locality, as other localities were
// matched by number.
GetTopLocalities(1 /* limit */);
TEST_EQUAL(1, m_localities.size(), ());
TEST_EQUAL(m_localities[0].m_featureId, ID_TVER, ());
@ -204,10 +205,8 @@ UNIT_CLASS_TEST(LocalityScorerTest, PrefixMatch)
ID_MOSCOW
};
// QueryParams params;
InitParams("New York San Anto", true /* lastTokenIsPrefix */);
// vector<Geocoder::Locality> localities;
AddLocality("San Antonio", ID_SAN_ANTONIO);
AddLocality("New York", ID_NEW_YORK);
AddLocality("York", ID_YORK);