diff --git a/search/doc_vec.hpp b/search/doc_vec.hpp index c1f7eb19be..cb99a3ffa4 100644 --- a/search/doc_vec.hpp +++ b/search/doc_vec.hpp @@ -38,7 +38,7 @@ struct TokenWeightPair } // Returns squared weight of the token-weight pair. - double Sqr() const { return m_weight * m_weight; } + double SqrWeight() const { return m_weight * m_weight; } Token m_token; double m_weight = 0; @@ -54,23 +54,30 @@ std::string DebugPrint(TokenWeightPair const & tw) namespace impl { -// Accumulates weights of equal tokens in |tws|. Result is sorted by tokens. +// Accumulates weights of equal tokens in |tws|. Result is sorted by +// tokens. Also, maximum weight from a group of equal tokens will be +// stored in the corresponding |maxWeight| elem. template -void SortAndMerge(std::vector> & tws) +void SortAndMerge(std::vector> & tws, std::vector & maxWeights) { std::sort(tws.begin(), tws.end()); size_t n = 0; + maxWeights.clear(); for (size_t i = 0; i < tws.size(); ++i) { ASSERT_LESS_OR_EQUAL(n, i, ()); + ASSERT_EQUAL(n, maxWeights.size(), ()); + if (n == 0 || tws[n - 1].m_token != tws[i].m_token) { tws[n].Swap(tws[i]); + maxWeights.push_back(tws[n].m_weight); ++n; } else { tws[n - 1].m_weight += tws[i].m_weight; + maxWeights[n - 1] = std::max(maxWeights[n - 1], tws[i].m_weight); } } @@ -84,7 +91,7 @@ double SqrL2(std::vector> const & tws) { double sum = 0; for (auto const & tw : tws) - sum += tw.Sqr(); + sum += tw.SqrWeight(); return sum; } @@ -94,7 +101,7 @@ double SqrL2(std::vector> const & tws, boost::optional> const & prefix) { double result = SqrL2(tws); - return result + (prefix ? prefix->Sqr() : 0); + return result + (prefix ? prefix->SqrWeight() : 0); } } // namespace impl @@ -126,6 +133,7 @@ public: explicit DocVec(Builder const & builder) : m_tws(builder.m_tws) { Init(); } TokenWeightPairs const & GetTokenWeightPairs() const { return m_tws; } + std::vector const & GetMaxWeights() const { return m_maxWeights; } bool Empty() const { return m_tws.empty(); } @@ -136,9 +144,10 @@ private: return "DocVec " + DebugPrint(dv.m_tws); } - void Init() { impl::SortAndMerge(m_tws); } + void Init() { impl::SortAndMerge(m_tws, m_maxWeights); } TokenWeightPairs m_tws; + std::vector m_maxWeights; }; // This class represents a search query in a vector space of tokens. @@ -197,6 +206,7 @@ public: auto const & ls = m_tws; auto const & rs = rhs.GetTokenWeightPairs(); + auto const & maxWeights = rhs.GetMaxWeights(); ASSERT(std::is_sorted(ls.begin(), ls.end()), ()); ASSERT(std::is_sorted(rs.begin(), rs.end()), ()); @@ -258,29 +268,22 @@ public: // query, we need to update it's weight in the cosine distance // - so we need to update correspondingly dot product and // vector norms of query and doc. + auto const w = maxWeights[j]; + auto const l = std::max(0.0, ln - prefix.SqrWeight() + w * w); - // This is the hacky moment: weight of query prefix token may - // be greater than the weight of the corresponding document - // token, because the weight of the document token may be - // unknown at the moment, and be set to some default value. - // But this heuristic works nicely in practice. - double const w = std::max(prefix.m_weight, tw.m_weight); - auto const sqrW = w * w; - double const l = std::max(0.0, ln - prefix.Sqr() + sqrW); - double const r = std::max(0.0, rn - tw.Sqr() + sqrW); - - nom = dot + sqrW; - denom = sqrt(l) * sqrt(r); + nom = dot + w * tw.m_weight; + denom = sqrt(l) * sqrt(rn); } else { // If this document token is already matched with |i|-th full - // token in a query - we here that completion of the prefix - // token is the |i|-th token. So we need to update + // token in a query - we know that completion of the prefix + // token is the |i|-th query token. So we need to update // correspondingly dot product and vector norm of the query. - double const l = ln + 2 * ls[i].m_weight * prefix.m_weight; + auto const w = ls[i].m_weight + m_maxWeights[i]; + auto const l = ln - ls[i].SqrWeight() - prefix.SqrWeight() + w * w; - nom = dot + prefix.m_weight * tw.m_weight; + nom = dot + (w - ls[i].m_weight) * tw.m_weight; denom = sqrt(l) * sqrt(rn); } @@ -310,9 +313,10 @@ private: return "QueryVec " + DebugPrint(qv.m_tws); } - void Init() { impl::SortAndMerge(m_tws); } + void Init() { impl::SortAndMerge(m_tws, m_maxWeights); } std::vector m_tws; + std::vector m_maxWeights; boost::optional m_prefix; }; } // namespace search diff --git a/search/geocoder.cpp b/search/geocoder.cpp index c5d0162fe0..6b6bc3f268 100644 --- a/search/geocoder.cpp +++ b/search/geocoder.cpp @@ -157,8 +157,13 @@ private: class LocalityScorerDelegate : public LocalityScorer::Delegate { public: - LocalityScorerDelegate(MwmContext const & context, Geocoder::Params const & params) - : m_context(context), m_params(params), m_ranks(m_context.m_value) + LocalityScorerDelegate(MwmContext const & context, Geocoder::Params const & params, + my::Cancellable const & cancellable) + : m_context(context) + , m_params(params) + , m_cancellable(cancellable) + , m_retrieval(m_context, m_cancellable) + , m_ranks(m_context.m_value) { } @@ -178,9 +183,22 @@ public: uint8_t GetRank(uint32_t featureId) const override { return m_ranks.Get(featureId); } + CBV GetMatchedFeatures(strings::UniString const & token) const override + { + SearchTrieRequest request; + request.m_names.emplace_back(token); + request.SetLangs(m_params.GetLangs()); + + return CBV{m_retrieval.RetrieveAddressFeatures(request)}; + } + private: MwmContext const & m_context; Geocoder::Params const & m_params; + my::Cancellable const & m_cancellable; + + Retrieval m_retrieval; + LazyRankTable m_ranks; }; @@ -624,7 +642,7 @@ void Geocoder::FillLocalityCandidates(BaseContext const & ctx, CBV const & filte return; } - LocalityScorerDelegate delegate(*m_context, m_params); + LocalityScorerDelegate delegate(*m_context, m_params, m_cancellable); LocalityScorer scorer(m_params, delegate); scorer.GetTopLocalities(m_context->GetId(), ctx, filter, maxNumLocalities, preLocalities); } diff --git a/search/idf_map.cpp b/search/idf_map.cpp index 28798d20bd..bb88659e42 100644 --- a/search/idf_map.cpp +++ b/search/idf_map.cpp @@ -2,11 +2,21 @@ namespace search { -IdfMap::IdfMap(double unknownIdf): m_unknownIdf(unknownIdf) {} +IdfMap::IdfMap(Delegate & delegate, double unknownIdf) + : m_delegate(delegate), m_unknownIdf(unknownIdf) +{ +} -double IdfMap::Get(strings::UniString const & s) const +double IdfMap::Get(strings::UniString const & s) { auto const it = m_idfs.find(s); - return it == m_idfs.cend() ? m_unknownIdf : it->second; + if (it != m_idfs.cend()) + return it->second; + + auto const df = static_cast(m_delegate.GetNumDocs(s)); + auto const idf = df == 0 ? m_unknownIdf : 1.0 / df; + m_idfs[s] = idf; + + return idf; } } // namespace search diff --git a/search/idf_map.hpp b/search/idf_map.hpp index aa3709293f..ad967b895b 100644 --- a/search/idf_map.hpp +++ b/search/idf_map.hpp @@ -2,6 +2,7 @@ #include "base/string_utils.hpp" +#include #include namespace search @@ -9,13 +10,22 @@ namespace search class IdfMap { public: - explicit IdfMap(double unknownIdf); + struct Delegate + { + virtual ~Delegate() = default; + + virtual uint64_t GetNumDocs(strings::UniString const & token) const = 0; + }; + + IdfMap(Delegate & delegate, double unknownIdf); void Set(strings::UniString const & s, double idf) { m_idfs[s] = idf; } - double Get(strings::UniString const & s) const; + double Get(strings::UniString const & s); private: std::map m_idfs; + + Delegate & m_delegate; double m_unknownIdf; }; } // namespace search diff --git a/search/locality_scorer.cpp b/search/locality_scorer.cpp index 1046da4de6..f62bd002c5 100644 --- a/search/locality_scorer.cpp +++ b/search/locality_scorer.cpp @@ -19,6 +19,27 @@ using namespace std; namespace search { +namespace +{ +struct IdfMapDelegate : public IdfMap::Delegate +{ + IdfMapDelegate(LocalityScorer::Delegate const & delegate, CBV const & filter) + : m_delegate(delegate), m_filter(filter) + { + } + + ~IdfMapDelegate() override = default; + + uint64_t GetNumDocs(strings::UniString const & token) const override + { + return m_filter.Intersect(m_delegate.GetMatchedFeatures(token)).PopCount(); + } + + LocalityScorer::Delegate const & m_delegate; + CBV const & m_filter; +}; +} // namespace + // static size_t const LocalityScorer::kDefaultReadLimit = 100; @@ -46,12 +67,25 @@ void LocalityScorer::GetTopLocalities(MwmSet::MwmId const & countryId, BaseConte for (size_t i = 0; i < ctx.m_numTokens; ++i) intersections[i] = filter.Intersect(ctx.m_features[i]); - IdfMap idfs(1.0 /* unknownIdf */); + IdfMapDelegate delegate(m_delegate, filter); + IdfMap idfs(delegate, 1.0 /* unknownIdf */); + double prefixIdf = 1.0; for (size_t i = 0; i < ctx.m_numTokens; ++i) { - auto const idf = 1.0 / static_cast(intersections[i].PopCount()); - // IDF should be the same for the token and its synonyms. - m_params.GetToken(i).ForEach([&idfs, &idf](strings::UniString const & s) { idfs.Set(s, idf); }); + auto const numDocs = intersections[i].PopCount(); + double idf = 1.0; + if (numDocs > 0) + idf = 1.0 / static_cast(numDocs); + + if (m_params.IsPrefixToken(i)) + { + prefixIdf = idf; + } + else + { + m_params.GetToken(i).ForEach( + [&idfs, &idf](strings::UniString const & s) { idfs.Set(s, idf); }); + } } for (size_t startToken = 0; startToken < ctx.m_numTokens; ++startToken) @@ -64,11 +98,10 @@ void LocalityScorer::GetTopLocalities(MwmSet::MwmId const & countryId, BaseConte { auto const curToken = endToken - 1; auto const & token = m_params.GetToken(curToken).m_original; - double const weight = idfs.Get(token); if (m_params.IsPrefixToken(curToken)) - builder.SetPrefix(token, weight); + builder.SetPrefix(token, prefixIdf); else - builder.AddFull(token, weight); + builder.AddFull(token, idfs.Get(token)); TokenRange const tokenRange(startToken, endToken); // Skip locality candidates that match only numbers. @@ -88,7 +121,7 @@ void LocalityScorer::GetTopLocalities(MwmSet::MwmId const & countryId, BaseConte LeaveTopLocalities(idfs, limit, localities); } -void LocalityScorer::LeaveTopLocalities(IdfMap const & idfs, size_t limit, +void LocalityScorer::LeaveTopLocalities(IdfMap & idfs, size_t limit, vector & localities) const { vector els; @@ -179,8 +212,7 @@ void LocalityScorer::LeaveTopBySimilarityAndRank(size_t limit, vector & dvs) const +void LocalityScorer::GetDocVecs(IdfMap & idfs, uint32_t localityId, vector & dvs) const { vector names; m_delegate.GetNames(localityId, names); diff --git a/search/locality_scorer.hpp b/search/locality_scorer.hpp index d68a0d93d6..6a1c6201ca 100644 --- a/search/locality_scorer.hpp +++ b/search/locality_scorer.hpp @@ -1,5 +1,6 @@ #pragma once +#include "search/cbv.hpp" #include "search/geocoder_locality.hpp" #include "search/ranking_utils.hpp" @@ -29,6 +30,7 @@ public: virtual void GetNames(uint32_t featureId, std::vector & names) const = 0; virtual uint8_t GetRank(uint32_t featureId) const = 0; + virtual CBV GetMatchedFeatures(strings::UniString const & token) const = 0; }; LocalityScorer(QueryParams const & params, Delegate const & delegate); @@ -59,8 +61,7 @@ private: // Leaves at most |limit| elements of |localities|, ordered by some // combination of ranks and number of matched tokens. - void LeaveTopLocalities(IdfMap const & idfs, size_t limit, - std::vector & localities) const; + void LeaveTopLocalities(IdfMap & idfs, size_t limit, std::vector & localities) const; // Selects at most |limitUniqueIds| best features by query norm and // rank, and then leaves only localities corresponding to those @@ -71,7 +72,7 @@ private: // and rank. Result doesn't contain duplicate features. void LeaveTopBySimilarityAndRank(size_t limit, std::vector & els) const; - void GetDocVecs(IdfMap const & idfs, uint32_t localityId, vector & dvs) const; + void GetDocVecs(IdfMap & idfs, uint32_t localityId, std::vector & dvs) const; double GetSimilarity(QueryVec const & qv, std::vector const & dvs) const; QueryParams const & m_params; diff --git a/search/retrieval.cpp b/search/retrieval.cpp index 7af36d0622..3656da1df2 100644 --- a/search/retrieval.cpp +++ b/search/retrieval.cpp @@ -319,43 +319,43 @@ Retrieval::Retrieval(MwmContext const & context, my::Cancellable const & cancell } unique_ptr Retrieval::RetrieveAddressFeatures( - SearchTrieRequest const & request) + SearchTrieRequest const & request) const { return Retrieve(request); } unique_ptr Retrieval::RetrieveAddressFeatures( - SearchTrieRequest> const & request) + SearchTrieRequest> const & request) const { return Retrieve(request); } unique_ptr Retrieval::RetrieveAddressFeatures( - SearchTrieRequest const & request) + SearchTrieRequest const & request) const { return Retrieve(request); } unique_ptr Retrieval::RetrieveAddressFeatures( - SearchTrieRequest> const & request) + SearchTrieRequest> const & request) const { return Retrieve(request); } unique_ptr Retrieval::RetrievePostcodeFeatures( - TokenSlice const & slice) + TokenSlice const & slice) const { return Retrieve(slice); } unique_ptr Retrieval::RetrieveGeometryFeatures(m2::RectD const & rect, - int scale) + int scale) const { return RetrieveGeometryFeaturesImpl(m_context, m_cancellable, rect, scale); } template