diff --git a/search/categories_cache.cpp b/search/categories_cache.cpp index ab5d7a945e..7d0f4c52dc 100644 --- a/search/categories_cache.cpp +++ b/search/categories_cache.cpp @@ -52,7 +52,7 @@ CBV CategoriesCache::Load(MwmContext const & context) const }); Retrieval retrieval(context, m_cancellable); - return CBV(retrieval.RetrieveAddressFeatures(request).m_features); + return retrieval.RetrieveAddressFeatures(request).m_features; } // StreetsCache ------------------------------------------------------------------------------------ diff --git a/search/geocoder.cpp b/search/geocoder.cpp index 2363547e01..79a62af416 100644 --- a/search/geocoder.cpp +++ b/search/geocoder.cpp @@ -579,15 +579,15 @@ void Geocoder::InitBaseContext(BaseContext & ctx) // Implementation-wise, the simplest way to match a feature by // its category bypassing the matching by name is by using a CategoriesCache. CategoriesCache cache(m_params.m_preferredTypes, m_cancellable); - ctx.m_features[i] = cache.Get(*m_context); + ctx.m_features[i] = Retrieval::ExtendedFeatures(cache.Get(*m_context)); } else if (m_params.IsPrefixToken(i)) { - ctx.m_features[i] = retrieval.RetrieveAddressFeatures(m_prefixTokenRequest).m_features; + ctx.m_features[i] = retrieval.RetrieveAddressFeatures(m_prefixTokenRequest); } else { - ctx.m_features[i] = retrieval.RetrieveAddressFeatures(m_tokenRequests[i]).m_features; + ctx.m_features[i] = retrieval.RetrieveAddressFeatures(m_tokenRequests[i]); } } @@ -799,17 +799,18 @@ void Geocoder::MatchCategories(BaseContext & ctx, bool aroundPivot) { auto const pivotFeatures = RetrieveGeometryFeatures(*m_context, m_params.m_pivot, RECT_ID_PIVOT); ViewportFilter filter(pivotFeatures, m_preRanker.Limit() /* threshold */); - features = filter.Filter(features); + features.m_features = filter.Filter(features.m_features); + features.m_exactMatchingFeatures = + features.m_exactMatchingFeatures.Intersect(features.m_features); } - auto emit = [&](uint64_t bit) { - auto const featureId = base::asserted_cast(bit); + auto emit = [&](uint32_t featureId, bool exactMatch) { Model::Type type; if (!GetTypeInGeocoding(ctx, featureId, type)) return; - EmitResult(ctx, m_context->GetId(), featureId, type, TokenRange(0, ctx.m_numTokens), nullptr /* geoParts */, - true /* allTokensUsed */); + EmitResult(ctx, m_context->GetId(), featureId, type, TokenRange(0, ctx.m_numTokens), + nullptr /* geoParts */, true /* allTokensUsed */, exactMatch); }; // Features have been retrieved from the search index @@ -879,8 +880,15 @@ void Geocoder::MatchRegions(BaseContext & ctx, Region::Type type) if (ctx.AllTokensUsed()) { + bool exactMatch = true; + for (auto const & region : ctx.m_regions) + { + if (!region->m_exactMatch) + exactMatch = false; + } + // Region matches to search query, we need to emit it as is. - EmitResult(ctx, region, tokenRange, true /* allTokensUsed */); + EmitResult(ctx, region, tokenRange, true /* allTokensUsed */, exactMatch); continue; } @@ -924,7 +932,7 @@ void Geocoder::MatchCities(BaseContext & ctx) if (ctx.AllTokensUsed()) { // City matches to search query, we need to emit it as is. - EmitResult(ctx, city, tokenRange, true /* allTokensUsed */); + EmitResult(ctx, city, tokenRange, true /* allTokensUsed */, city.m_exactMatch); continue; } @@ -1085,7 +1093,7 @@ void Geocoder::MatchPOIsAndBuildings(BaseContext & ctx, size_t curToken) if (GetTypeInGeocoding(ctx, featureId, type)) { EmitResult(ctx, m_context->GetId(), featureId, type, m_postcodes.m_tokenRange, - nullptr /* geoParts */, true /* allTokensUsed */); + nullptr /* geoParts */, true /* allTokensUsed */, true /* exactMatch */); } }); return; @@ -1107,7 +1115,7 @@ void Geocoder::MatchPOIsAndBuildings(BaseContext & ctx, size_t curToken) if (!m_postcodes.m_features.HasBit(id)) continue; EmitResult(ctx, m_context->GetId(), id, Model::TYPE_STREET, layers.back().m_tokenRange, - nullptr /* geoParts */, true /* allTokensUsed */); + nullptr /* geoParts */, true /* allTokensUsed */, true /* exactMatch */); } } @@ -1153,7 +1161,7 @@ void Geocoder::MatchPOIsAndBuildings(BaseContext & ctx, size_t curToken) } }; - CBV features; + Retrieval::ExtendedFeatures features; features.SetFull(); // Try to consume [curToken, m_numTokens) tokens range. @@ -1172,9 +1180,9 @@ void Geocoder::MatchPOIsAndBuildings(BaseContext & ctx, size_t curToken) features = features.Intersect(ctx.m_features[curToken + n - 1]); - CBV filtered = features; - if (m_filter->NeedToFilter(features)) - filtered = m_filter->Filter(features); + CBV filtered = features.m_features; + if (m_filter->NeedToFilter(features.m_features)) + filtered = m_filter->Filter(features.m_features); bool const looksLikeHouseNumber = house_numbers::LooksLikeHouseNumber( layers.back().m_subQuery, layers.back().m_lastTokenIsPrefix); @@ -1307,12 +1315,49 @@ void Geocoder::FindPaths(BaseContext & ctx) else m_matcher->SetPostcodes(nullptr); - m_finder.ForEachReachableVertex( - *m_matcher, sortedLayers, [this, &ctx, &innermostLayer](IntersectionResult const & result) { - ASSERT(result.IsValid(), ()); - EmitResult(ctx, m_context->GetId(), result.InnermostResult(), innermostLayer.m_type, - innermostLayer.m_tokenRange, &result, ctx.AllTokensUsed()); - }); + auto isExactMatch = [](BaseContext const & context, IntersectionResult const & result) { + bool regionsChecked = false; + for (size_t i = 0; i < context.m_tokens.size(); ++i) + { + auto const tokenType = context.m_tokens[i]; + auto id = IntersectionResult::kInvalidId; + + if (tokenType == BaseContext::TokenType::TOKEN_TYPE_POI) + id = result.m_poi; + if (tokenType == BaseContext::TokenType::TOKEN_TYPE_STREET) + id = result.m_street; + + if (id != IntersectionResult::kInvalidId && context.m_features[i].m_features.HasBit(id) && + !context.m_features[i].m_exactMatchingFeatures.HasBit(id)) + { + return false; + } + + auto const isCityOrVillage = tokenType == BaseContext::TokenType::TOKEN_TYPE_CITY || + tokenType == BaseContext::TokenType::TOKEN_TYPE_VILLAGE; + if (isCityOrVillage && context.m_city && !context.m_city->m_exactMatch) + return false; + + auto const isRegion = tokenType == BaseContext::TokenType::TOKEN_TYPE_STATE || + tokenType == BaseContext::TokenType::TOKEN_TYPE_COUNTRY; + if (isRegion && !regionsChecked) + { + for (auto const & region : context.m_regions) + { + if (!region->m_exactMatch) + return false; + } + } + } + return true; + }; + + m_finder.ForEachReachableVertex(*m_matcher, sortedLayers, [&](IntersectionResult const & result) { + ASSERT(result.IsValid(), ()); + EmitResult(ctx, m_context->GetId(), result.InnermostResult(), innermostLayer.m_type, + innermostLayer.m_tokenRange, &result, ctx.AllTokensUsed(), + isExactMatch(ctx, result)); + }); } void Geocoder::TraceResult(Tracer & tracer, BaseContext const & ctx, MwmSet::MwmId const & mwmId, @@ -1340,7 +1385,7 @@ void Geocoder::TraceResult(Tracer & tracer, BaseContext const & ctx, MwmSet::Mwm void Geocoder::EmitResult(BaseContext & ctx, MwmSet::MwmId const & mwmId, uint32_t ftId, Model::Type type, TokenRange const & tokenRange, - IntersectionResult const * geoParts, bool allTokensUsed) + IntersectionResult const * geoParts, bool allTokensUsed, bool exactMatch) { FeatureID id(mwmId, ftId); @@ -1380,6 +1425,7 @@ void Geocoder::EmitResult(BaseContext & ctx, MwmSet::MwmId const & mwmId, uint32 info.m_geoParts = *geoParts; info.m_allTokensUsed = allTokensUsed; + info.m_exactMatch = exactMatch; m_preRanker.Emplace(id, info, m_resultTracer.GetProvenance()); @@ -1387,18 +1433,18 @@ void Geocoder::EmitResult(BaseContext & ctx, MwmSet::MwmId const & mwmId, uint32 } void Geocoder::EmitResult(BaseContext & ctx, Region const & region, TokenRange const & tokenRange, - bool allTokensUsed) + bool allTokensUsed, bool exactMatch) { auto const type = Region::ToModelType(region.m_type); EmitResult(ctx, region.m_countryId, region.m_featureId, type, tokenRange, nullptr /* geoParts */, - allTokensUsed); + allTokensUsed, exactMatch); } void Geocoder::EmitResult(BaseContext & ctx, City const & city, TokenRange const & tokenRange, - bool allTokensUsed) + bool allTokensUsed, bool exactMatch) { EmitResult(ctx, city.m_countryId, city.m_featureId, city.m_type, tokenRange, - nullptr /* geoParts */, allTokensUsed); + nullptr /* geoParts */, allTokensUsed, exactMatch); } void Geocoder::MatchUnclassified(BaseContext & ctx, size_t curToken) @@ -1417,7 +1463,7 @@ void Geocoder::MatchUnclassified(BaseContext & ctx, size_t curToken) if (ctx.NumUnusedTokenGroups() != 1) return; - CBV allFeatures; + Retrieval::ExtendedFeatures allFeatures; allFeatures.SetFull(); auto startToken = curToken; @@ -1427,19 +1473,20 @@ void Geocoder::MatchUnclassified(BaseContext & ctx, size_t curToken) allFeatures = allFeatures.Intersect(ctx.m_features[curToken]); } - if (m_filter->NeedToFilter(allFeatures)) - allFeatures = m_filter->Filter(allFeatures); - - auto emitUnclassified = [&](uint64_t bit) + if (m_filter->NeedToFilter(allFeatures.m_features)) { - auto const featureId = base::asserted_cast(bit); + allFeatures.m_features = m_filter->Filter(allFeatures.m_features); + allFeatures.m_exactMatchingFeatures = m_filter->Filter(allFeatures.m_exactMatchingFeatures); + } + + auto emitUnclassified = [&](uint32_t featureId, bool exactMatch) { Model::Type type; if (!GetTypeInGeocoding(ctx, featureId, type)) return; if (type == Model::TYPE_UNCLASSIFIED) { EmitResult(ctx, m_context->GetId(), featureId, type, TokenRange(startToken, curToken), - nullptr /* geoParts */, true /* allTokensUsed */); + nullptr /* geoParts */, true /* allTokensUsed */, exactMatch); } }; allFeatures.ForEach(emitUnclassified); diff --git a/search/geocoder.hpp b/search/geocoder.hpp index 6cb0db9a68..6c866702c4 100644 --- a/search/geocoder.hpp +++ b/search/geocoder.hpp @@ -225,11 +225,11 @@ private: // Forms result and feeds it to |m_preRanker|. void EmitResult(BaseContext & ctx, MwmSet::MwmId const & mwmId, uint32_t ftId, Model::Type type, TokenRange const & tokenRange, IntersectionResult const * geoParts, - bool allTokensUsed); + bool allTokensUsed, bool exactMatch); void EmitResult(BaseContext & ctx, Region const & region, TokenRange const & tokenRange, - bool allTokensUsed); + bool allTokensUsed, bool exactMatch); void EmitResult(BaseContext & ctx, City const & city, TokenRange const & tokenRange, - bool allTokensUsed); + bool allTokensUsed, bool exactMatch); // Tries to match unclassified objects from lower layers, like // parks, forests, lakes, rivers, etc. This method finds all diff --git a/search/geocoder_context.hpp b/search/geocoder_context.hpp index a72d896ece..401ebcf866 100644 --- a/search/geocoder_context.hpp +++ b/search/geocoder_context.hpp @@ -6,6 +6,7 @@ #include "search/geocoder_locality.hpp" #include "search/hotels_filter.hpp" #include "search/model.hpp" +#include "search/retrieval.hpp" #include #include @@ -55,7 +56,7 @@ struct BaseContext // List of bit-vectors of features, where i-th element of the list // corresponds to the i-th token in the search query. - std::vector m_features; + std::vector m_features; CBV m_villages; CBV m_streets; diff --git a/search/geocoder_locality.hpp b/search/geocoder_locality.hpp index 7e8a6dd190..10c8401b96 100644 --- a/search/geocoder_locality.hpp +++ b/search/geocoder_locality.hpp @@ -21,8 +21,12 @@ class IdfMap; struct Locality { Locality(MwmSet::MwmId const & countryId, uint32_t featureId, TokenRange const & tokenRange, - QueryVec const & queryVec) - : m_countryId(countryId), m_featureId(featureId), m_tokenRange(tokenRange), m_queryVec(queryVec) + QueryVec const & queryVec, bool exactMatch) + : m_countryId(countryId) + , m_featureId(featureId) + , m_tokenRange(tokenRange) + , m_queryVec(queryVec) + , m_exactMatch(exactMatch) { } @@ -32,6 +36,7 @@ struct Locality uint32_t m_featureId = 0; TokenRange m_tokenRange; QueryVec m_queryVec; + bool m_exactMatch; }; // This struct represents a country or US- or Canadian- state. It diff --git a/search/intermediate_result.hpp b/search/intermediate_result.hpp index 6a865354f1..d3a094e8f2 100644 --- a/search/intermediate_result.hpp +++ b/search/intermediate_result.hpp @@ -15,7 +15,6 @@ #include class FeatureType; -class CategoriesHolder; namespace storage { diff --git a/search/locality_scorer.cpp b/search/locality_scorer.cpp index 95f787b6fb..409c85f062 100644 --- a/search/locality_scorer.cpp +++ b/search/locality_scorer.cpp @@ -3,6 +3,7 @@ #include "search/cbv.hpp" #include "search/geocoder_context.hpp" #include "search/idf_map.hpp" +#include "search/retrieval.hpp" #include "search/token_slice.hpp" #include "search/utils.hpp" @@ -91,15 +92,15 @@ void LocalityScorer::GetTopLocalities(MwmSet::MwmId const & countryId, BaseConte localities.clear(); - vector intersections(ctx.m_numTokens); + vector intersections(ctx.m_numTokens); vector> tokensToDf; vector, uint64_t>> prefixToDf; bool const havePrefix = ctx.m_numTokens > 0 && m_params.LastTokenIsPrefix(); size_t const nonPrefixTokens = havePrefix ? ctx.m_numTokens - 1 : ctx.m_numTokens; for (size_t i = 0; i < nonPrefixTokens; ++i) { - intersections[i] = filter.Intersect(ctx.m_features[i]); - auto const df = intersections.back().PopCount(); + intersections[i] = ctx.m_features[i].Intersect(filter); + auto const df = intersections.back().m_features.PopCount(); if (df != 0) { m_params.GetToken(i).ForEach([&tokensToDf, &df](UniString const & s) { @@ -111,8 +112,8 @@ void LocalityScorer::GetTopLocalities(MwmSet::MwmId const & countryId, BaseConte if (havePrefix) { auto const count = ctx.m_numTokens - 1; - intersections[count] = filter.Intersect(ctx.m_features[count]); - auto const prefixDf = intersections.back().PopCount(); + intersections[count] = ctx.m_features[count].Intersect(filter); + auto const prefixDf = intersections.back().m_features.PopCount(); if (prefixDf != 0) { m_params.GetToken(count).ForEach([&prefixToDf, &prefixDf](UniString const & s) { @@ -130,8 +131,8 @@ void LocalityScorer::GetTopLocalities(MwmSet::MwmId const & countryId, BaseConte auto intersection = intersections[startToken]; QueryVec::Builder builder; - for (size_t endToken = startToken + 1; endToken <= ctx.m_numTokens && !intersection.IsEmpty(); - ++endToken) + for (size_t endToken = startToken + 1; + endToken <= ctx.m_numTokens && !intersection.m_features.IsEmpty(); ++endToken) { auto const curToken = endToken - 1; auto const & token = m_params.GetToken(curToken).GetOriginal(); @@ -144,9 +145,9 @@ void LocalityScorer::GetTopLocalities(MwmSet::MwmId const & countryId, BaseConte // Skip locality candidates that match only numbers. if (!m_params.IsNumberTokens(tokenRange)) { - intersection.ForEach([&](uint64_t bit) { - auto const featureId = base::asserted_cast(bit); - localities.emplace_back(countryId, featureId, tokenRange, QueryVec(idfs, builder)); + intersection.ForEach([&](uint32_t featureId, bool exactMatch) { + localities.emplace_back(countryId, featureId, tokenRange, QueryVec(idfs, builder), + exactMatch); }); } @@ -172,8 +173,8 @@ void LocalityScorer::LeaveTopLocalities(IdfMap & idfs, size_t limit, // We don't want to read too many names for localities, so this is // the best effort - select the best features by available params - - // query norm and rank. - LeaveTopByNormAndRank(max(limit, kDefaultReadLimit) /* limitUniqueIds */, els); + // exactMatch, query norm and rank. + LeaveTopByExactMatchNormAndRank(max(limit, kDefaultReadLimit) /* limitUniqueIds */, els); sort(els.begin(), els.end(), [](ExLocality const & lhs, ExLocality const & rhs) { return lhs.GetId() < rhs.GetId(); }); @@ -206,9 +207,12 @@ void LocalityScorer::LeaveTopLocalities(IdfMap & idfs, size_t limit, ASSERT_LESS_OR_EQUAL(localities.size(), limit, ()); } -void LocalityScorer::LeaveTopByNormAndRank(size_t limitUniqueIds, vector & els) const +void LocalityScorer::LeaveTopByExactMatchNormAndRank(size_t limitUniqueIds, + vector & els) const { sort(els.begin(), els.end(), [](ExLocality const & lhs, ExLocality const & rhs) { + if (lhs.m_locality.m_exactMatch != rhs.m_locality.m_exactMatch) + return lhs.m_locality.m_exactMatch; auto const ln = lhs.m_queryNorm; auto const rn = rhs.m_queryNorm; if (ln != rn) diff --git a/search/locality_scorer.hpp b/search/locality_scorer.hpp index b693972b2b..45a0aec325 100644 --- a/search/locality_scorer.hpp +++ b/search/locality_scorer.hpp @@ -59,10 +59,10 @@ private: // combination of ranks and number of matched tokens. void LeaveTopLocalities(IdfMap & idfs, size_t limit, std::vector & localities) const; - // Selects at most |limitUniqueIds| best features by query norm and + // Selects at most |limitUniqueIds| best features by exact match, query norm and // rank, and then leaves only localities corresponding to those // features in |els|. - void LeaveTopByNormAndRank(size_t limitUniqueIds, std::vector & els) const; + void LeaveTopByExactMatchNormAndRank(size_t limitUniqueIds, std::vector & els) const; // Leaves at most |limit| unique best localities by similarity to // the query and rank. diff --git a/search/pre_ranking_info.cpp b/search/pre_ranking_info.cpp index c5a4afa668..c42af8add3 100644 --- a/search/pre_ranking_info.cpp +++ b/search/pre_ranking_info.cpp @@ -8,7 +8,7 @@ std::string DebugPrint(PreRankingInfo const & info) { std::ostringstream os; os << "PreRankingInfo ["; - os << "m_distanceToPivot:" << info.m_distanceToPivot << ", "; + os << "m_distanceToPivot: " << info.m_distanceToPivot << ", "; for (size_t i = 0; i < static_cast(Model::TYPE_COUNT); ++i) { if (info.m_tokenRange[i].Empty()) @@ -17,8 +17,10 @@ std::string DebugPrint(PreRankingInfo const & info) auto const type = static_cast(i); os << "m_tokenRange[" << DebugPrint(type) << "]:" << DebugPrint(info.m_tokenRange[i]) << ", "; } - os << "m_rank:" << static_cast(info.m_rank) << ", "; - os << "m_popularity:" << static_cast(info.m_popularity) << ", "; + os << "m_allTokensUsed: " << info.m_allTokensUsed << ", "; + os << "m_exactMatch: " << info.m_exactMatch << ", "; + os << "m_rank: " << static_cast(info.m_rank) << ", "; + os << "m_popularity: " << static_cast(info.m_popularity) << ", "; os << "m_rating: [" << static_cast(info.m_rating.first) << ", "<< info.m_rating.second << "], "; os << "m_type:" << info.m_type; os << "]"; diff --git a/search/pre_ranking_info.hpp b/search/pre_ranking_info.hpp index 8150bf0e36..18a57553d5 100644 --- a/search/pre_ranking_info.hpp +++ b/search/pre_ranking_info.hpp @@ -54,6 +54,9 @@ struct PreRankingInfo // were used when retrieving the feature. bool m_allTokensUsed = true; + // True iff all tokens retrieved from search index were matched without misprints. + bool m_exactMatch = true; + // Rank of the feature. uint8_t m_rank = 0; diff --git a/search/ranker.cpp b/search/ranker.cpp index 2e5bb0fd73..ce3ca18015 100644 --- a/search/ranker.cpp +++ b/search/ranker.cpp @@ -301,6 +301,7 @@ class RankerResultMaker info.m_rating = preInfo.m_rating; info.m_type = preInfo.m_type; info.m_allTokensUsed = preInfo.m_allTokensUsed; + info.m_exactMatch = preInfo.m_exactMatch; info.m_categorialRequest = m_params.IsCategorialRequest(); info.m_hasName = ft.HasName(); diff --git a/search/ranking_info.hpp b/search/ranking_info.hpp index 2668ea0f45..8250d6a12b 100644 --- a/search/ranking_info.hpp +++ b/search/ranking_info.hpp @@ -50,6 +50,9 @@ struct RankingInfo // were used when retrieving the feature. bool m_allTokensUsed = true; + // True iff all tokens retrieved from search index were matched without misprints. + bool m_exactMatch = true; + // Search type for the feature. Model::Type m_type = Model::TYPE_COUNT; diff --git a/search/retrieval.cpp b/search/retrieval.cpp index 4e7dc28a6c..1108d45a69 100644 --- a/search/retrieval.cpp +++ b/search/retrieval.cpp @@ -31,8 +31,6 @@ #include #include -#include -#include #include using namespace std; @@ -125,14 +123,23 @@ private: vector m_created; }; -Retrieval::ExtendedFeatures SortFeaturesAndBuildCBV(vector && features, - vector && exactlyMatchedFeatures) +Retrieval::ExtendedFeatures SortFeaturesAndBuildResult(vector && features, + vector && exactlyMatchedFeatures) { using Builder = coding::CompressedBitVectorBuilder; base::SortUnique(features); base::SortUnique(exactlyMatchedFeatures); - return {Builder::FromBitPositions(move(features)), - Builder::FromBitPositions(move(exactlyMatchedFeatures))}; + auto featuresCBV = CBV(Builder::FromBitPositions(move(features))); + auto exactlyMatchedFeaturesCBV = CBV(Builder::FromBitPositions(move(exactlyMatchedFeatures))); + return Retrieval::ExtendedFeatures(move(featuresCBV), move(exactlyMatchedFeaturesCBV)); +} + +Retrieval::ExtendedFeatures SortFeaturesAndBuildResult(vector && features) +{ + using Builder = coding::CompressedBitVectorBuilder; + base::SortUnique(features); + auto const featuresCBV = CBV(Builder::FromBitPositions(move(features))); + return Retrieval::ExtendedFeatures(featuresCBV); } template @@ -254,7 +261,7 @@ Retrieval::ExtendedFeatures RetrieveAddressFeaturesImpl(Retrieval::TrieRoot @@ -280,7 +287,7 @@ Retrieval::ExtendedFeatures RetrievePostcodeFeaturesImpl(Retrieval::TrieRoot diff --git a/search/retrieval.hpp b/search/retrieval.hpp index 757df0d7c1..52200ab421 100644 --- a/search/retrieval.hpp +++ b/search/retrieval.hpp @@ -1,5 +1,6 @@ #pragma once +#include "search/cbv.hpp" #include "search/feature_offset_match.hpp" #include "search/query_params.hpp" @@ -10,18 +11,17 @@ #include "geometry/rect2d.hpp" #include "base/cancellable.hpp" +#include "base/checked_cast.hpp" #include "base/dfa_helpers.hpp" #include "base/levenshtein_dfa.hpp" +#include +#include #include +#include class MwmValue; -namespace coding -{ -class CompressedBitVector; -} - namespace search { class MwmContext; @@ -32,10 +32,56 @@ class Retrieval public: template using TrieRoot = trie::Iterator>; - using Features = std::unique_ptr; + using Features = search::CBV; struct ExtendedFeatures { + ExtendedFeatures() = default; + ExtendedFeatures(ExtendedFeatures const &) = default; + ExtendedFeatures(ExtendedFeatures &&) = default; + + explicit ExtendedFeatures(Features const & cbv) : m_features(cbv), m_exactMatchingFeatures(cbv) + { + } + + ExtendedFeatures(Features && features, Features && exactMatchingFeatures) + : m_features(std::move(features)), m_exactMatchingFeatures(std::move(exactMatchingFeatures)) + { + } + + ExtendedFeatures & operator=(ExtendedFeatures const &) = default; + ExtendedFeatures & operator=(ExtendedFeatures &&) = default; + + ExtendedFeatures Intersect(ExtendedFeatures const & rhs) const + { + ExtendedFeatures result; + result.m_features = m_features.Intersect(rhs.m_features); + result.m_exactMatchingFeatures = + m_exactMatchingFeatures.Intersect(rhs.m_exactMatchingFeatures); + return result; + } + + ExtendedFeatures Intersect(Features const & cbv) const + { + ExtendedFeatures result; + result.m_features = m_features.Intersect(cbv); + result.m_exactMatchingFeatures = m_exactMatchingFeatures.Intersect(cbv); + return result; + } + + void SetFull() + { + m_features.SetFull(); + m_exactMatchingFeatures.SetFull(); + } + + void ForEach(std::function const & f) const + { + m_features.ForEach([&](uint64_t id) { + f(base::asserted_cast(id), m_exactMatchingFeatures.HasBit(id)); + }); + } + Features m_features; Features m_exactMatchingFeatures; }; diff --git a/search/search_integration_tests/processor_test.cpp b/search/search_integration_tests/processor_test.cpp index 349e2eb505..35563471dd 100644 --- a/search/search_integration_tests/processor_test.cpp +++ b/search/search_integration_tests/processor_test.cpp @@ -686,10 +686,10 @@ UNIT_CLASS_TEST(ProcessorTest, TestPostcodes) Retrieval retrieval(context, cancellable); auto features = retrieval.RetrievePostcodeFeatures( TokenSlice(params, TokenRange(0, params.GetNumTokens()))); - TEST_EQUAL(1, features->PopCount(), ()); + TEST_EQUAL(1, features.PopCount(), ()); uint64_t index = 0; - while (!features->GetBit(index)) + while (!features.HasBit(index)) ++index; FeaturesLoaderGuard loader(m_dataSource, countryId); @@ -1862,5 +1862,59 @@ UNIT_CLASS_TEST(ProcessorTest, RemoveDuplicatingStreets) TEST_EQUAL(GetResultsNumber(streetName, "ru"), 1, ()); } } + +UNIT_CLASS_TEST(ProcessorTest, ExactMatchTest) +{ + string const countryName = "Wonderland"; + + TestCafe lermontov(m2::PointD(1, 1), "Лермонтовъ", "ru"); + + TestCity lermontovo(m2::PointD(-1, -1), "Лермонтово", "ru", 0 /* rank */); + TestCafe cafe(m2::PointD(-1.01, -1.01), "", "ru"); + + auto worldId = BuildWorld([&](TestMwmBuilder & builder) { builder.Add(lermontovo); }); + auto wonderlandId = BuildCountry(countryName, [&](TestMwmBuilder & builder) { + builder.Add(cafe); + builder.Add(lermontov); + }); + + { + auto request = MakeRequest("cafe лермонтовъ "); + auto const & results = request->Results(); + + Rules rules{ExactMatch(wonderlandId, cafe), ExactMatch(wonderlandId, lermontov)}; + TEST(ResultsMatch(results, rules), ()); + + TEST_EQUAL(2, results.size(), ("Unexpected number of retrieved cafes.")); + TEST(ResultsMatch({results[0]}, {ExactMatch(wonderlandId, lermontov)}), ()); + TEST(results[0].GetRankingInfo().m_exactMatch, ()); + TEST(!results[1].GetRankingInfo().m_exactMatch, ()); + } + + { + auto request = MakeRequest("cafe лермонтово "); + auto const & results = request->Results(); + + Rules rules{ExactMatch(wonderlandId, cafe), ExactMatch(wonderlandId, lermontov)}; + TEST(ResultsMatch(results, rules), ()); + + TEST_EQUAL(2, results.size(), ("Unexpected number of retrieved cafes.")); + TEST(ResultsMatch({results[0]}, {ExactMatch(wonderlandId, cafe)}), ()); + TEST(results[0].GetRankingInfo().m_exactMatch, ()); + TEST(!results[1].GetRankingInfo().m_exactMatch, ()); + } + + { + auto request = MakeRequest("cafe лермонтов "); + auto const & results = request->Results(); + + Rules rules{ExactMatch(wonderlandId, cafe), ExactMatch(wonderlandId, lermontov)}; + TEST(ResultsMatch(results, rules), ()); + + TEST_EQUAL(2, results.size(), ("Unexpected number of retrieved cafes.")); + TEST(!results[0].GetRankingInfo().m_exactMatch, ()); + TEST(!results[1].GetRankingInfo().m_exactMatch, ()); + } +} } // namespace } // namespace search diff --git a/search/search_tests/locality_scorer_test.cpp b/search/search_tests/locality_scorer_test.cpp index 15d1ffc534..8444195069 100644 --- a/search/search_tests/locality_scorer_test.cpp +++ b/search/search_tests/locality_scorer_test.cpp @@ -93,7 +93,7 @@ public: }); base::SortUnique(ids); - ctx.m_features.emplace_back(coding::CompressedBitVectorBuilder::FromBitPositions(ids)); + ctx.m_features.emplace_back(CBV(coding::CompressedBitVectorBuilder::FromBitPositions(ids))); } CBV filter; diff --git a/search/streets_matcher.cpp b/search/streets_matcher.cpp index fadaadfbf4..e65719aa31 100644 --- a/search/streets_matcher.cpp +++ b/search/streets_matcher.cpp @@ -126,7 +126,7 @@ void StreetsMatcher::FindStreets(BaseContext const & ctx, FeaturesFilter const & StreetTokensFilter filter([&](strings::UniString const & /* token */, size_t tag) { - auto buffer = streets.Intersect(ctx.m_features[tag]); + auto buffer = streets.Intersect(ctx.m_features[tag].m_features); if (tag < curToken) { // This is the case for delayed @@ -134,7 +134,7 @@ void StreetsMatcher::FindStreets(BaseContext const & ctx, FeaturesFilter const & // |streets| is temporarily in the // incomplete state. streets = buffer; - all = all.Intersect(ctx.m_features[tag]); + all = all.Intersect(ctx.m_features[tag].m_features); emptyIntersection = false; incomplete = true; @@ -149,7 +149,7 @@ void StreetsMatcher::FindStreets(BaseContext const & ctx, FeaturesFilter const & emit(); streets = buffer; - all = all.Intersect(ctx.m_features[tag]); + all = all.Intersect(ctx.m_features[tag].m_features); emptyIntersection = false; incomplete = false; });