forked from organicmaps/organicmaps
[search] Honest IDF calculations.
This commit is contained in:
parent
16d172f2f5
commit
0e63410a65
9 changed files with 174 additions and 64 deletions
|
@ -38,7 +38,7 @@ struct TokenWeightPair
|
|||
}
|
||||
|
||||
// Returns squared weight of the token-weight pair.
|
||||
double Sqr() const { return m_weight * m_weight; }
|
||||
double SqrWeight() const { return m_weight * m_weight; }
|
||||
|
||||
Token m_token;
|
||||
double m_weight = 0;
|
||||
|
@ -54,23 +54,30 @@ std::string DebugPrint(TokenWeightPair<Token> const & tw)
|
|||
|
||||
namespace impl
|
||||
{
|
||||
// Accumulates weights of equal tokens in |tws|. Result is sorted by tokens.
|
||||
// Accumulates weights of equal tokens in |tws|. Result is sorted by
|
||||
// tokens. Also, maximum weight from a group of equal tokens will be
|
||||
// stored in the corresponding |maxWeight| elem.
|
||||
template <typename Token>
|
||||
void SortAndMerge(std::vector<TokenWeightPair<Token>> & tws)
|
||||
void SortAndMerge(std::vector<TokenWeightPair<Token>> & tws, std::vector<double> & maxWeights)
|
||||
{
|
||||
std::sort(tws.begin(), tws.end());
|
||||
size_t n = 0;
|
||||
maxWeights.clear();
|
||||
for (size_t i = 0; i < tws.size(); ++i)
|
||||
{
|
||||
ASSERT_LESS_OR_EQUAL(n, i, ());
|
||||
ASSERT_EQUAL(n, maxWeights.size(), ());
|
||||
|
||||
if (n == 0 || tws[n - 1].m_token != tws[i].m_token)
|
||||
{
|
||||
tws[n].Swap(tws[i]);
|
||||
maxWeights.push_back(tws[n].m_weight);
|
||||
++n;
|
||||
}
|
||||
else
|
||||
{
|
||||
tws[n - 1].m_weight += tws[i].m_weight;
|
||||
maxWeights[n - 1] = std::max(maxWeights[n - 1], tws[i].m_weight);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -84,7 +91,7 @@ double SqrL2(std::vector<TokenWeightPair<Token>> const & tws)
|
|||
{
|
||||
double sum = 0;
|
||||
for (auto const & tw : tws)
|
||||
sum += tw.Sqr();
|
||||
sum += tw.SqrWeight();
|
||||
return sum;
|
||||
}
|
||||
|
||||
|
@ -94,7 +101,7 @@ double SqrL2(std::vector<TokenWeightPair<Token>> const & tws,
|
|||
boost::optional<TokenWeightPair<Token>> const & prefix)
|
||||
{
|
||||
double result = SqrL2(tws);
|
||||
return result + (prefix ? prefix->Sqr() : 0);
|
||||
return result + (prefix ? prefix->SqrWeight() : 0);
|
||||
}
|
||||
} // namespace impl
|
||||
|
||||
|
@ -126,6 +133,7 @@ public:
|
|||
explicit DocVec(Builder const & builder) : m_tws(builder.m_tws) { Init(); }
|
||||
|
||||
TokenWeightPairs const & GetTokenWeightPairs() const { return m_tws; }
|
||||
std::vector<double> const & GetMaxWeights() const { return m_maxWeights; }
|
||||
|
||||
bool Empty() const { return m_tws.empty(); }
|
||||
|
||||
|
@ -136,9 +144,10 @@ private:
|
|||
return "DocVec " + DebugPrint(dv.m_tws);
|
||||
}
|
||||
|
||||
void Init() { impl::SortAndMerge(m_tws); }
|
||||
void Init() { impl::SortAndMerge(m_tws, m_maxWeights); }
|
||||
|
||||
TokenWeightPairs m_tws;
|
||||
std::vector<double> m_maxWeights;
|
||||
};
|
||||
|
||||
// This class represents a search query in a vector space of tokens.
|
||||
|
@ -197,6 +206,7 @@ public:
|
|||
|
||||
auto const & ls = m_tws;
|
||||
auto const & rs = rhs.GetTokenWeightPairs();
|
||||
auto const & maxWeights = rhs.GetMaxWeights();
|
||||
|
||||
ASSERT(std::is_sorted(ls.begin(), ls.end()), ());
|
||||
ASSERT(std::is_sorted(rs.begin(), rs.end()), ());
|
||||
|
@ -258,29 +268,22 @@ public:
|
|||
// query, we need to update it's weight in the cosine distance
|
||||
// - so we need to update correspondingly dot product and
|
||||
// vector norms of query and doc.
|
||||
auto const w = maxWeights[j];
|
||||
auto const l = std::max(0.0, ln - prefix.SqrWeight() + w * w);
|
||||
|
||||
// This is the hacky moment: weight of query prefix token may
|
||||
// be greater than the weight of the corresponding document
|
||||
// token, because the weight of the document token may be
|
||||
// unknown at the moment, and be set to some default value.
|
||||
// But this heuristic works nicely in practice.
|
||||
double const w = std::max(prefix.m_weight, tw.m_weight);
|
||||
auto const sqrW = w * w;
|
||||
double const l = std::max(0.0, ln - prefix.Sqr() + sqrW);
|
||||
double const r = std::max(0.0, rn - tw.Sqr() + sqrW);
|
||||
|
||||
nom = dot + sqrW;
|
||||
denom = sqrt(l) * sqrt(r);
|
||||
nom = dot + w * tw.m_weight;
|
||||
denom = sqrt(l) * sqrt(rn);
|
||||
}
|
||||
else
|
||||
{
|
||||
// If this document token is already matched with |i|-th full
|
||||
// token in a query - we here that completion of the prefix
|
||||
// token is the |i|-th token. So we need to update
|
||||
// token in a query - we know that completion of the prefix
|
||||
// token is the |i|-th query token. So we need to update
|
||||
// correspondingly dot product and vector norm of the query.
|
||||
double const l = ln + 2 * ls[i].m_weight * prefix.m_weight;
|
||||
auto const w = ls[i].m_weight + m_maxWeights[i];
|
||||
auto const l = ln - ls[i].SqrWeight() - prefix.SqrWeight() + w * w;
|
||||
|
||||
nom = dot + prefix.m_weight * tw.m_weight;
|
||||
nom = dot + (w - ls[i].m_weight) * tw.m_weight;
|
||||
denom = sqrt(l) * sqrt(rn);
|
||||
}
|
||||
|
||||
|
@ -310,9 +313,10 @@ private:
|
|||
return "QueryVec " + DebugPrint(qv.m_tws);
|
||||
}
|
||||
|
||||
void Init() { impl::SortAndMerge(m_tws); }
|
||||
void Init() { impl::SortAndMerge(m_tws, m_maxWeights); }
|
||||
|
||||
std::vector<TokenWeightPair> m_tws;
|
||||
std::vector<double> m_maxWeights;
|
||||
boost::optional<TokenWeightPair> m_prefix;
|
||||
};
|
||||
} // namespace search
|
||||
|
|
|
@ -157,8 +157,13 @@ private:
|
|||
class LocalityScorerDelegate : public LocalityScorer::Delegate
|
||||
{
|
||||
public:
|
||||
LocalityScorerDelegate(MwmContext const & context, Geocoder::Params const & params)
|
||||
: m_context(context), m_params(params), m_ranks(m_context.m_value)
|
||||
LocalityScorerDelegate(MwmContext const & context, Geocoder::Params const & params,
|
||||
my::Cancellable const & cancellable)
|
||||
: m_context(context)
|
||||
, m_params(params)
|
||||
, m_cancellable(cancellable)
|
||||
, m_retrieval(m_context, m_cancellable)
|
||||
, m_ranks(m_context.m_value)
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -178,9 +183,22 @@ public:
|
|||
|
||||
uint8_t GetRank(uint32_t featureId) const override { return m_ranks.Get(featureId); }
|
||||
|
||||
CBV GetMatchedFeatures(strings::UniString const & token) const override
|
||||
{
|
||||
SearchTrieRequest<strings::UniStringDFA> request;
|
||||
request.m_names.emplace_back(token);
|
||||
request.SetLangs(m_params.GetLangs());
|
||||
|
||||
return CBV{m_retrieval.RetrieveAddressFeatures(request)};
|
||||
}
|
||||
|
||||
private:
|
||||
MwmContext const & m_context;
|
||||
Geocoder::Params const & m_params;
|
||||
my::Cancellable const & m_cancellable;
|
||||
|
||||
Retrieval m_retrieval;
|
||||
|
||||
LazyRankTable m_ranks;
|
||||
};
|
||||
|
||||
|
@ -624,7 +642,7 @@ void Geocoder::FillLocalityCandidates(BaseContext const & ctx, CBV const & filte
|
|||
return;
|
||||
}
|
||||
|
||||
LocalityScorerDelegate delegate(*m_context, m_params);
|
||||
LocalityScorerDelegate delegate(*m_context, m_params, m_cancellable);
|
||||
LocalityScorer scorer(m_params, delegate);
|
||||
scorer.GetTopLocalities(m_context->GetId(), ctx, filter, maxNumLocalities, preLocalities);
|
||||
}
|
||||
|
|
|
@ -2,11 +2,21 @@
|
|||
|
||||
namespace search
|
||||
{
|
||||
IdfMap::IdfMap(double unknownIdf): m_unknownIdf(unknownIdf) {}
|
||||
IdfMap::IdfMap(Delegate & delegate, double unknownIdf)
|
||||
: m_delegate(delegate), m_unknownIdf(unknownIdf)
|
||||
{
|
||||
}
|
||||
|
||||
double IdfMap::Get(strings::UniString const & s) const
|
||||
double IdfMap::Get(strings::UniString const & s)
|
||||
{
|
||||
auto const it = m_idfs.find(s);
|
||||
return it == m_idfs.cend() ? m_unknownIdf : it->second;
|
||||
if (it != m_idfs.cend())
|
||||
return it->second;
|
||||
|
||||
auto const df = static_cast<double>(m_delegate.GetNumDocs(s));
|
||||
auto const idf = df == 0 ? m_unknownIdf : 1.0 / df;
|
||||
m_idfs[s] = idf;
|
||||
|
||||
return idf;
|
||||
}
|
||||
} // namespace search
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
#include "base/string_utils.hpp"
|
||||
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
|
||||
namespace search
|
||||
|
@ -9,13 +10,22 @@ namespace search
|
|||
class IdfMap
|
||||
{
|
||||
public:
|
||||
explicit IdfMap(double unknownIdf);
|
||||
struct Delegate
|
||||
{
|
||||
virtual ~Delegate() = default;
|
||||
|
||||
virtual uint64_t GetNumDocs(strings::UniString const & token) const = 0;
|
||||
};
|
||||
|
||||
IdfMap(Delegate & delegate, double unknownIdf);
|
||||
|
||||
void Set(strings::UniString const & s, double idf) { m_idfs[s] = idf; }
|
||||
double Get(strings::UniString const & s) const;
|
||||
double Get(strings::UniString const & s);
|
||||
|
||||
private:
|
||||
std::map<strings::UniString, double> m_idfs;
|
||||
|
||||
Delegate & m_delegate;
|
||||
double m_unknownIdf;
|
||||
};
|
||||
} // namespace search
|
||||
|
|
|
@ -19,6 +19,27 @@ using namespace std;
|
|||
|
||||
namespace search
|
||||
{
|
||||
namespace
|
||||
{
|
||||
struct IdfMapDelegate : public IdfMap::Delegate
|
||||
{
|
||||
IdfMapDelegate(LocalityScorer::Delegate const & delegate, CBV const & filter)
|
||||
: m_delegate(delegate), m_filter(filter)
|
||||
{
|
||||
}
|
||||
|
||||
~IdfMapDelegate() override = default;
|
||||
|
||||
uint64_t GetNumDocs(strings::UniString const & token) const override
|
||||
{
|
||||
return m_filter.Intersect(m_delegate.GetMatchedFeatures(token)).PopCount();
|
||||
}
|
||||
|
||||
LocalityScorer::Delegate const & m_delegate;
|
||||
CBV const & m_filter;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// static
|
||||
size_t const LocalityScorer::kDefaultReadLimit = 100;
|
||||
|
||||
|
@ -46,12 +67,25 @@ void LocalityScorer::GetTopLocalities(MwmSet::MwmId const & countryId, BaseConte
|
|||
for (size_t i = 0; i < ctx.m_numTokens; ++i)
|
||||
intersections[i] = filter.Intersect(ctx.m_features[i]);
|
||||
|
||||
IdfMap idfs(1.0 /* unknownIdf */);
|
||||
IdfMapDelegate delegate(m_delegate, filter);
|
||||
IdfMap idfs(delegate, 1.0 /* unknownIdf */);
|
||||
double prefixIdf = 1.0;
|
||||
for (size_t i = 0; i < ctx.m_numTokens; ++i)
|
||||
{
|
||||
auto const idf = 1.0 / static_cast<double>(intersections[i].PopCount());
|
||||
// IDF should be the same for the token and its synonyms.
|
||||
m_params.GetToken(i).ForEach([&idfs, &idf](strings::UniString const & s) { idfs.Set(s, idf); });
|
||||
auto const numDocs = intersections[i].PopCount();
|
||||
double idf = 1.0;
|
||||
if (numDocs > 0)
|
||||
idf = 1.0 / static_cast<double>(numDocs);
|
||||
|
||||
if (m_params.IsPrefixToken(i))
|
||||
{
|
||||
prefixIdf = idf;
|
||||
}
|
||||
else
|
||||
{
|
||||
m_params.GetToken(i).ForEach(
|
||||
[&idfs, &idf](strings::UniString const & s) { idfs.Set(s, idf); });
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t startToken = 0; startToken < ctx.m_numTokens; ++startToken)
|
||||
|
@ -64,11 +98,10 @@ void LocalityScorer::GetTopLocalities(MwmSet::MwmId const & countryId, BaseConte
|
|||
{
|
||||
auto const curToken = endToken - 1;
|
||||
auto const & token = m_params.GetToken(curToken).m_original;
|
||||
double const weight = idfs.Get(token);
|
||||
if (m_params.IsPrefixToken(curToken))
|
||||
builder.SetPrefix(token, weight);
|
||||
builder.SetPrefix(token, prefixIdf);
|
||||
else
|
||||
builder.AddFull(token, weight);
|
||||
builder.AddFull(token, idfs.Get(token));
|
||||
|
||||
TokenRange const tokenRange(startToken, endToken);
|
||||
// Skip locality candidates that match only numbers.
|
||||
|
@ -88,7 +121,7 @@ void LocalityScorer::GetTopLocalities(MwmSet::MwmId const & countryId, BaseConte
|
|||
LeaveTopLocalities(idfs, limit, localities);
|
||||
}
|
||||
|
||||
void LocalityScorer::LeaveTopLocalities(IdfMap const & idfs, size_t limit,
|
||||
void LocalityScorer::LeaveTopLocalities(IdfMap & idfs, size_t limit,
|
||||
vector<Locality> & localities) const
|
||||
{
|
||||
vector<ExLocality> els;
|
||||
|
@ -179,8 +212,7 @@ void LocalityScorer::LeaveTopBySimilarityAndRank(size_t limit, vector<ExLocality
|
|||
els.resize(n);
|
||||
}
|
||||
|
||||
void LocalityScorer::GetDocVecs(IdfMap const & idfs, uint32_t localityId,
|
||||
vector<DocVec> & dvs) const
|
||||
void LocalityScorer::GetDocVecs(IdfMap & idfs, uint32_t localityId, vector<DocVec> & dvs) const
|
||||
{
|
||||
vector<string> names;
|
||||
m_delegate.GetNames(localityId, names);
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#include "search/cbv.hpp"
|
||||
#include "search/geocoder_locality.hpp"
|
||||
#include "search/ranking_utils.hpp"
|
||||
|
||||
|
@ -29,6 +30,7 @@ public:
|
|||
|
||||
virtual void GetNames(uint32_t featureId, std::vector<std::string> & names) const = 0;
|
||||
virtual uint8_t GetRank(uint32_t featureId) const = 0;
|
||||
virtual CBV GetMatchedFeatures(strings::UniString const & token) const = 0;
|
||||
};
|
||||
|
||||
LocalityScorer(QueryParams const & params, Delegate const & delegate);
|
||||
|
@ -59,8 +61,7 @@ private:
|
|||
|
||||
// Leaves at most |limit| elements of |localities|, ordered by some
|
||||
// combination of ranks and number of matched tokens.
|
||||
void LeaveTopLocalities(IdfMap const & idfs, size_t limit,
|
||||
std::vector<Locality> & localities) const;
|
||||
void LeaveTopLocalities(IdfMap & idfs, size_t limit, std::vector<Locality> & localities) const;
|
||||
|
||||
// Selects at most |limitUniqueIds| best features by query norm and
|
||||
// rank, and then leaves only localities corresponding to those
|
||||
|
@ -71,7 +72,7 @@ private:
|
|||
// and rank. Result doesn't contain duplicate features.
|
||||
void LeaveTopBySimilarityAndRank(size_t limit, std::vector<ExLocality> & els) const;
|
||||
|
||||
void GetDocVecs(IdfMap const & idfs, uint32_t localityId, vector<DocVec> & dvs) const;
|
||||
void GetDocVecs(IdfMap & idfs, uint32_t localityId, std::vector<DocVec> & dvs) const;
|
||||
double GetSimilarity(QueryVec const & qv, std::vector<DocVec> const & dvs) const;
|
||||
|
||||
QueryParams const & m_params;
|
||||
|
|
|
@ -319,43 +319,43 @@ Retrieval::Retrieval(MwmContext const & context, my::Cancellable const & cancell
|
|||
}
|
||||
|
||||
unique_ptr<coding::CompressedBitVector> Retrieval::RetrieveAddressFeatures(
|
||||
SearchTrieRequest<UniStringDFA> const & request)
|
||||
SearchTrieRequest<UniStringDFA> const & request) const
|
||||
{
|
||||
return Retrieve<RetrieveAddressFeaturesAdaptor>(request);
|
||||
}
|
||||
|
||||
unique_ptr<coding::CompressedBitVector> Retrieval::RetrieveAddressFeatures(
|
||||
SearchTrieRequest<PrefixDFAModifier<UniStringDFA>> const & request)
|
||||
SearchTrieRequest<PrefixDFAModifier<UniStringDFA>> const & request) const
|
||||
{
|
||||
return Retrieve<RetrieveAddressFeaturesAdaptor>(request);
|
||||
}
|
||||
|
||||
unique_ptr<coding::CompressedBitVector> Retrieval::RetrieveAddressFeatures(
|
||||
SearchTrieRequest<LevenshteinDFA> const & request)
|
||||
SearchTrieRequest<LevenshteinDFA> const & request) const
|
||||
{
|
||||
return Retrieve<RetrieveAddressFeaturesAdaptor>(request);
|
||||
}
|
||||
|
||||
unique_ptr<coding::CompressedBitVector> Retrieval::RetrieveAddressFeatures(
|
||||
SearchTrieRequest<PrefixDFAModifier<LevenshteinDFA>> const & request)
|
||||
SearchTrieRequest<PrefixDFAModifier<LevenshteinDFA>> const & request) const
|
||||
{
|
||||
return Retrieve<RetrieveAddressFeaturesAdaptor>(request);
|
||||
}
|
||||
|
||||
unique_ptr<coding::CompressedBitVector> Retrieval::RetrievePostcodeFeatures(
|
||||
TokenSlice const & slice)
|
||||
TokenSlice const & slice) const
|
||||
{
|
||||
return Retrieve<RetrievePostcodeFeaturesAdaptor>(slice);
|
||||
}
|
||||
|
||||
unique_ptr<coding::CompressedBitVector> Retrieval::RetrieveGeometryFeatures(m2::RectD const & rect,
|
||||
int scale)
|
||||
int scale) const
|
||||
{
|
||||
return RetrieveGeometryFeaturesImpl(m_context, m_cancellable, rect, scale);
|
||||
}
|
||||
|
||||
template <template <typename> class R, typename... Args>
|
||||
unique_ptr<coding::CompressedBitVector> Retrieval::Retrieve(Args &&... args)
|
||||
unique_ptr<coding::CompressedBitVector> Retrieval::Retrieve(Args &&... args) const
|
||||
{
|
||||
switch (m_format)
|
||||
{
|
||||
|
|
|
@ -38,28 +38,28 @@ public:
|
|||
// Following functions retrieve from the search index corresponding to
|
||||
// |value| all features matching to |request|.
|
||||
unique_ptr<coding::CompressedBitVector> RetrieveAddressFeatures(
|
||||
SearchTrieRequest<strings::UniStringDFA> const & request);
|
||||
SearchTrieRequest<strings::UniStringDFA> const & request) const;
|
||||
|
||||
unique_ptr<coding::CompressedBitVector> RetrieveAddressFeatures(
|
||||
SearchTrieRequest<strings::PrefixDFAModifier<strings::UniStringDFA>> const & request);
|
||||
SearchTrieRequest<strings::PrefixDFAModifier<strings::UniStringDFA>> const & request) const;
|
||||
|
||||
unique_ptr<coding::CompressedBitVector> RetrieveAddressFeatures(
|
||||
SearchTrieRequest<strings::LevenshteinDFA> const & request);
|
||||
SearchTrieRequest<strings::LevenshteinDFA> const & request) const;
|
||||
|
||||
unique_ptr<coding::CompressedBitVector> RetrieveAddressFeatures(
|
||||
SearchTrieRequest<strings::PrefixDFAModifier<strings::LevenshteinDFA>> const & request);
|
||||
SearchTrieRequest<strings::PrefixDFAModifier<strings::LevenshteinDFA>> const & request) const;
|
||||
|
||||
// Retrieves from the search index corresponding to |value| all
|
||||
// postcodes matching to |slice|.
|
||||
unique_ptr<coding::CompressedBitVector> RetrievePostcodeFeatures(TokenSlice const & slice);
|
||||
unique_ptr<coding::CompressedBitVector> RetrievePostcodeFeatures(TokenSlice const & slice) const;
|
||||
|
||||
// Retrieves from the geometry index corresponding to |value| all features belonging to |rect|.
|
||||
unique_ptr<coding::CompressedBitVector> RetrieveGeometryFeatures(m2::RectD const & rect,
|
||||
int scale);
|
||||
int scale) const;
|
||||
|
||||
private:
|
||||
template <template <typename> class R, typename... Args>
|
||||
unique_ptr<coding::CompressedBitVector> Retrieve(Args &&... args);
|
||||
unique_ptr<coding::CompressedBitVector> Retrieve(Args &&... args) const;
|
||||
|
||||
MwmContext const & m_context;
|
||||
my::Cancellable const & m_cancellable;
|
||||
|
|
|
@ -15,12 +15,13 @@
|
|||
#include "base/stl_helpers.hpp"
|
||||
#include "base/string_utils.hpp"
|
||||
|
||||
#include "std/algorithm.hpp"
|
||||
#include "std/set.hpp"
|
||||
#include "std/unordered_map.hpp"
|
||||
#include "std/vector.hpp"
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
using namespace search;
|
||||
using namespace std;
|
||||
using namespace strings;
|
||||
|
||||
namespace
|
||||
|
@ -51,7 +52,7 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
void AddLocality(string const & name, uint32_t featureId)
|
||||
void AddLocality(string const & name, uint32_t featureId, uint8_t rank = 0)
|
||||
{
|
||||
set<UniString> tokens;
|
||||
Delimiters delims;
|
||||
|
@ -61,6 +62,7 @@ public:
|
|||
m_searchIndex.Add(token, featureId);
|
||||
|
||||
m_names[featureId].push_back(name);
|
||||
m_ranks[featureId] = rank;
|
||||
}
|
||||
|
||||
Ids GetTopLocalities(size_t limit)
|
||||
|
@ -114,11 +116,24 @@ public:
|
|||
names.insert(names.end(), it->second.begin(), it->second.end());
|
||||
}
|
||||
|
||||
uint8_t GetRank(uint32_t /* featureId */) const override { return 0; }
|
||||
uint8_t GetRank(uint32_t featureId) const override
|
||||
{
|
||||
auto it = m_ranks.find(featureId);
|
||||
return it == m_ranks.end() ? 0 : it->second;
|
||||
}
|
||||
|
||||
CBV GetMatchedFeatures(strings::UniString const & token) const override
|
||||
{
|
||||
vector<uint64_t> ids;
|
||||
m_searchIndex.ForEachInNode(token, [&ids](uint32_t id) { ids.push_back(id); });
|
||||
my::SortUnique(ids);
|
||||
return CBV{coding::CompressedBitVectorBuilder::FromBitPositions(move(ids))};
|
||||
}
|
||||
|
||||
protected:
|
||||
QueryParams m_params;
|
||||
unordered_map<uint32_t, vector<string>> m_names;
|
||||
unordered_map<uint32_t, uint8_t> m_ranks;
|
||||
LocalityScorer m_scorer;
|
||||
|
||||
base::MemTrie<UniString, base::VectorValues<uint32_t>> m_searchIndex;
|
||||
|
@ -207,3 +222,23 @@ UNIT_CLASS_TEST(LocalityScorerTest, PrefixMatch)
|
|||
TEST_EQUAL(GetTopLocalities(2 /* limit */), Ids({ID_SAN_ANTONIO, ID_NEW_YORK}), ());
|
||||
TEST_EQUAL(GetTopLocalities(1 /* limit */), Ids({ID_SAN_ANTONIO}), ());
|
||||
}
|
||||
|
||||
UNIT_CLASS_TEST(LocalityScorerTest, Ranks)
|
||||
{
|
||||
enum
|
||||
{
|
||||
ID_SAN_MARINO,
|
||||
ID_SAN_ANTONIO,
|
||||
ID_SAN_FRANCISCO
|
||||
};
|
||||
|
||||
AddLocality("San Marino", ID_SAN_MARINO, 10 /* rank */);
|
||||
AddLocality("Citta di San Antonio", ID_SAN_ANTONIO, 20 /* rank */);
|
||||
AddLocality("San Francisco", ID_SAN_FRANCISCO, 30 /* rank */);
|
||||
|
||||
InitParams("San", false /* lastTokenIsPrefix */);
|
||||
TEST_EQUAL(GetTopLocalities(100 /* limit */),
|
||||
Ids({ID_SAN_MARINO, ID_SAN_ANTONIO, ID_SAN_FRANCISCO}), ());
|
||||
TEST_EQUAL(GetTopLocalities(2 /* limit */), Ids({ID_SAN_MARINO, ID_SAN_FRANCISCO}), ());
|
||||
TEST_EQUAL(GetTopLocalities(1 /* limit */), Ids({ID_SAN_FRANCISCO}), ());
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue