diff --git a/search/CMakeLists.txt b/search/CMakeLists.txt index 53d6ec8483..6e919c22b0 100644 --- a/search/CMakeLists.txt +++ b/search/CMakeLists.txt @@ -20,6 +20,7 @@ set( common.hpp displayed_categories.cpp displayed_categories.hpp + doc_vec.cpp doc_vec.hpp downloader_search_callback.cpp downloader_search_callback.hpp diff --git a/search/doc_vec.cpp b/search/doc_vec.cpp new file mode 100644 index 0000000000..a19cd2a4f9 --- /dev/null +++ b/search/doc_vec.cpp @@ -0,0 +1,234 @@ +#include "search/doc_vec.hpp" + +#include "base/logging.hpp" + +#include + +using namespace std; + +namespace search +{ +namespace +{ +// Accumulates frequencies of equal tokens in |tfs|. Result is sorted +// by tokens. +void SortAndMerge(vector tokens, vector & tfs) +{ + ASSERT(tfs.empty(), ()); + sort(tokens.begin(), tokens.end()); + for (size_t i = 0; i < tokens.size(); ++i) + { + if (tfs.empty() || tfs.back().m_token != tokens[i]) + tfs.emplace_back(tokens[i], 1 /* frequency */); + else + ++tfs.back().m_frequency; + } +} + +double GetTfIdf(double tf, double idf) { return tf * idf; } + +double GetWeightImpl(IdfMap & idfs, TokenFrequencyPair const & tf, bool isPrefix) +{ + return GetTfIdf(tf.m_frequency, idfs.Get(tf.m_token, isPrefix)); +} + +double GetSqrWeightImpl(IdfMap & idfs, TokenFrequencyPair const & tf, bool isPrefix) +{ + auto const w = GetWeightImpl(idfs, tf, isPrefix); + return w * w; +} + +// Computes squared L2 norm of vector of tokens. +double SqrL2(IdfMap & idfs, vector const & tfs) +{ + double sum = 0; + for (auto const & tf : tfs) + sum += GetSqrWeightImpl(idfs, tf, false /* isPrefix */); + return sum; +} + +// Computes squared L2 norm of vector of tokens + prefix token. +double SqrL2(IdfMap & idfs, vector const & tfs, + boost::optional const & prefix) +{ + auto result = SqrL2(idfs, tfs); + if (prefix) + { + result += + GetSqrWeightImpl(idfs, TokenFrequencyPair(*prefix, 1 /* frequency */), true /* isPrefix */); + } + return result; +} +} // namespace + +// TokenFrequencyPair ------------------------------------------------------------------------------ +bool TokenFrequencyPair::operator<(TokenFrequencyPair const & rhs) const +{ + if (m_token != rhs.m_token) + return m_token < rhs.m_token; + return m_frequency < rhs.m_frequency; +} + +void TokenFrequencyPair::Swap(TokenFrequencyPair & rhs) +{ + m_token.swap(rhs.m_token); + swap(m_frequency, rhs.m_frequency); +} + +string DebugPrint(TokenFrequencyPair const & tf) +{ + ostringstream os; + os << "TokenFrequencyPair [" << DebugPrint(tf.m_token) << ", " << tf.m_frequency << "]"; + return os.str(); +} + +// DocVec ------------------------------------------------------------------------------------------ +DocVec::DocVec(IdfMap & idfs, Builder const & builder) : m_idfs(&idfs) +{ + SortAndMerge(builder.m_tokens, m_tfs); +} + +double DocVec::Norm() { return SqrL2(*m_idfs, m_tfs); } + +strings::UniString const & DocVec::GetToken(size_t i) const +{ + ASSERT_LESS(i, m_tfs.size(), ()); + return m_tfs[i].m_token; +} + +double DocVec::GetIdf(size_t i) +{ + ASSERT_LESS(i, m_tfs.size(), ()); + return m_idfs->Get(m_tfs[i].m_token, false /* isPrefix */); +} + +double DocVec::GetWeight(size_t i) +{ + ASSERT_LESS(i, m_tfs.size(), ()); + return GetWeightImpl(*m_idfs, m_tfs[i], false /* isPrefix */); +} + +// QueryVec ---------------------------------------------------------------------------------------- +QueryVec::QueryVec(IdfMap & idfs, Builder const & builder) + : m_idfs(&idfs), m_prefix(builder.m_prefix) +{ + SortAndMerge(builder.m_tokens, m_tfs); +} + +double QueryVec::Similarity(DocVec & rhs) +{ + size_t kInvalidIndex = numeric_limits::max(); + + if (Empty() && rhs.Empty()) + return 1.0; + + if (Empty() || rhs.Empty()) + return 0.0; + + vector rsMatchTo(rhs.GetNumTokens(), kInvalidIndex); + + double dot = 0; + { + size_t i = 0, j = 0; + + while (i < m_tfs.size() && j < rhs.GetNumTokens()) + { + auto const & lt = m_tfs[i].m_token; + auto const & rt = rhs.GetToken(j); + + if (lt < rt) + { + ++i; + } + else if (lt > rt) + { + ++j; + } + else + { + dot += GetFullTokenWeight(i) * rhs.GetWeight(j); + rsMatchTo[j] = i; + ++i; + ++j; + } + } + } + + auto const ln = Norm(); + auto const rn = rhs.Norm(); + + // This similarity metric assumes that prefix is not matched in the document. + double const similarityNoPrefix = ln > 0 && rn > 0 ? dot / sqrt(ln) / sqrt(rn) : 0; + + if (!m_prefix) + return similarityNoPrefix; + + double similarityWithPrefix = 0; + auto const & prefix = *m_prefix; + + // Let's try to match prefix token with all tokens in the + // document, and compute the best cosine distance. + for (size_t j = 0; j < rhs.GetNumTokens(); ++j) + { + auto const & t = rhs.GetToken(j); + if (!strings::StartsWith(t.begin(), t.end(), prefix.begin(), prefix.end())) + continue; + + auto const i = rsMatchTo[j]; + + double num = 0; + double denom = 0; + if (i == kInvalidIndex) + { + // If this document token is not matched with full tokens in a + // query, we need to update its weight in the cosine distance + // - so we need to update correspondingly dot product and + // vector norms of query and doc. + auto const oldW = GetPrefixTokenWeight(); + auto const newW = GetTfIdf(1 /* frequency */, rhs.GetIdf(j)); + auto const l = max(0.0, ln - oldW * oldW + newW * newW); + + num = dot + newW * rhs.GetWeight(j); + denom = sqrt(l) * sqrt(rn); + } + else + { + // If this document token is already matched with |i|-th full + // token in a query - we know that completion of the prefix + // token is the |i|-th query token. So we need to update + // correspondingly dot product and vector norm of the query. + auto const oldFW = GetFullTokenWeight(i); + auto const oldPW = GetPrefixTokenWeight(); + + auto const tf = m_tfs[i].m_frequency + 1; + auto const idf = m_idfs->Get(m_tfs[i].m_token, false /* isPrefix */); + auto const newW = GetTfIdf(tf, idf); + + auto const l = ln - oldFW * oldFW - oldPW * oldPW + newW * newW; + + num = dot + (newW - oldFW) * rhs.GetWeight(j); + denom = sqrt(l) * sqrt(rn); + } + + if (denom > 0) + similarityWithPrefix = max(similarityWithPrefix, num / denom); + } + + return max(similarityWithPrefix, similarityNoPrefix); +} + +double QueryVec::Norm() { return SqrL2(*m_idfs, m_tfs, m_prefix); } + +double QueryVec::GetFullTokenWeight(size_t i) +{ + ASSERT_LESS(i, m_tfs.size(), ()); + return GetWeightImpl(*m_idfs, m_tfs[i], false /* isPrefix */); +} + +double QueryVec::GetPrefixTokenWeight() +{ + ASSERT(m_prefix, ()); + return GetWeightImpl(*m_idfs, TokenFrequencyPair(*m_prefix, 1 /* frequency */), + true /* isPrefix */); +} +} // namespace search diff --git a/search/doc_vec.hpp b/search/doc_vec.hpp index cb99a3ffa4..c4b54492c2 100644 --- a/search/doc_vec.hpp +++ b/search/doc_vec.hpp @@ -1,11 +1,14 @@ #pragma once +#include "search/idf_map.hpp" + #include "base/assert.hpp" #include "base/string_utils.hpp" #include #include #include +#include #include #include #include @@ -14,309 +17,125 @@ namespace search { -template -struct TokenWeightPair +class IdfMap; + +struct TokenFrequencyPair { - TokenWeightPair() = default; + TokenFrequencyPair() = default; - template - TokenWeightPair(T && token, double weight) : m_token(std::forward(token)), m_weight(weight) + template + TokenFrequencyPair(Token && token, uint64_t frequency) + : m_token(std::forward(token)), m_frequency(frequency) { } - bool operator<(TokenWeightPair const & rhs) const - { - if (m_token != rhs.m_token) - return m_token < rhs.m_token; - return m_weight < rhs.m_weight; - } + bool operator<(TokenFrequencyPair const & rhs) const; - void Swap(TokenWeightPair & rhs) - { - m_token.swap(rhs.m_token); - std::swap(m_weight, rhs.m_weight); - } + void Swap(TokenFrequencyPair & rhs); - // Returns squared weight of the token-weight pair. - double SqrWeight() const { return m_weight * m_weight; } - - Token m_token; - double m_weight = 0; + strings::UniString m_token; + uint64_t m_frequency = 0; }; -template -std::string DebugPrint(TokenWeightPair const & tw) -{ - std::ostringstream os; - os << "TokenWeightPair [ " << DebugPrint(tw.m_token) << ", " << tw.m_weight << " ]"; - return os.str(); -} - -namespace impl -{ -// Accumulates weights of equal tokens in |tws|. Result is sorted by -// tokens. Also, maximum weight from a group of equal tokens will be -// stored in the corresponding |maxWeight| elem. -template -void SortAndMerge(std::vector> & tws, std::vector & maxWeights) -{ - std::sort(tws.begin(), tws.end()); - size_t n = 0; - maxWeights.clear(); - for (size_t i = 0; i < tws.size(); ++i) - { - ASSERT_LESS_OR_EQUAL(n, i, ()); - ASSERT_EQUAL(n, maxWeights.size(), ()); - - if (n == 0 || tws[n - 1].m_token != tws[i].m_token) - { - tws[n].Swap(tws[i]); - maxWeights.push_back(tws[n].m_weight); - ++n; - } - else - { - tws[n - 1].m_weight += tws[i].m_weight; - maxWeights[n - 1] = std::max(maxWeights[n - 1], tws[i].m_weight); - } - } - - ASSERT_LESS_OR_EQUAL(n, tws.size(), ()); - tws.erase(tws.begin() + n, tws.end()); -} - -// Computes squared L2 norm of vector of tokens. -template -double SqrL2(std::vector> const & tws) -{ - double sum = 0; - for (auto const & tw : tws) - sum += tw.SqrWeight(); - return sum; -} - -// Computes squared L2 norm of vector of tokens + prefix token. -template -double SqrL2(std::vector> const & tws, - boost::optional> const & prefix) -{ - double result = SqrL2(tws); - return result + (prefix ? prefix->SqrWeight() : 0); -} -} // namespace impl +std::string DebugPrint(TokenFrequencyPair const & tf); // This class represents a document in a vector space of tokens. -template class DocVec { public: - using TokenWeightPair = TokenWeightPair; - using TokenWeightPairs = std::vector; - class Builder { public: - template - void Add(T && token, double weight) + template + void Add(Token && token) { - m_tws.emplace_back(std::forward(token), weight); + m_tokens.emplace_back(std::forward(token)); } private: friend class DocVec; - TokenWeightPairs m_tws; + std::vector m_tokens; }; - DocVec() = default; - explicit DocVec(Builder && builder) : m_tws(std::move(builder.m_tws)) { Init(); } - explicit DocVec(Builder const & builder) : m_tws(builder.m_tws) { Init(); } + explicit DocVec(IdfMap & idfs) : m_idfs(&idfs) {} - TokenWeightPairs const & GetTokenWeightPairs() const { return m_tws; } - std::vector const & GetMaxWeights() const { return m_maxWeights; } + DocVec(IdfMap & idfs, Builder const & builder); - bool Empty() const { return m_tws.empty(); } + // Computes vector norm of the doc. + double Norm(); + + size_t GetNumTokens() const { return m_tfs.size(); } + + strings::UniString const & GetToken(size_t i) const; + double GetIdf(size_t i); + double GetWeight(size_t i); + + bool Empty() const { return m_tfs.empty(); } private: - template - friend std::string DebugPrint(DocVec const & dv) + friend std::string DebugPrint(DocVec const & dv) { - return "DocVec " + DebugPrint(dv.m_tws); + return "DocVec " + ::DebugPrint(dv.m_tfs); } - void Init() { impl::SortAndMerge(m_tws, m_maxWeights); } - - TokenWeightPairs m_tws; - std::vector m_maxWeights; + IdfMap * m_idfs; + std::vector m_tfs; }; // This class represents a search query in a vector space of tokens. -template class QueryVec { public: - using TokenWeightPair = TokenWeightPair; - using TokenWeightPairs = std::vector; - class Builder { public: - template - void AddFull(T && token, double weight) + template + void AddFull(Token && token) { - m_tws.emplace_back(std::forward(token), weight); + m_tokens.emplace_back(std::forward(token)); } - template - void SetPrefix(T && token, double weight) + template + void SetPrefix(Token && token) { - m_prefix = TokenWeightPair(std::forward(token), weight); + m_prefix = std::forward(token); } private: friend class QueryVec; - TokenWeightPairs m_tws; - boost::optional m_prefix; + std::vector m_tokens; + boost::optional m_prefix; }; - QueryVec() = default; + explicit QueryVec(IdfMap & idfs) : m_idfs(&idfs) {} - explicit QueryVec(Builder && builder) - : m_tws(std::move(builder.m_tws)), m_prefix(std::move(builder.m_prefix)) - { - Init(); - } + QueryVec(IdfMap & idfs, Builder const & builder); - explicit QueryVec(Builder const & builder) : m_tws(builder.m_tws), m_prefix(builder.m_prefix) - { - Init(); - } + // Computes cosine similarity between |*this| and |rhs|. + double Similarity(DocVec & rhs); - // Computes cosine distance between |*this| and |rhs|. - double Similarity(DocVec const & rhs) const - { - size_t kInvalidIndex = std::numeric_limits::max(); + // Computes vector norm of the query. + double Norm(); - if (Empty() && rhs.Empty()) - return 1.0; - - if (Empty() || rhs.Empty()) - return 0.0; - - auto const & ls = m_tws; - auto const & rs = rhs.GetTokenWeightPairs(); - auto const & maxWeights = rhs.GetMaxWeights(); - - ASSERT(std::is_sorted(ls.begin(), ls.end()), ()); - ASSERT(std::is_sorted(rs.begin(), rs.end()), ()); - - std::vector rsMatchTo(rs.size(), kInvalidIndex); - - size_t i = 0, j = 0; - double dot = 0; - while (i < ls.size() && j < rs.size()) - { - if (ls[i].m_token < rs[j].m_token) - { - ++i; - } - else if (ls[i].m_token > rs[j].m_token) - { - ++j; - } - else - { - dot += ls[i].m_weight * rs[j].m_weight; - rsMatchTo[j] = i; - ++i; - ++j; - } - } - - auto const ln = impl::SqrL2(ls, m_prefix); - auto const rn = impl::SqrL2(rs); - - // This similarity metric assumes that prefix is not matched in the document. - double const similarityNoPrefix = ln > 0 && rn > 0 ? dot / sqrt(ln) / sqrt(rn) : 0; - - if (!m_prefix) - return similarityNoPrefix; - - double similarityWithPrefix = 0; - - auto const & prefix = *m_prefix; - - // Let's try to match prefix token with all tokens in the - // document, and compute the best cosine distance. - for (size_t j = 0; j < rs.size(); ++j) - { - auto const & tw = rs[j]; - if (!strings::StartsWith(tw.m_token.begin(), tw.m_token.end(), prefix.m_token.begin(), - prefix.m_token.end())) - { - continue; - } - - auto const i = rsMatchTo[j]; - - double nom = 0; - double denom = 0; - if (i == kInvalidIndex) - { - // If this document token is not matched with full tokens in a - // query, we need to update it's weight in the cosine distance - // - so we need to update correspondingly dot product and - // vector norms of query and doc. - auto const w = maxWeights[j]; - auto const l = std::max(0.0, ln - prefix.SqrWeight() + w * w); - - nom = dot + w * tw.m_weight; - denom = sqrt(l) * sqrt(rn); - } - else - { - // If this document token is already matched with |i|-th full - // token in a query - we know that completion of the prefix - // token is the |i|-th query token. So we need to update - // correspondingly dot product and vector norm of the query. - auto const w = ls[i].m_weight + m_maxWeights[i]; - auto const l = ln - ls[i].SqrWeight() - prefix.SqrWeight() + w * w; - - nom = dot + (w - ls[i].m_weight) * tw.m_weight; - denom = sqrt(l) * sqrt(rn); - } - - if (denom > 0) - similarityWithPrefix = std::max(similarityWithPrefix, nom / denom); - } - - return std::max(similarityWithPrefix, similarityNoPrefix); - } - - double Norm() const - { - double n = 0; - for (auto const & tw : m_tws) - n += tw.m_weight * tw.m_weight; - if (m_prefix) - n += m_prefix->m_weight * m_prefix->m_weight; - return sqrt(n); - } - - bool Empty() const { return m_tws.empty() && !m_prefix; } + bool Empty() const { return m_tfs.empty() && !m_prefix; } private: - template - friend std::string DebugPrint(QueryVec const & qv) + double GetFullTokenWeight(size_t i); + double GetPrefixTokenWeight(); + + friend std::string DebugPrint(QueryVec const & qv) { - return "QueryVec " + DebugPrint(qv.m_tws); + std::ostringstream os; + os << "QueryVec " + ::DebugPrint(qv.m_tfs); + if (qv.m_prefix) + os << " " << DebugPrint(*qv.m_prefix); + return os.str(); } - void Init() { impl::SortAndMerge(m_tws, m_maxWeights); } - - std::vector m_tws; - std::vector m_maxWeights; - boost::optional m_prefix; + IdfMap * m_idfs; + std::vector m_tfs; + boost::optional m_prefix; }; } // namespace search diff --git a/search/geocoder.cpp b/search/geocoder.cpp index 6b6bc3f268..01cb38cb4e 100644 --- a/search/geocoder.cpp +++ b/search/geocoder.cpp @@ -183,13 +183,22 @@ public: uint8_t GetRank(uint32_t featureId) const override { return m_ranks.Get(featureId); } - CBV GetMatchedFeatures(strings::UniString const & token) const override + CBV GetMatchedFeatures(strings::UniString const & token, bool isPrefix) const override { - SearchTrieRequest request; - request.m_names.emplace_back(token); - request.SetLangs(m_params.GetLangs()); - - return CBV{m_retrieval.RetrieveAddressFeatures(request)}; + if (isPrefix) + { + SearchTrieRequest> request; + request.m_names.emplace_back(strings::UniStringDFA(token)); + request.SetLangs(m_params.GetLangs()); + return CBV{m_retrieval.RetrieveAddressFeatures(request)}; + } + else + { + SearchTrieRequest request; + request.m_names.emplace_back(token); + request.SetLangs(m_params.GetLangs()); + return CBV{m_retrieval.RetrieveAddressFeatures(request)}; + } } private: diff --git a/search/geocoder_locality.hpp b/search/geocoder_locality.hpp index 42fb86f887..9a3ba52147 100644 --- a/search/geocoder_locality.hpp +++ b/search/geocoder_locality.hpp @@ -16,19 +16,17 @@ namespace search { +class IdfMap; + struct Locality { - using QueryVec = QueryVec; - - Locality() = default; - Locality(MwmSet::MwmId const & countryId, uint32_t featureId, TokenRange const & tokenRange, QueryVec const & queryVec) : m_countryId(countryId), m_featureId(featureId), m_tokenRange(tokenRange), m_queryVec(queryVec) { } - double QueryNorm() const { return m_queryVec.Norm(); } + double QueryNorm() { return m_queryVec.Norm(); } MwmSet::MwmId m_countryId; uint32_t m_featureId = 0; diff --git a/search/idf_map.cpp b/search/idf_map.cpp index bb88659e42..bbc1d4254f 100644 --- a/search/idf_map.cpp +++ b/search/idf_map.cpp @@ -1,21 +1,24 @@ #include "search/idf_map.hpp" +#include "base/assert.hpp" + namespace search { IdfMap::IdfMap(Delegate & delegate, double unknownIdf) : m_delegate(delegate), m_unknownIdf(unknownIdf) { + ASSERT_GREATER(m_unknownIdf, 0.0, ()); } -double IdfMap::Get(strings::UniString const & s) +double IdfMap::GetImpl(Map & idfs, strings::UniString const & s, bool isPrefix) { - auto const it = m_idfs.find(s); - if (it != m_idfs.cend()) + auto const it = idfs.find(s); + if (it != idfs.cend()) return it->second; - auto const df = static_cast(m_delegate.GetNumDocs(s)); + auto const df = static_cast(m_delegate.GetNumDocs(s, isPrefix)); auto const idf = df == 0 ? m_unknownIdf : 1.0 / df; - m_idfs[s] = idf; + idfs[s] = idf; return idf; } diff --git a/search/idf_map.hpp b/search/idf_map.hpp index ad967b895b..2e902d0c18 100644 --- a/search/idf_map.hpp +++ b/search/idf_map.hpp @@ -14,16 +14,29 @@ public: { virtual ~Delegate() = default; - virtual uint64_t GetNumDocs(strings::UniString const & token) const = 0; + virtual uint64_t GetNumDocs(strings::UniString const & token, bool isPrefix) const = 0; }; IdfMap(Delegate & delegate, double unknownIdf); - void Set(strings::UniString const & s, double idf) { m_idfs[s] = idf; } - double Get(strings::UniString const & s); + void Set(strings::UniString const & s, bool isPrefix, double idf) + { + SetImpl(isPrefix ? m_prefixIdfs : m_fullIdfs, s, idf); + } + + double Get(strings::UniString const & s, bool isPrefix) + { + return GetImpl(isPrefix ? m_prefixIdfs : m_fullIdfs, s, isPrefix); + } private: - std::map m_idfs; + using Map = std::map; + + void SetImpl(Map & idfs, strings::UniString const & s, double idf) { idfs[s] = idf; } + double GetImpl(Map & idfs, strings::UniString const & s, bool isPrefix); + + Map m_fullIdfs; + Map m_prefixIdfs; Delegate & m_delegate; double m_unknownIdf; diff --git a/search/locality_scorer.cpp b/search/locality_scorer.cpp index f62bd002c5..eab4ebf911 100644 --- a/search/locality_scorer.cpp +++ b/search/locality_scorer.cpp @@ -30,9 +30,9 @@ struct IdfMapDelegate : public IdfMap::Delegate ~IdfMapDelegate() override = default; - uint64_t GetNumDocs(strings::UniString const & token) const override + uint64_t GetNumDocs(strings::UniString const & token, bool isPrefix) const override { - return m_filter.Intersect(m_delegate.GetMatchedFeatures(token)).PopCount(); + return m_filter.Intersect(m_delegate.GetMatchedFeatures(token, isPrefix)).PopCount(); } LocalityScorer::Delegate const & m_delegate; @@ -59,6 +59,8 @@ void LocalityScorer::GetTopLocalities(MwmSet::MwmId const & countryId, BaseConte CBV const & filter, size_t limit, vector & localities) { + double const kUnknownIdf = 1.0; + CHECK_EQUAL(ctx.m_numTokens, m_params.GetNumTokens(), ()); localities.clear(); @@ -68,24 +70,16 @@ void LocalityScorer::GetTopLocalities(MwmSet::MwmId const & countryId, BaseConte intersections[i] = filter.Intersect(ctx.m_features[i]); IdfMapDelegate delegate(m_delegate, filter); - IdfMap idfs(delegate, 1.0 /* unknownIdf */); - double prefixIdf = 1.0; - for (size_t i = 0; i < ctx.m_numTokens; ++i) + IdfMap idfs(delegate, kUnknownIdf); + if (ctx.m_numTokens > 0 && m_params.LastTokenIsPrefix()) { - auto const numDocs = intersections[i].PopCount(); - double idf = 1.0; + auto const numDocs = intersections.back().PopCount(); + double idf = kUnknownIdf; if (numDocs > 0) idf = 1.0 / static_cast(numDocs); - - if (m_params.IsPrefixToken(i)) - { - prefixIdf = idf; - } - else - { - m_params.GetToken(i).ForEach( - [&idfs, &idf](strings::UniString const & s) { idfs.Set(s, idf); }); - } + m_params.GetToken(ctx.m_numTokens - 1).ForEach([&idfs, &idf](strings::UniString const & s) { + idfs.Set(s, true /* isPrefix */, idf); + }); } for (size_t startToken = 0; startToken < ctx.m_numTokens; ++startToken) @@ -99,9 +93,9 @@ void LocalityScorer::GetTopLocalities(MwmSet::MwmId const & countryId, BaseConte auto const curToken = endToken - 1; auto const & token = m_params.GetToken(curToken).m_original; if (m_params.IsPrefixToken(curToken)) - builder.SetPrefix(token, prefixIdf); + builder.SetPrefix(token); else - builder.AddFull(token, idfs.Get(token)); + builder.AddFull(token); TokenRange const tokenRange(startToken, endToken); // Skip locality candidates that match only numbers. @@ -109,7 +103,7 @@ void LocalityScorer::GetTopLocalities(MwmSet::MwmId const & countryId, BaseConte { intersection.ForEach([&](uint64_t bit) { auto const featureId = base::asserted_cast(bit); - localities.emplace_back(countryId, featureId, tokenRange, QueryVec(builder)); + localities.emplace_back(countryId, featureId, tokenRange, QueryVec(idfs, builder)); }); } @@ -126,15 +120,15 @@ void LocalityScorer::LeaveTopLocalities(IdfMap & idfs, size_t limit, { vector els; els.reserve(localities.size()); - for (auto const & locality : localities) + for (auto & locality : localities) { auto const queryNorm = locality.m_queryVec.Norm(); auto const rank = m_delegate.GetRank(locality.m_featureId); els.emplace_back(locality, queryNorm, rank); } - // We don't want to read too much names for localities, to this is - // the best effort - select best features by available params - + // We don't want to read too many names for localities, so this is + // the best effort - select the best features by available params - // query norm and rank. LeaveTopByNormAndRank(max(limit, kDefaultReadLimit) /* limitUniqueIds */, els); @@ -209,7 +203,7 @@ void LocalityScorer::LeaveTopBySimilarityAndRank(size_t limit, vector & dvs) const @@ -224,22 +218,22 @@ void LocalityScorer::GetDocVecs(IdfMap & idfs, uint32_t localityId, vector const & dvc) const +double LocalityScorer::GetSimilarity(QueryVec & qv, vector & dvc) const { double const kScale = 1e6; double similarity = 0; - for (auto const & dv : dvc) + for (auto & dv : dvc) similarity = max(similarity, qv.Similarity(dv)); - // We need scale here to prevent double artifacts, and to make - // sorting by similarity more robust, as 1e-6 is good enough for our - // purposes. + // We need to scale similarity here to prevent floating-point + // artifacts, and to make sorting by similarity more robust, as 1e-6 + // is good enough for our purposes. return round(similarity * kScale); } diff --git a/search/locality_scorer.hpp b/search/locality_scorer.hpp index 6a1c6201ca..8a4e0a58e0 100644 --- a/search/locality_scorer.hpp +++ b/search/locality_scorer.hpp @@ -30,7 +30,7 @@ public: virtual void GetNames(uint32_t featureId, std::vector & names) const = 0; virtual uint8_t GetRank(uint32_t featureId) const = 0; - virtual CBV GetMatchedFeatures(strings::UniString const & token) const = 0; + virtual CBV GetMatchedFeatures(strings::UniString const & token, bool isPrefix) const = 0; }; LocalityScorer(QueryParams const & params, Delegate const & delegate); @@ -41,9 +41,6 @@ public: CBV const & filter, size_t limit, std::vector & localities); private: - using DocVec = DocVec; - using QueryVec = Locality::QueryVec; - struct ExLocality { ExLocality() = default; @@ -68,12 +65,12 @@ private: // features in |els|. void LeaveTopByNormAndRank(size_t limitUniqueIds, std::vector & els) const; - // Leaves at most |limit| best localities by similarity to the query - // and rank. Result doesn't contain duplicate features. + // Leaves at most |limit| unique best localities by similarity to + // the query and rank. void LeaveTopBySimilarityAndRank(size_t limit, std::vector & els) const; void GetDocVecs(IdfMap & idfs, uint32_t localityId, std::vector & dvs) const; - double GetSimilarity(QueryVec const & qv, std::vector const & dvs) const; + double GetSimilarity(QueryVec & qv, std::vector & dvs) const; QueryParams const & m_params; Delegate const & m_delegate; diff --git a/search/search.pro b/search/search.pro index 94666b9bf1..81fbb2cdc5 100644 --- a/search/search.pro +++ b/search/search.pro @@ -96,6 +96,7 @@ SOURCES += \ cities_boundaries_table.cpp \ city_finder.cpp \ displayed_categories.cpp \ + doc_vec.cpp \ downloader_search_callback.cpp \ dummy_rank_table.cpp \ editor_delegate.cpp \ diff --git a/search/search_tests/locality_scorer_test.cpp b/search/search_tests/locality_scorer_test.cpp index 67789b5bdc..aa655212d4 100644 --- a/search/search_tests/locality_scorer_test.cpp +++ b/search/search_tests/locality_scorer_test.cpp @@ -122,10 +122,20 @@ public: return it == m_ranks.end() ? 0 : it->second; } - CBV GetMatchedFeatures(strings::UniString const & token) const override + CBV GetMatchedFeatures(strings::UniString const & token, bool isPrefix) const override { vector ids; - m_searchIndex.ForEachInNode(token, [&ids](uint32_t id) { ids.push_back(id); }); + + if (isPrefix) + { + m_searchIndex.ForEachInSubtree(token, [&ids](strings::UniString const & /* prefix */, + uint32_t id) { ids.push_back(id); }); + } + else + { + m_searchIndex.ForEachInNode(token, [&ids](uint32_t id) { ids.push_back(id); }); + } + my::SortUnique(ids); return CBV{coding::CompressedBitVectorBuilder::FromBitPositions(move(ids))}; }