From 20cb964da9929641da191afbbabeb89f8d51405b Mon Sep 17 00:00:00 2001 From: Yuri Gorshenin Date: Tue, 12 Dec 2017 21:25:27 +0300 Subject: [PATCH] [search] TF-IDF for localities ranking. --- base/buffer_vector.hpp | 6 + search/CMakeLists.txt | 3 + search/doc_vec.hpp | 317 +++++++++++++++++++ search/geocoder.cpp | 4 +- search/geocoder_locality.cpp | 1 - search/geocoder_locality.hpp | 15 +- search/idf_map.cpp | 12 + search/idf_map.hpp | 21 ++ search/locality_scorer.cpp | 277 ++++++++-------- search/locality_scorer.hpp | 37 ++- search/search.pro | 3 + search/search_tests/locality_scorer_test.cpp | 68 ++-- 12 files changed, 572 insertions(+), 192 deletions(-) create mode 100644 search/doc_vec.hpp create mode 100644 search/idf_map.cpp create mode 100644 search/idf_map.hpp diff --git a/base/buffer_vector.hpp b/base/buffer_vector.hpp index 99c2925c6c..377a1f4777 100644 --- a/base/buffer_vector.hpp +++ b/base/buffer_vector.hpp @@ -457,3 +457,9 @@ inline bool operator<(buffer_vector const & v1, buffer_vector cons { return std::lexicographical_compare(v1.begin(), v1.end(), v2.begin(), v2.end()); } + +template +inline bool operator>(buffer_vector const & v1, buffer_vector const & v2) +{ + return v2 < v1; +} diff --git a/search/CMakeLists.txt b/search/CMakeLists.txt index 425c6ab313..53d6ec8483 100644 --- a/search/CMakeLists.txt +++ b/search/CMakeLists.txt @@ -20,6 +20,7 @@ set( common.hpp displayed_categories.cpp displayed_categories.hpp + doc_vec.hpp downloader_search_callback.cpp downloader_search_callback.hpp dummy_rank_table.cpp @@ -64,6 +65,8 @@ set( house_numbers_matcher.hpp house_to_street_table.cpp house_to_street_table.hpp + idf_map.cpp + idf_map.hpp intermediate_result.cpp intermediate_result.hpp intersection_result.cpp diff --git a/search/doc_vec.hpp b/search/doc_vec.hpp new file mode 100644 index 0000000000..e299dee44a --- /dev/null +++ b/search/doc_vec.hpp @@ -0,0 +1,317 @@ +#pragma once + +#include "base/assert.hpp" +#include "base/string_utils.hpp" + +#include +#include +#include +#include +#include +#include + +#include + +namespace search +{ +template +struct TokenWeightPair +{ + TokenWeightPair() = default; + + template + TokenWeightPair(T && token, double weight) : m_token(std::forward(token)), m_weight(weight) + { + } + + bool operator<(TokenWeightPair const & rhs) const + { + if (m_token != rhs.m_token) + return m_token < rhs.m_token; + return m_weight < rhs.m_weight; + } + + void Swap(TokenWeightPair & rhs) + { + m_token.swap(rhs.m_token); + std::swap(m_weight, rhs.m_weight); + } + + // Returns squared weight of the token-weight pair. + double Sqr() const { return m_weight * m_weight; } + + Token m_token; + double m_weight = 0; +}; + +template +std::string DebugPrint(TokenWeightPair const & tw) +{ + std::ostringstream os; + os << "TokenWeightPair [ " << DebugPrint(tw.m_token) << ", " << tw.m_weight << " ]"; + return os.str(); +} + +namespace impl +{ +// Accumulates weights of equal tokens in |tws|. Result is sorted by tokens. +template +void SortAndMerge(std::vector> & tws) +{ + std::sort(tws.begin(), tws.end()); + size_t n = 0; + for (size_t i = 0; i < tws.size(); ++i) + { + ASSERT_LESS_OR_EQUAL(n, i, ()); + if (n == 0 || tws[n - 1].m_token != tws[i].m_token) + { + tws[n].Swap(tws[i]); + ++n; + } + else + { + tws[n - 1].m_weight += tws[i].m_weight; + } + } + + ASSERT_LESS_OR_EQUAL(n, tws.size(), ()); + tws.erase(tws.begin() + n, tws.end()); +} + +// Computes squared L2 norm of vector of tokens. +template +double SqrL2(std::vector> const & tws) +{ + double sum = 0; + for (auto const & tw : tws) + sum += tw.Sqr(); + return sum; +} + +// Computes squared L2 norm of vector of tokens + prefix token. +template +double SqrL2(std::vector> const & tws, + boost::optional> const & prefix) +{ + double result = SqrL2(tws); + return result + (prefix ? prefix->Sqr() : 0); +} +} // namespace impl + +// This class represents a document in a vector space of tokens. +template +class DocVec +{ +public: + using TokenWeightPair = TokenWeightPair; + using TokenWeightPairs = std::vector; + + class Builder + { + public: + template + void Add(T && token, double weight) + { + m_tws.emplace_back(std::forward(token), weight); + } + + private: + friend class DocVec; + + std::vector m_tws; + }; + + DocVec() = default; + explicit DocVec(Builder && builder) : m_tws(std::move(builder.m_tws)) { Init(); } + explicit DocVec(Builder const & builder) : m_tws(builder.m_tws) { Init(); } + + TokenWeightPairs const & GetTokenWeightPairs() const { return m_tws; } + + bool Empty() const { return m_tws.empty(); } + +private: + template + friend std::string DebugPrint(DocVec const & dv) + { + return "DocVec " + DebugPrint(dv.m_tws); + } + + void Init() { impl::SortAndMerge(m_tws); } + + TokenWeightPairs m_tws; +}; + +// This class represents a search query in a vector space of tokens. +template +class QueryVec +{ +public: + using TokenWeightPair = TokenWeightPair; + + class Builder + { + public: + template + void AddFull(T && token, double weight) + { + m_tws.emplace_back(std::forward(token), weight); + } + + template + void SetPrefix(T && token, double weight) + { + m_prefix = TokenWeightPair(std::forward(token), weight); + } + + private: + friend class QueryVec; + + std::vector m_tws; + boost::optional m_prefix; + }; + + QueryVec() = default; + + explicit QueryVec(Builder && builder) + : m_tws(std::move(builder.m_tws)), m_prefix(std::move(builder.m_prefix)) + { + Init(); + } + + explicit QueryVec(Builder const & builder) : m_tws(builder.m_tws), m_prefix(builder.m_prefix) + { + Init(); + } + + // Computes cosine distance between |*this| and |rhs|. + double Similarity(DocVec const & rhs) const + { + size_t kInvalidIndex = std::numeric_limits::max(); + + if (Empty() && rhs.Empty()) + return 1.0; + + if (Empty() || rhs.Empty()) + return 0.0; + + auto const & ls = m_tws; + auto const & rs = rhs.GetTokenWeightPairs(); + + ASSERT(std::is_sorted(ls.begin(), ls.end()), ()); + ASSERT(std::is_sorted(rs.begin(), rs.end()), ()); + + std::vector rsMatchTo(rs.size(), kInvalidIndex); + + size_t i = 0, j = 0; + double dot = 0; + while (i < ls.size() && j < rs.size()) + { + if (ls[i].m_token < rs[j].m_token) + { + ++i; + } + else if (ls[i].m_token > rs[j].m_token) + { + ++j; + } + else + { + dot += ls[i].m_weight * rs[j].m_weight; + rsMatchTo[j] = i; + ++i; + ++j; + } + } + + auto const ln = impl::SqrL2(ls, m_prefix); + auto const rn = impl::SqrL2(rs); + + // This similarity metric assumes that prefix is not matched in the document. + double const similarityNoPrefix = ln > 0 && rn > 0 ? dot / sqrt(ln) / sqrt(rn) : 0; + + if (!m_prefix) + return similarityNoPrefix; + + double similarityWithPrefix = 0; + + auto const & prefix = *m_prefix; + + // Let's try to match prefix token with all tokens in the + // document, and compute the best cosine distance. + for (size_t j = 0; j < rs.size(); ++j) + { + auto const & tw = rs[j]; + if (!strings::StartsWith(tw.m_token.begin(), tw.m_token.end(), prefix.m_token.begin(), + prefix.m_token.end())) + { + continue; + } + + auto const i = rsMatchTo[j]; + + double nom = 0; + double denom = 0; + if (i == kInvalidIndex) + { + // If this document token is not matched with full tokens in a + // query, we need to update it's weight in the cosine distance + // - so we need to update correspondingly dot product and + // vector norms of query and doc. + + // This is the hacky moment: weight of query prefix token may + // be greater than the weight of the corresponding document + // token, because the weight of the document token may be + // unknown at the moment, and be set to some default value. + // But this heuristic works nicely in practice. + double const w = std::max(prefix.m_weight, tw.m_weight); + auto const sqrW = w * w; + double const l = std::max(0.0, ln - prefix.Sqr() + sqrW); + double const r = std::max(0.0, rn - tw.Sqr() + sqrW); + + nom = dot + sqrW; + denom = sqrt(l) * sqrt(r); + } + else + { + // If this document token is already matched with |i|-th full + // token in a query - we here that completion of the prefix + // token is the |i|-th token. So we need to update + // correspondingly dot product and vector norm of the query. + double const l = ln + 2 * ls[i].m_weight * prefix.m_weight; + + nom = dot + prefix.m_weight * tw.m_weight; + denom = sqrt(l) * sqrt(rn); + } + + if (denom > 0) + similarityWithPrefix = std::max(similarityWithPrefix, nom / denom); + } + + return std::max(similarityWithPrefix, similarityNoPrefix); + } + + double Norm() const + { + double n = 0; + for (auto const & tw : m_tws) + n += tw.m_weight * tw.m_weight; + if (m_prefix) + n += m_prefix->m_weight * m_prefix->m_weight; + return sqrt(n); + } + + bool Empty() const { return m_tws.empty() && !m_prefix; } + +private: + template + friend std::string DebugPrint(QueryVec const & qv) + { + return "QueryVec " + DebugPrint(qv.m_tws); + } + + void Init() { impl::SortAndMerge(m_tws); } + + std::vector m_tws; + boost::optional m_prefix; +}; +} // namespace search diff --git a/search/geocoder.cpp b/search/geocoder.cpp index dcf4990ddc..c5d0162fe0 100644 --- a/search/geocoder.cpp +++ b/search/geocoder.cpp @@ -689,7 +689,7 @@ void Geocoder::FillLocalitiesTable(BaseContext const & ctx) #if defined(DEBUG) ft.GetName(StringUtf8Multilang::kDefaultCode, city.m_defaultName); - LOG(LINFO, ("City =", city.m_defaultName, "radius =", radius, "prob =", city.m_prob)); + LOG(LINFO, ("City =", city.m_defaultName, "radius =", radius)); #endif m_cities[city.m_tokenRange].push_back(city); @@ -738,7 +738,7 @@ void Geocoder::FillVillageLocalities(BaseContext const & ctx) #if defined(DEBUG) ft.GetName(StringUtf8Multilang::kDefaultCode, village.m_defaultName); - LOG(LDEBUG, ("Village =", village.m_defaultName, "radius =", radius, "prob =", village.m_prob)); + LOG(LDEBUG, ("Village =", village.m_defaultName, "radius =", radius)); #endif m_cities[village.m_tokenRange].push_back(village); diff --git a/search/geocoder_locality.cpp b/search/geocoder_locality.cpp index ec3df798f4..2cf619eff0 100644 --- a/search/geocoder_locality.cpp +++ b/search/geocoder_locality.cpp @@ -22,7 +22,6 @@ std::string DebugPrint(Locality const & locality) os << "m_countryId=" << DebugPrint(locality.m_countryId) << ", "; os << "m_featureId=" << locality.m_featureId << ", "; os << "m_tokenRange=" << DebugPrint(locality.m_tokenRange) << ", "; - os << "m_prob=" << locality.m_prob; os << " ]"; return os.str(); } diff --git a/search/geocoder_locality.hpp b/search/geocoder_locality.hpp index 850a465d29..42fb86f887 100644 --- a/search/geocoder_locality.hpp +++ b/search/geocoder_locality.hpp @@ -1,5 +1,6 @@ #pragma once +#include "search/doc_vec.hpp" #include "search/model.hpp" #include "search/token_range.hpp" @@ -17,22 +18,22 @@ namespace search { struct Locality { + using QueryVec = QueryVec; + Locality() = default; Locality(MwmSet::MwmId const & countryId, uint32_t featureId, TokenRange const & tokenRange, - double prob) - : m_countryId(countryId), m_featureId(featureId), m_tokenRange(tokenRange), m_prob(prob) + QueryVec const & queryVec) + : m_countryId(countryId), m_featureId(featureId), m_tokenRange(tokenRange), m_queryVec(queryVec) { } + double QueryNorm() const { return m_queryVec.Norm(); } + MwmSet::MwmId m_countryId; uint32_t m_featureId = 0; TokenRange m_tokenRange; - - // Measures our belief in the fact that tokens in the range - // [m_startToken, m_endToken) indeed specify a locality. Currently - // it is set only for villages. - double m_prob = 0.0; + QueryVec m_queryVec; }; // This struct represents a country or US- or Canadian- state. It diff --git a/search/idf_map.cpp b/search/idf_map.cpp new file mode 100644 index 0000000000..28798d20bd --- /dev/null +++ b/search/idf_map.cpp @@ -0,0 +1,12 @@ +#include "search/idf_map.hpp" + +namespace search +{ +IdfMap::IdfMap(double unknownIdf): m_unknownIdf(unknownIdf) {} + +double IdfMap::Get(strings::UniString const & s) const +{ + auto const it = m_idfs.find(s); + return it == m_idfs.cend() ? m_unknownIdf : it->second; +} +} // namespace search diff --git a/search/idf_map.hpp b/search/idf_map.hpp new file mode 100644 index 0000000000..aa3709293f --- /dev/null +++ b/search/idf_map.hpp @@ -0,0 +1,21 @@ +#pragma once + +#include "base/string_utils.hpp" + +#include + +namespace search +{ +class IdfMap +{ +public: + explicit IdfMap(double unknownIdf); + + void Set(strings::UniString const & s, double idf) { m_idfs[s] = idf; } + double Get(strings::UniString const & s) const; + +private: + std::map m_idfs; + double m_unknownIdf; +}; +} // namespace search diff --git a/search/locality_scorer.cpp b/search/locality_scorer.cpp index 0cfe084f35..1046da4de6 100644 --- a/search/locality_scorer.cpp +++ b/search/locality_scorer.cpp @@ -2,37 +2,29 @@ #include "search/cbv.hpp" #include "search/geocoder_context.hpp" +#include "search/idf_map.hpp" #include "search/token_slice.hpp" +#include "indexer/search_string_utils.hpp" + #include "base/checked_cast.hpp" +#include "base/stl_helpers.hpp" #include #include #include +#include + +using namespace std; namespace search { // static size_t const LocalityScorer::kDefaultReadLimit = 100; -namespace -{ -bool IsAlmostFullMatch(NameScore score) -{ - return score == NAME_SCORE_PREFIX || score == NAME_SCORE_FULL_MATCH; -} -} // namespace - // LocalityScorer::ExLocality ---------------------------------------------------------------------- -LocalityScorer::ExLocality::ExLocality() : m_numTokens(0), m_rank(0), m_nameScore(NAME_SCORE_ZERO) -{ -} - -LocalityScorer::ExLocality::ExLocality(Locality const & locality) - : m_locality(locality) - , m_numTokens(locality.m_tokenRange.Size()) - , m_rank(0) - , m_nameScore(NAME_SCORE_ZERO) +LocalityScorer::ExLocality::ExLocality(Locality const & locality, double queryNorm, uint8_t rank) + : m_locality(locality), m_queryNorm(queryNorm), m_rank(rank) { } @@ -44,152 +36,189 @@ LocalityScorer::LocalityScorer(QueryParams const & params, Delegate const & dele void LocalityScorer::GetTopLocalities(MwmSet::MwmId const & countryId, BaseContext const & ctx, CBV const & filter, size_t limit, - std::vector & localities) + vector & localities) { CHECK_EQUAL(ctx.m_numTokens, m_params.GetNumTokens(), ()); localities.clear(); + vector intersections(ctx.m_numTokens); + for (size_t i = 0; i < ctx.m_numTokens; ++i) + intersections[i] = filter.Intersect(ctx.m_features[i]); + + IdfMap idfs(1.0 /* unknownIdf */); + for (size_t i = 0; i < ctx.m_numTokens; ++i) + { + auto const idf = 1.0 / static_cast(intersections[i].PopCount()); + // IDF should be the same for the token and its synonyms. + m_params.GetToken(i).ForEach([&idfs, &idf](strings::UniString const & s) { idfs.Set(s, idf); }); + } + for (size_t startToken = 0; startToken < ctx.m_numTokens; ++startToken) { - CBV intersection = filter.Intersect(ctx.m_features[startToken]); - if (intersection.IsEmpty()) - continue; + auto intersection = intersections[startToken]; + QueryVec::Builder builder; - CBV unfilteredIntersection = ctx.m_features[startToken]; - - for (size_t endToken = startToken + 1; endToken <= ctx.m_numTokens; ++endToken) + for (size_t endToken = startToken + 1; endToken <= ctx.m_numTokens && !intersection.IsEmpty(); + ++endToken) { + auto const curToken = endToken - 1; + auto const & token = m_params.GetToken(curToken).m_original; + double const weight = idfs.Get(token); + if (m_params.IsPrefixToken(curToken)) + builder.SetPrefix(token, weight); + else + builder.AddFull(token, weight); + TokenRange const tokenRange(startToken, endToken); // Skip locality candidates that match only numbers. if (!m_params.IsNumberTokens(tokenRange)) { intersection.ForEach([&](uint64_t bit) { auto const featureId = base::asserted_cast(bit); - double const prob = static_cast(intersection.PopCount()) / - static_cast(unfilteredIntersection.PopCount()); - localities.emplace_back(countryId, featureId, tokenRange, prob); + localities.emplace_back(countryId, featureId, tokenRange, QueryVec(builder)); }); } if (endToken < ctx.m_numTokens) - { - intersection = intersection.Intersect(ctx.m_features[endToken]); - if (intersection.IsEmpty()) - break; - - unfilteredIntersection = unfilteredIntersection.Intersect(ctx.m_features[endToken]); - } + intersection = intersection.Intersect(intersections[endToken]); } } - LeaveTopLocalities(limit, localities); + LeaveTopLocalities(idfs, limit, localities); } -void LocalityScorer::LeaveTopLocalities(size_t limit, std::vector & localities) const +void LocalityScorer::LeaveTopLocalities(IdfMap const & idfs, size_t limit, + vector & localities) const { - std::vector ls; - ls.reserve(localities.size()); + vector els; + els.reserve(localities.size()); for (auto const & locality : localities) - ls.emplace_back(locality); + { + auto const queryNorm = locality.m_queryVec.Norm(); + auto const rank = m_delegate.GetRank(locality.m_featureId); + els.emplace_back(locality, queryNorm, rank); + } - RemoveDuplicates(ls); - LeaveTopByRankAndProb(std::max(limit, kDefaultReadLimit), ls); - SortByNameAndProb(ls); - if (ls.size() > limit) - ls.resize(limit); + // We don't want to read too much names for localities, to this is + // the best effort - select best features by available params - + // query norm and rank. + LeaveTopByNormAndRank(max(limit, kDefaultReadLimit) /* limitUniqueIds */, els); + + sort(els.begin(), els.end(), + [](ExLocality const & lhs, ExLocality const & rhs) { return lhs.GetId() < rhs.GetId(); }); + + size_t i = 0; + while (i < els.size()) + { + size_t j = i + 1; + while (j < els.size() && els[j].GetId() == els[i].GetId()) + ++j; + + vector dvs; + + // *NOTE* |idfs| is filled based on query tokens, not all + // localities tokens, because it's expensive to compute IDF map + // for all localities tokens. Therefore, for tokens not in the + // query, some default IDF value will be used. + GetDocVecs(idfs, els[i].GetId(), dvs); + for (; i < j; ++i) + els[i].m_similarity = GetSimilarity(els[i].m_locality.m_queryVec, dvs); + } + + LeaveTopBySimilarityAndRank(limit, els); + ASSERT_LESS_OR_EQUAL(els.size(), limit, ()); localities.clear(); - localities.reserve(ls.size()); - for (auto const & l : ls) - localities.push_back(l.m_locality); + localities.reserve(els.size()); + for (auto const & el : els) + localities.push_back(el.m_locality); + ASSERT_LESS_OR_EQUAL(localities.size(), limit, ()); } -void LocalityScorer::RemoveDuplicates(std::vector & ls) const +void LocalityScorer::LeaveTopByNormAndRank(size_t limitUniqueIds, vector & els) const { - std::sort(ls.begin(), ls.end(), [](ExLocality const & lhs, ExLocality const & rhs) { - if (lhs.GetId() != rhs.GetId()) - return lhs.GetId() < rhs.GetId(); - return lhs.m_numTokens > rhs.m_numTokens; - }); - ls.erase(std::unique(ls.begin(), ls.end(), - [](ExLocality const & lhs, ExLocality const & rhs) { - return lhs.GetId() == rhs.GetId(); - }), - ls.end()); -} - -void LocalityScorer::LeaveTopByRankAndProb(size_t limit, std::vector & ls) const -{ - if (ls.size() <= limit) - return; - - for (auto & l : ls) - l.m_rank = m_delegate.GetRank(l.GetId()); - - std::sort(ls.begin(), ls.end(), [](ExLocality const & lhs, ExLocality const & rhs) { - if (lhs.m_locality.m_prob != rhs.m_locality.m_prob) - return lhs.m_locality.m_prob > rhs.m_locality.m_prob; - if (lhs.m_rank != rhs.m_rank) - return lhs.m_rank > rhs.m_rank; - return lhs.m_numTokens > rhs.m_numTokens; - }); - ls.resize(limit); -} - -void LocalityScorer::SortByNameAndProb(std::vector & ls) const -{ - std::vector names; - for (auto & l : ls) - { - names.clear(); - m_delegate.GetNames(l.GetId(), names); - - auto score = NAME_SCORE_ZERO; - for (auto const & name : names) - score = max(score, GetNameScore(name, TokenSlice(m_params, l.m_locality.m_tokenRange))); - l.m_nameScore = score; - } - - std::sort(ls.begin(), ls.end(), [](ExLocality const & lhs, ExLocality const & rhs) { - // Probabilities form a stronger signal than name scores do. - if (lhs.m_locality.m_prob != rhs.m_locality.m_prob) - return lhs.m_locality.m_prob > rhs.m_locality.m_prob; - if (IsAlmostFullMatch(lhs.m_nameScore) && IsAlmostFullMatch(rhs.m_nameScore)) - { - // When both localities match well, e.g. full or full prefix - // match, the one with larger number of tokens is selected. In - // case of tie, the one with better score is selected. - if (lhs.m_numTokens != rhs.m_numTokens) - return lhs.m_numTokens > rhs.m_numTokens; - if (lhs.m_nameScore != rhs.m_nameScore) - return lhs.m_nameScore > rhs.m_nameScore; - } - else - { - // When name scores differ, the one with better name score is - // selected. In case of tie, the one with larger number of - // matched tokens is selected. - if (lhs.m_nameScore != rhs.m_nameScore) - return lhs.m_nameScore > rhs.m_nameScore; - if (lhs.m_numTokens != rhs.m_numTokens) - return lhs.m_numTokens > rhs.m_numTokens; - } - - // Okay, in case of tie we select the one with better rank. This - // is a quite arbitrary decision and definitely may be improved. + sort(els.begin(), els.end(), [](ExLocality const & lhs, ExLocality const & rhs) { + auto const ln = lhs.m_queryNorm; + auto const rn = rhs.m_queryNorm; + if (ln != rn) + return ln > rn; return lhs.m_rank > rhs.m_rank; }); + + unordered_set seen; + for (size_t i = 0; i < els.size() && seen.size() < limitUniqueIds; ++i) + seen.insert(els[i].GetId()); + ASSERT_LESS_OR_EQUAL(seen.size(), limitUniqueIds, ()); + + my::EraseIf(els, [&](ExLocality const & el) { return seen.find(el.GetId()) == seen.cend(); }); } -string DebugPrint(LocalityScorer::ExLocality const & locality) +void LocalityScorer::LeaveTopBySimilarityAndRank(size_t limit, vector & els) const +{ + sort(els.begin(), els.end(), [](ExLocality const & lhs, ExLocality const & rhs) { + if (lhs.m_similarity != rhs.m_similarity) + return lhs.m_similarity > rhs.m_similarity; + if (lhs.m_rank != rhs.m_rank) + return lhs.m_rank > rhs.m_rank; + return lhs.m_locality.m_featureId < rhs.m_locality.m_featureId; + }); + + unordered_set seen; + + size_t n = 0; + for (size_t i = 0; i < els.size() && n < limit; ++i) + { + auto const id = els[i].GetId(); + if (seen.insert(id).second) + { + els[n] = els[i]; + ++n; + } + } + els.resize(n); +} + +void LocalityScorer::GetDocVecs(IdfMap const & idfs, uint32_t localityId, + vector & dvs) const +{ + vector names; + m_delegate.GetNames(localityId, names); + + for (auto const & name : names) + { + vector tokens; + NormalizeAndTokenizeString(name, tokens); + + DocVec::Builder builder; + for (auto const & token : tokens) + builder.Add(token, idfs.Get(token) /* weight */); + dvs.emplace_back(move(builder)); + } +} + +double LocalityScorer::GetSimilarity(QueryVec const & qv, vector const & dvc) const +{ + double const kScale = 1e6; + + double similarity = 0; + for (auto const & dv : dvc) + similarity = max(similarity, qv.Similarity(dv)); + + // We need scale here to prevent double artifacts, and to make + // sorting by similarity more robust, as 1e-6 is good enough for our + // purposes. + return round(similarity * kScale); +} + +string DebugPrint(LocalityScorer::ExLocality const & el) { ostringstream os; os << "LocalityScorer::ExLocality [ "; - os << "m_locality=" << DebugPrint(locality.m_locality) << ", "; - os << "m_numTokens=" << locality.m_numTokens << ", "; - os << "m_rank=" << static_cast(locality.m_rank) << ", "; - os << "m_nameScore=" << DebugPrint(locality.m_nameScore); + os << "m_locality=" << DebugPrint(el.m_locality) << ", "; + os << "m_queryNorm=" << el.m_queryNorm << ", "; + os << "m_similarity=" << el.m_similarity << ", "; + os << "m_rank=" << static_cast(el.m_rank); os << " ]"; return os.str(); } diff --git a/search/locality_scorer.hpp b/search/locality_scorer.hpp index 6919b3dc06..d68a0d93d6 100644 --- a/search/locality_scorer.hpp +++ b/search/locality_scorer.hpp @@ -3,6 +3,9 @@ #include "search/geocoder_locality.hpp" #include "search/ranking_utils.hpp" +#include "base/string_utils.hpp" + +#include #include #include #include @@ -10,6 +13,7 @@ namespace search { class CBV; +class IdfMap; class QueryParams; struct BaseContext; @@ -35,31 +39,42 @@ public: CBV const & filter, size_t limit, std::vector & localities); private: + using DocVec = DocVec; + using QueryVec = Locality::QueryVec; + struct ExLocality { - ExLocality(); - explicit ExLocality(Locality const & locality); + ExLocality() = default; + ExLocality(Locality const & locality, double queryNorm, uint8_t rank); - inline uint32_t GetId() const { return m_locality.m_featureId; } + uint32_t GetId() const { return m_locality.m_featureId; } Locality m_locality; - size_t m_numTokens; - uint8_t m_rank; - NameScore m_nameScore; + double m_queryNorm = 0.0; + double m_similarity = 0.0; + uint8_t m_rank = 0; }; friend std::string DebugPrint(ExLocality const & locality); // Leaves at most |limit| elements of |localities|, ordered by some // combination of ranks and number of matched tokens. - void LeaveTopLocalities(size_t limit, std::vector & localities) const; + void LeaveTopLocalities(IdfMap const & idfs, size_t limit, + std::vector & localities) const; - void RemoveDuplicates(std::vector & ls) const; - void LeaveTopByRankAndProb(size_t limit, std::vector & ls) const; - void SortByNameAndProb(std::vector & ls) const; + // Selects at most |limitUniqueIds| best features by query norm and + // rank, and then leaves only localities corresponding to those + // features in |els|. + void LeaveTopByNormAndRank(size_t limitUniqueIds, std::vector & els) const; + + // Leaves at most |limit| best localities by similarity to the query + // and rank. Result doesn't contain duplicate features. + void LeaveTopBySimilarityAndRank(size_t limit, std::vector & els) const; + + void GetDocVecs(IdfMap const & idfs, uint32_t localityId, vector & dvs) const; + double GetSimilarity(QueryVec const & qv, std::vector const & dvs) const; QueryParams const & m_params; Delegate const & m_delegate; }; - } // namespace search diff --git a/search/search.pro b/search/search.pro index 103ab8ad79..94666b9bf1 100644 --- a/search/search.pro +++ b/search/search.pro @@ -21,6 +21,7 @@ HEADERS += \ city_finder.hpp \ common.hpp \ displayed_categories.hpp \ + doc_vec.hpp \ downloader_search_callback.hpp \ dummy_rank_table.hpp \ editor_delegate.hpp \ @@ -44,6 +45,7 @@ HEADERS += \ house_detector.hpp \ house_numbers_matcher.hpp \ house_to_street_table.hpp \ + idf_map.hpp \ intermediate_result.hpp \ intersection_result.hpp \ interval_set.hpp \ @@ -115,6 +117,7 @@ SOURCES += \ house_detector.cpp \ house_numbers_matcher.cpp \ house_to_street_table.cpp \ + idf_map.cpp \ intermediate_result.cpp \ intersection_result.cpp \ keyword_lang_matcher.cpp \ diff --git a/search/search_tests/locality_scorer_test.cpp b/search/search_tests/locality_scorer_test.cpp index 8d04504d66..f97ddd9f6b 100644 --- a/search/search_tests/locality_scorer_test.cpp +++ b/search/search_tests/locality_scorer_test.cpp @@ -28,6 +28,8 @@ namespace class LocalityScorerTest : public LocalityScorer::Delegate { public: + using Ids = vector; + LocalityScorerTest() : m_scorer(m_params, static_cast(*this)) {} void InitParams(string const & query, bool lastTokenIsPrefix) @@ -61,7 +63,7 @@ public: m_names[featureId].push_back(name); } - void GetTopLocalities(size_t limit) + Ids GetTopLocalities(size_t limit) { BaseContext ctx; ctx.m_tokens.assign(m_params.GetNumTokens(), BaseContext::TOKEN_TYPE_COUNT); @@ -94,8 +96,14 @@ public: CBV filter; filter.SetFull(); - m_scorer.GetTopLocalities(MwmSet::MwmId(), ctx, filter, limit, m_localities); - sort(m_localities.begin(), m_localities.end(), my::LessBy(&Locality::m_featureId)); + vector localities; + m_scorer.GetTopLocalities(MwmSet::MwmId(), ctx, filter, limit, localities); + sort(localities.begin(), localities.end(), my::LessBy(&Locality::m_featureId)); + + Ids ids; + for (auto const & locality : localities) + ids.push_back(locality.m_featureId); + return ids; } // LocalityScorer::Delegate overrides: @@ -110,7 +118,6 @@ public: protected: QueryParams m_params; - vector m_localities; unordered_map> m_names; LocalityScorer m_scorer; @@ -133,16 +140,9 @@ UNIT_CLASS_TEST(LocalityScorerTest, Smoke) AddLocality("York", ID_YORK); AddLocality("New York", ID_NEW_YORK); - GetTopLocalities(100 /* limit */); - TEST_EQUAL(3, m_localities.size(), ()); - TEST_EQUAL(m_localities[0].m_featureId, ID_NEW_ORLEANS, ()); - TEST_EQUAL(m_localities[1].m_featureId, ID_YORK, ()); - TEST_EQUAL(m_localities[2].m_featureId, ID_NEW_YORK, ()); - - // New York is the best matching locality - GetTopLocalities(1 /* limit */); - TEST_EQUAL(1, m_localities.size(), ()); - TEST_EQUAL(m_localities[0].m_featureId, ID_NEW_YORK, ()); + TEST_EQUAL(GetTopLocalities(100 /* limit */), Ids({ID_NEW_ORLEANS, ID_YORK, ID_NEW_YORK}), ()); + TEST_EQUAL(GetTopLocalities(2 /* limit */), Ids({ID_YORK, ID_NEW_YORK}), ()); + TEST_EQUAL(GetTopLocalities(1 /* limit */), Ids({ID_YORK}), ()); } UNIT_CLASS_TEST(LocalityScorerTest, NumbersMatch) @@ -164,13 +164,8 @@ UNIT_CLASS_TEST(LocalityScorerTest, NumbersMatch) // Tver is the only matched locality as other localities were // matched only by number. - GetTopLocalities(100 /* limit */); - TEST_EQUAL(1, m_localities.size(), ()); - TEST_EQUAL(m_localities[0].m_featureId, ID_TVER, ()); - - GetTopLocalities(1 /* limit */); - TEST_EQUAL(1, m_localities.size(), ()); - TEST_EQUAL(m_localities[0].m_featureId, ID_TVER, ()); + TEST_EQUAL(GetTopLocalities(100 /* limit */), Ids({ID_TVER}), ()); + TEST_EQUAL(GetTopLocalities(1 /* limit */), Ids({ID_TVER}), ()); } UNIT_CLASS_TEST(LocalityScorerTest, NumbersComplexMatch) @@ -186,12 +181,8 @@ UNIT_CLASS_TEST(LocalityScorerTest, NumbersComplexMatch) AddLocality("may 1", ID_MAY); AddLocality("saint petersburg", ID_SAINT_PETERSBURG); - // "May 1" contains a numeric token, but as it was matched by at - // least two tokens, there is no penalty for numeric token. And, as - // it has smaller featureId, it should be left. - GetTopLocalities(1 /* limit */); - TEST_EQUAL(1, m_localities.size(), ()); - TEST_EQUAL(m_localities[0].m_featureId, ID_MAY, ()); + TEST_EQUAL(GetTopLocalities(2 /* limit */), Ids({ID_MAY, ID_SAINT_PETERSBURG}), ()); + TEST_EQUAL(GetTopLocalities(1 /* limit */), Ids({ID_MAY}), ()); } UNIT_CLASS_TEST(LocalityScorerTest, PrefixMatch) @@ -212,24 +203,7 @@ UNIT_CLASS_TEST(LocalityScorerTest, PrefixMatch) AddLocality("Moscow", ID_MOSCOW); // All localities except Moscow match to the search query. - GetTopLocalities(100 /* limit */); - TEST_EQUAL(3, m_localities.size(), ()); - TEST_EQUAL(m_localities[0].m_featureId, ID_SAN_ANTONIO, ()); - TEST_EQUAL(m_localities[1].m_featureId, ID_NEW_YORK, ()); - TEST_EQUAL(m_localities[2].m_featureId, ID_YORK, ()); - - // New York and San Antonio are better than York, because they match - // by two tokens (second token is prefix for San Antonio), whereas - // York matches by only one token. - GetTopLocalities(2 /* limit */); - TEST_EQUAL(2, m_localities.size(), ()); - TEST_EQUAL(m_localities[0].m_featureId, ID_SAN_ANTONIO, ()); - TEST_EQUAL(m_localities[1].m_featureId, ID_NEW_YORK, ()); - - // New York is a better than San Antonio because it matches by two - // full tokens whereas San Antonio matches by one full token and by - // one prefix token. - GetTopLocalities(1 /* limit */); - TEST_EQUAL(1, m_localities.size(), ()); - TEST_EQUAL(m_localities[0].m_featureId, ID_NEW_YORK, ()); + TEST_EQUAL(GetTopLocalities(100 /* limit */), Ids({ID_SAN_ANTONIO, ID_NEW_YORK, ID_YORK}), ()); + TEST_EQUAL(GetTopLocalities(2 /* limit */), Ids({ID_SAN_ANTONIO, ID_NEW_YORK}), ()); + TEST_EQUAL(GetTopLocalities(1 /* limit */), Ids({ID_SAN_ANTONIO}), ()); }