[search] Honest IDF calculations.

This commit is contained in:
Yuri Gorshenin 2017-12-14 17:47:13 +03:00 committed by mpimenov
parent 16d172f2f5
commit 0e63410a65
9 changed files with 174 additions and 64 deletions

View file

@ -38,7 +38,7 @@ struct TokenWeightPair
}
// Returns squared weight of the token-weight pair.
double Sqr() const { return m_weight * m_weight; }
double SqrWeight() const { return m_weight * m_weight; }
Token m_token;
double m_weight = 0;
@ -54,23 +54,30 @@ std::string DebugPrint(TokenWeightPair<Token> const & tw)
namespace impl
{
// Accumulates weights of equal tokens in |tws|. Result is sorted by tokens.
// Accumulates weights of equal tokens in |tws|. Result is sorted by
// tokens. Also, maximum weight from a group of equal tokens will be
// stored in the corresponding |maxWeight| elem.
template <typename Token>
void SortAndMerge(std::vector<TokenWeightPair<Token>> & tws)
void SortAndMerge(std::vector<TokenWeightPair<Token>> & tws, std::vector<double> & maxWeights)
{
std::sort(tws.begin(), tws.end());
size_t n = 0;
maxWeights.clear();
for (size_t i = 0; i < tws.size(); ++i)
{
ASSERT_LESS_OR_EQUAL(n, i, ());
ASSERT_EQUAL(n, maxWeights.size(), ());
if (n == 0 || tws[n - 1].m_token != tws[i].m_token)
{
tws[n].Swap(tws[i]);
maxWeights.push_back(tws[n].m_weight);
++n;
}
else
{
tws[n - 1].m_weight += tws[i].m_weight;
maxWeights[n - 1] = std::max(maxWeights[n - 1], tws[i].m_weight);
}
}
@ -84,7 +91,7 @@ double SqrL2(std::vector<TokenWeightPair<Token>> const & tws)
{
double sum = 0;
for (auto const & tw : tws)
sum += tw.Sqr();
sum += tw.SqrWeight();
return sum;
}
@ -94,7 +101,7 @@ double SqrL2(std::vector<TokenWeightPair<Token>> const & tws,
boost::optional<TokenWeightPair<Token>> const & prefix)
{
double result = SqrL2(tws);
return result + (prefix ? prefix->Sqr() : 0);
return result + (prefix ? prefix->SqrWeight() : 0);
}
} // namespace impl
@ -126,6 +133,7 @@ public:
explicit DocVec(Builder const & builder) : m_tws(builder.m_tws) { Init(); }
TokenWeightPairs const & GetTokenWeightPairs() const { return m_tws; }
std::vector<double> const & GetMaxWeights() const { return m_maxWeights; }
bool Empty() const { return m_tws.empty(); }
@ -136,9 +144,10 @@ private:
return "DocVec " + DebugPrint(dv.m_tws);
}
void Init() { impl::SortAndMerge(m_tws); }
void Init() { impl::SortAndMerge(m_tws, m_maxWeights); }
TokenWeightPairs m_tws;
std::vector<double> m_maxWeights;
};
// This class represents a search query in a vector space of tokens.
@ -197,6 +206,7 @@ public:
auto const & ls = m_tws;
auto const & rs = rhs.GetTokenWeightPairs();
auto const & maxWeights = rhs.GetMaxWeights();
ASSERT(std::is_sorted(ls.begin(), ls.end()), ());
ASSERT(std::is_sorted(rs.begin(), rs.end()), ());
@ -258,29 +268,22 @@ public:
// query, we need to update it's weight in the cosine distance
// - so we need to update correspondingly dot product and
// vector norms of query and doc.
auto const w = maxWeights[j];
auto const l = std::max(0.0, ln - prefix.SqrWeight() + w * w);
// This is the hacky moment: weight of query prefix token may
// be greater than the weight of the corresponding document
// token, because the weight of the document token may be
// unknown at the moment, and be set to some default value.
// But this heuristic works nicely in practice.
double const w = std::max(prefix.m_weight, tw.m_weight);
auto const sqrW = w * w;
double const l = std::max(0.0, ln - prefix.Sqr() + sqrW);
double const r = std::max(0.0, rn - tw.Sqr() + sqrW);
nom = dot + sqrW;
denom = sqrt(l) * sqrt(r);
nom = dot + w * tw.m_weight;
denom = sqrt(l) * sqrt(rn);
}
else
{
// If this document token is already matched with |i|-th full
// token in a query - we here that completion of the prefix
// token is the |i|-th token. So we need to update
// token in a query - we know that completion of the prefix
// token is the |i|-th query token. So we need to update
// correspondingly dot product and vector norm of the query.
double const l = ln + 2 * ls[i].m_weight * prefix.m_weight;
auto const w = ls[i].m_weight + m_maxWeights[i];
auto const l = ln - ls[i].SqrWeight() - prefix.SqrWeight() + w * w;
nom = dot + prefix.m_weight * tw.m_weight;
nom = dot + (w - ls[i].m_weight) * tw.m_weight;
denom = sqrt(l) * sqrt(rn);
}
@ -310,9 +313,10 @@ private:
return "QueryVec " + DebugPrint(qv.m_tws);
}
void Init() { impl::SortAndMerge(m_tws); }
void Init() { impl::SortAndMerge(m_tws, m_maxWeights); }
std::vector<TokenWeightPair> m_tws;
std::vector<double> m_maxWeights;
boost::optional<TokenWeightPair> m_prefix;
};
} // namespace search

View file

@ -157,8 +157,13 @@ private:
class LocalityScorerDelegate : public LocalityScorer::Delegate
{
public:
LocalityScorerDelegate(MwmContext const & context, Geocoder::Params const & params)
: m_context(context), m_params(params), m_ranks(m_context.m_value)
LocalityScorerDelegate(MwmContext const & context, Geocoder::Params const & params,
my::Cancellable const & cancellable)
: m_context(context)
, m_params(params)
, m_cancellable(cancellable)
, m_retrieval(m_context, m_cancellable)
, m_ranks(m_context.m_value)
{
}
@ -178,9 +183,22 @@ public:
uint8_t GetRank(uint32_t featureId) const override { return m_ranks.Get(featureId); }
CBV GetMatchedFeatures(strings::UniString const & token) const override
{
SearchTrieRequest<strings::UniStringDFA> request;
request.m_names.emplace_back(token);
request.SetLangs(m_params.GetLangs());
return CBV{m_retrieval.RetrieveAddressFeatures(request)};
}
private:
MwmContext const & m_context;
Geocoder::Params const & m_params;
my::Cancellable const & m_cancellable;
Retrieval m_retrieval;
LazyRankTable m_ranks;
};
@ -624,7 +642,7 @@ void Geocoder::FillLocalityCandidates(BaseContext const & ctx, CBV const & filte
return;
}
LocalityScorerDelegate delegate(*m_context, m_params);
LocalityScorerDelegate delegate(*m_context, m_params, m_cancellable);
LocalityScorer scorer(m_params, delegate);
scorer.GetTopLocalities(m_context->GetId(), ctx, filter, maxNumLocalities, preLocalities);
}

View file

@ -2,11 +2,21 @@
namespace search
{
IdfMap::IdfMap(double unknownIdf): m_unknownIdf(unknownIdf) {}
IdfMap::IdfMap(Delegate & delegate, double unknownIdf)
: m_delegate(delegate), m_unknownIdf(unknownIdf)
{
}
double IdfMap::Get(strings::UniString const & s) const
double IdfMap::Get(strings::UniString const & s)
{
auto const it = m_idfs.find(s);
return it == m_idfs.cend() ? m_unknownIdf : it->second;
if (it != m_idfs.cend())
return it->second;
auto const df = static_cast<double>(m_delegate.GetNumDocs(s));
auto const idf = df == 0 ? m_unknownIdf : 1.0 / df;
m_idfs[s] = idf;
return idf;
}
} // namespace search

View file

@ -2,6 +2,7 @@
#include "base/string_utils.hpp"
#include <cstdint>
#include <map>
namespace search
@ -9,13 +10,22 @@ namespace search
class IdfMap
{
public:
explicit IdfMap(double unknownIdf);
struct Delegate
{
virtual ~Delegate() = default;
virtual uint64_t GetNumDocs(strings::UniString const & token) const = 0;
};
IdfMap(Delegate & delegate, double unknownIdf);
void Set(strings::UniString const & s, double idf) { m_idfs[s] = idf; }
double Get(strings::UniString const & s) const;
double Get(strings::UniString const & s);
private:
std::map<strings::UniString, double> m_idfs;
Delegate & m_delegate;
double m_unknownIdf;
};
} // namespace search

View file

@ -19,6 +19,27 @@ using namespace std;
namespace search
{
namespace
{
struct IdfMapDelegate : public IdfMap::Delegate
{
IdfMapDelegate(LocalityScorer::Delegate const & delegate, CBV const & filter)
: m_delegate(delegate), m_filter(filter)
{
}
~IdfMapDelegate() override = default;
uint64_t GetNumDocs(strings::UniString const & token) const override
{
return m_filter.Intersect(m_delegate.GetMatchedFeatures(token)).PopCount();
}
LocalityScorer::Delegate const & m_delegate;
CBV const & m_filter;
};
} // namespace
// static
size_t const LocalityScorer::kDefaultReadLimit = 100;
@ -46,12 +67,25 @@ void LocalityScorer::GetTopLocalities(MwmSet::MwmId const & countryId, BaseConte
for (size_t i = 0; i < ctx.m_numTokens; ++i)
intersections[i] = filter.Intersect(ctx.m_features[i]);
IdfMap idfs(1.0 /* unknownIdf */);
IdfMapDelegate delegate(m_delegate, filter);
IdfMap idfs(delegate, 1.0 /* unknownIdf */);
double prefixIdf = 1.0;
for (size_t i = 0; i < ctx.m_numTokens; ++i)
{
auto const idf = 1.0 / static_cast<double>(intersections[i].PopCount());
// IDF should be the same for the token and its synonyms.
m_params.GetToken(i).ForEach([&idfs, &idf](strings::UniString const & s) { idfs.Set(s, idf); });
auto const numDocs = intersections[i].PopCount();
double idf = 1.0;
if (numDocs > 0)
idf = 1.0 / static_cast<double>(numDocs);
if (m_params.IsPrefixToken(i))
{
prefixIdf = idf;
}
else
{
m_params.GetToken(i).ForEach(
[&idfs, &idf](strings::UniString const & s) { idfs.Set(s, idf); });
}
}
for (size_t startToken = 0; startToken < ctx.m_numTokens; ++startToken)
@ -64,11 +98,10 @@ void LocalityScorer::GetTopLocalities(MwmSet::MwmId const & countryId, BaseConte
{
auto const curToken = endToken - 1;
auto const & token = m_params.GetToken(curToken).m_original;
double const weight = idfs.Get(token);
if (m_params.IsPrefixToken(curToken))
builder.SetPrefix(token, weight);
builder.SetPrefix(token, prefixIdf);
else
builder.AddFull(token, weight);
builder.AddFull(token, idfs.Get(token));
TokenRange const tokenRange(startToken, endToken);
// Skip locality candidates that match only numbers.
@ -88,7 +121,7 @@ void LocalityScorer::GetTopLocalities(MwmSet::MwmId const & countryId, BaseConte
LeaveTopLocalities(idfs, limit, localities);
}
void LocalityScorer::LeaveTopLocalities(IdfMap const & idfs, size_t limit,
void LocalityScorer::LeaveTopLocalities(IdfMap & idfs, size_t limit,
vector<Locality> & localities) const
{
vector<ExLocality> els;
@ -179,8 +212,7 @@ void LocalityScorer::LeaveTopBySimilarityAndRank(size_t limit, vector<ExLocality
els.resize(n);
}
void LocalityScorer::GetDocVecs(IdfMap const & idfs, uint32_t localityId,
vector<DocVec> & dvs) const
void LocalityScorer::GetDocVecs(IdfMap & idfs, uint32_t localityId, vector<DocVec> & dvs) const
{
vector<string> names;
m_delegate.GetNames(localityId, names);

View file

@ -1,5 +1,6 @@
#pragma once
#include "search/cbv.hpp"
#include "search/geocoder_locality.hpp"
#include "search/ranking_utils.hpp"
@ -29,6 +30,7 @@ public:
virtual void GetNames(uint32_t featureId, std::vector<std::string> & names) const = 0;
virtual uint8_t GetRank(uint32_t featureId) const = 0;
virtual CBV GetMatchedFeatures(strings::UniString const & token) const = 0;
};
LocalityScorer(QueryParams const & params, Delegate const & delegate);
@ -59,8 +61,7 @@ private:
// Leaves at most |limit| elements of |localities|, ordered by some
// combination of ranks and number of matched tokens.
void LeaveTopLocalities(IdfMap const & idfs, size_t limit,
std::vector<Locality> & localities) const;
void LeaveTopLocalities(IdfMap & idfs, size_t limit, std::vector<Locality> & localities) const;
// Selects at most |limitUniqueIds| best features by query norm and
// rank, and then leaves only localities corresponding to those
@ -71,7 +72,7 @@ private:
// and rank. Result doesn't contain duplicate features.
void LeaveTopBySimilarityAndRank(size_t limit, std::vector<ExLocality> & els) const;
void GetDocVecs(IdfMap const & idfs, uint32_t localityId, vector<DocVec> & dvs) const;
void GetDocVecs(IdfMap & idfs, uint32_t localityId, std::vector<DocVec> & dvs) const;
double GetSimilarity(QueryVec const & qv, std::vector<DocVec> const & dvs) const;
QueryParams const & m_params;

View file

@ -319,43 +319,43 @@ Retrieval::Retrieval(MwmContext const & context, my::Cancellable const & cancell
}
unique_ptr<coding::CompressedBitVector> Retrieval::RetrieveAddressFeatures(
SearchTrieRequest<UniStringDFA> const & request)
SearchTrieRequest<UniStringDFA> const & request) const
{
return Retrieve<RetrieveAddressFeaturesAdaptor>(request);
}
unique_ptr<coding::CompressedBitVector> Retrieval::RetrieveAddressFeatures(
SearchTrieRequest<PrefixDFAModifier<UniStringDFA>> const & request)
SearchTrieRequest<PrefixDFAModifier<UniStringDFA>> const & request) const
{
return Retrieve<RetrieveAddressFeaturesAdaptor>(request);
}
unique_ptr<coding::CompressedBitVector> Retrieval::RetrieveAddressFeatures(
SearchTrieRequest<LevenshteinDFA> const & request)
SearchTrieRequest<LevenshteinDFA> const & request) const
{
return Retrieve<RetrieveAddressFeaturesAdaptor>(request);
}
unique_ptr<coding::CompressedBitVector> Retrieval::RetrieveAddressFeatures(
SearchTrieRequest<PrefixDFAModifier<LevenshteinDFA>> const & request)
SearchTrieRequest<PrefixDFAModifier<LevenshteinDFA>> const & request) const
{
return Retrieve<RetrieveAddressFeaturesAdaptor>(request);
}
unique_ptr<coding::CompressedBitVector> Retrieval::RetrievePostcodeFeatures(
TokenSlice const & slice)
TokenSlice const & slice) const
{
return Retrieve<RetrievePostcodeFeaturesAdaptor>(slice);
}
unique_ptr<coding::CompressedBitVector> Retrieval::RetrieveGeometryFeatures(m2::RectD const & rect,
int scale)
int scale) const
{
return RetrieveGeometryFeaturesImpl(m_context, m_cancellable, rect, scale);
}
template <template <typename> class R, typename... Args>
unique_ptr<coding::CompressedBitVector> Retrieval::Retrieve(Args &&... args)
unique_ptr<coding::CompressedBitVector> Retrieval::Retrieve(Args &&... args) const
{
switch (m_format)
{

View file

@ -38,28 +38,28 @@ public:
// Following functions retrieve from the search index corresponding to
// |value| all features matching to |request|.
unique_ptr<coding::CompressedBitVector> RetrieveAddressFeatures(
SearchTrieRequest<strings::UniStringDFA> const & request);
SearchTrieRequest<strings::UniStringDFA> const & request) const;
unique_ptr<coding::CompressedBitVector> RetrieveAddressFeatures(
SearchTrieRequest<strings::PrefixDFAModifier<strings::UniStringDFA>> const & request);
SearchTrieRequest<strings::PrefixDFAModifier<strings::UniStringDFA>> const & request) const;
unique_ptr<coding::CompressedBitVector> RetrieveAddressFeatures(
SearchTrieRequest<strings::LevenshteinDFA> const & request);
SearchTrieRequest<strings::LevenshteinDFA> const & request) const;
unique_ptr<coding::CompressedBitVector> RetrieveAddressFeatures(
SearchTrieRequest<strings::PrefixDFAModifier<strings::LevenshteinDFA>> const & request);
SearchTrieRequest<strings::PrefixDFAModifier<strings::LevenshteinDFA>> const & request) const;
// Retrieves from the search index corresponding to |value| all
// postcodes matching to |slice|.
unique_ptr<coding::CompressedBitVector> RetrievePostcodeFeatures(TokenSlice const & slice);
unique_ptr<coding::CompressedBitVector> RetrievePostcodeFeatures(TokenSlice const & slice) const;
// Retrieves from the geometry index corresponding to |value| all features belonging to |rect|.
unique_ptr<coding::CompressedBitVector> RetrieveGeometryFeatures(m2::RectD const & rect,
int scale);
int scale) const;
private:
template <template <typename> class R, typename... Args>
unique_ptr<coding::CompressedBitVector> Retrieve(Args &&... args);
unique_ptr<coding::CompressedBitVector> Retrieve(Args &&... args) const;
MwmContext const & m_context;
my::Cancellable const & m_cancellable;

View file

@ -15,12 +15,13 @@
#include "base/stl_helpers.hpp"
#include "base/string_utils.hpp"
#include "std/algorithm.hpp"
#include "std/set.hpp"
#include "std/unordered_map.hpp"
#include "std/vector.hpp"
#include <algorithm>
#include <set>
#include <unordered_map>
#include <vector>
using namespace search;
using namespace std;
using namespace strings;
namespace
@ -51,7 +52,7 @@ public:
}
}
void AddLocality(string const & name, uint32_t featureId)
void AddLocality(string const & name, uint32_t featureId, uint8_t rank = 0)
{
set<UniString> tokens;
Delimiters delims;
@ -61,6 +62,7 @@ public:
m_searchIndex.Add(token, featureId);
m_names[featureId].push_back(name);
m_ranks[featureId] = rank;
}
Ids GetTopLocalities(size_t limit)
@ -114,11 +116,24 @@ public:
names.insert(names.end(), it->second.begin(), it->second.end());
}
uint8_t GetRank(uint32_t /* featureId */) const override { return 0; }
uint8_t GetRank(uint32_t featureId) const override
{
auto it = m_ranks.find(featureId);
return it == m_ranks.end() ? 0 : it->second;
}
CBV GetMatchedFeatures(strings::UniString const & token) const override
{
vector<uint64_t> ids;
m_searchIndex.ForEachInNode(token, [&ids](uint32_t id) { ids.push_back(id); });
my::SortUnique(ids);
return CBV{coding::CompressedBitVectorBuilder::FromBitPositions(move(ids))};
}
protected:
QueryParams m_params;
unordered_map<uint32_t, vector<string>> m_names;
unordered_map<uint32_t, uint8_t> m_ranks;
LocalityScorer m_scorer;
base::MemTrie<UniString, base::VectorValues<uint32_t>> m_searchIndex;
@ -207,3 +222,23 @@ UNIT_CLASS_TEST(LocalityScorerTest, PrefixMatch)
TEST_EQUAL(GetTopLocalities(2 /* limit */), Ids({ID_SAN_ANTONIO, ID_NEW_YORK}), ());
TEST_EQUAL(GetTopLocalities(1 /* limit */), Ids({ID_SAN_ANTONIO}), ());
}
UNIT_CLASS_TEST(LocalityScorerTest, Ranks)
{
enum
{
ID_SAN_MARINO,
ID_SAN_ANTONIO,
ID_SAN_FRANCISCO
};
AddLocality("San Marino", ID_SAN_MARINO, 10 /* rank */);
AddLocality("Citta di San Antonio", ID_SAN_ANTONIO, 20 /* rank */);
AddLocality("San Francisco", ID_SAN_FRANCISCO, 30 /* rank */);
InitParams("San", false /* lastTokenIsPrefix */);
TEST_EQUAL(GetTopLocalities(100 /* limit */),
Ids({ID_SAN_MARINO, ID_SAN_ANTONIO, ID_SAN_FRANCISCO}), ());
TEST_EQUAL(GetTopLocalities(2 /* limit */), Ids({ID_SAN_MARINO, ID_SAN_FRANCISCO}), ());
TEST_EQUAL(GetTopLocalities(1 /* limit */), Ids({ID_SAN_FRANCISCO}), ());
}