From 994df64d580cfaa019408a2383da77a98ecca494 Mon Sep 17 00:00:00 2001 From: Yuri Gorshenin Date: Wed, 15 Feb 2017 11:52:51 +0300 Subject: [PATCH] [search] Ranking with errors. --- search/geocoder.cpp | 27 +++-- search/geocoder.hpp | 3 +- search/intersection_result.cpp | 15 +-- search/intersection_result.hpp | 17 ++-- search/pre_ranking_info.hpp | 4 + search/ranker.cpp | 98 ++++++++++++------- search/ranker.hpp | 3 +- .../search_integration_tests/CMakeLists.txt | 1 + search/search_integration_tests/helpers.cpp | 23 ++++- search/search_integration_tests/helpers.hpp | 6 ++ .../processor_test.cpp | 39 ++------ .../search_integration_tests/ranker_test.cpp | 52 ++++++++++ 12 files changed, 187 insertions(+), 101 deletions(-) create mode 100644 search/search_integration_tests/ranker_test.cpp diff --git a/search/geocoder.cpp b/search/geocoder.cpp index 41a65e01ef..14dfdd46dc 100644 --- a/search/geocoder.cpp +++ b/search/geocoder.cpp @@ -965,7 +965,10 @@ void Geocoder::MatchPOIsAndBuildings(BaseContext & ctx, size_t curToken) filtered.ForEach([&](uint32_t id) { SearchModel::SearchType searchType; if (GetSearchTypeInGeocoding(ctx, id, searchType)) - EmitResult(ctx, m_context->GetId(), id, searchType, m_postcodes.m_tokenRange); + { + EmitResult(ctx, m_context->GetId(), id, searchType, m_postcodes.m_tokenRange, + nullptr /* geoParts */); + } }); return; } @@ -986,7 +989,7 @@ void Geocoder::MatchPOIsAndBuildings(BaseContext & ctx, size_t curToken) if (!m_postcodes.m_features.HasBit(id)) continue; EmitResult(ctx, m_context->GetId(), id, SearchModel::SEARCH_TYPE_STREET, - layers.back().m_tokenRange); + layers.back().m_tokenRange, nullptr /* geoParts */); } } @@ -1186,15 +1189,14 @@ void Geocoder::FindPaths(BaseContext const & ctx) *m_matcher, sortedLayers, [this, &ctx, &innermostLayer](IntersectionResult const & result) { ASSERT(result.IsValid(), ()); - // TODO(@y, @m, @vng): use rest fields of IntersectionResult for - // better scoring. EmitResult(ctx, m_context->GetId(), result.InnermostResult(), innermostLayer.m_type, - innermostLayer.m_tokenRange); + innermostLayer.m_tokenRange, &result); }); } void Geocoder::EmitResult(BaseContext const & ctx, MwmSet::MwmId const & mwmId, uint32_t ftId, - SearchModel::SearchType type, TokenRange const & tokenRange) + SearchModel::SearchType type, TokenRange const & tokenRange, + IntersectionResult const * geoParts) { FeatureID id(mwmId, ftId); @@ -1220,6 +1222,9 @@ void Geocoder::EmitResult(BaseContext const & ctx, MwmSet::MwmId const & mwmId, if (ctx.m_city) info.m_tokenRange[SearchModel::SEARCH_TYPE_CITY] = ctx.m_city->m_tokenRange; + if (geoParts) + info.m_geoParts = *geoParts; + m_preRanker.Emplace(id, info); } @@ -1227,12 +1232,13 @@ void Geocoder::EmitResult(BaseContext const & ctx, Region const & region, TokenRange const & tokenRange) { auto const type = Region::ToSearchType(region.m_type); - EmitResult(ctx, region.m_countryId, region.m_featureId, type, tokenRange); + EmitResult(ctx, region.m_countryId, region.m_featureId, type, tokenRange, nullptr /* geoParts */); } void Geocoder::EmitResult(BaseContext const & ctx, City const & city, TokenRange const & tokenRange) { - EmitResult(ctx, city.m_countryId, city.m_featureId, city.m_type, tokenRange); + EmitResult(ctx, city.m_countryId, city.m_featureId, city.m_type, tokenRange, + nullptr /* geoParts */); } void Geocoder::MatchUnclassified(BaseContext & ctx, size_t curToken) @@ -1268,7 +1274,10 @@ void Geocoder::MatchUnclassified(BaseContext & ctx, size_t curToken) if (!GetSearchTypeInGeocoding(ctx, featureId, searchType)) return; if (searchType == SearchModel::SEARCH_TYPE_UNCLASSIFIED) - EmitResult(ctx, m_context->GetId(), featureId, searchType, TokenRange(startToken, curToken)); + { + EmitResult(ctx, m_context->GetId(), featureId, searchType, TokenRange(startToken, curToken), + nullptr /* geoParts */); + } }; allFeatures.ForEach(emitUnclassified); } diff --git a/search/geocoder.hpp b/search/geocoder.hpp index b489ad237f..9a89912919 100644 --- a/search/geocoder.hpp +++ b/search/geocoder.hpp @@ -196,7 +196,8 @@ private: // Forms result and feeds it to |m_preRanker|. void EmitResult(BaseContext const & ctx, MwmSet::MwmId const & mwmId, uint32_t ftId, - SearchModel::SearchType type, TokenRange const & tokenRange); + SearchModel::SearchType type, TokenRange const & tokenRange, + IntersectionResult const * m_geoParts); void EmitResult(BaseContext const & ctx, Region const & region, TokenRange const & tokenRange); void EmitResult(BaseContext const & ctx, City const & city, TokenRange const & tokenRange); diff --git a/search/intersection_result.cpp b/search/intersection_result.cpp index 72dd5bdff6..898f2c097d 100644 --- a/search/intersection_result.cpp +++ b/search/intersection_result.cpp @@ -1,17 +1,11 @@ #include "search/intersection_result.hpp" -#include "std/limits.hpp" -#include "std/sstream.hpp" +#include namespace search { // static -uint32_t const IntersectionResult::kInvalidId = numeric_limits::max(); - -IntersectionResult::IntersectionResult() - : m_poi(kInvalidId), m_building(kInvalidId), m_street(kInvalidId) -{ -} +uint32_t const IntersectionResult::kInvalidId; void IntersectionResult::Set(SearchModel::SearchType type, uint32_t id) { @@ -47,9 +41,9 @@ void IntersectionResult::Clear() m_street = kInvalidId; } -string DebugPrint(IntersectionResult const & result) +std::string DebugPrint(IntersectionResult const & result) { - ostringstream os; + std::ostringstream os; os << "IntersectionResult [ "; if (result.m_poi != IntersectionResult::kInvalidId) os << "POI:" << result.m_poi << " "; @@ -60,5 +54,4 @@ string DebugPrint(IntersectionResult const & result) os << "]"; return os.str(); } - } // namespace search diff --git a/search/intersection_result.hpp b/search/intersection_result.hpp index 6bea9335a5..3e493dff2d 100644 --- a/search/intersection_result.hpp +++ b/search/intersection_result.hpp @@ -2,8 +2,9 @@ #include "search/model.hpp" -#include "std/cstdint.hpp" -#include "std/string.hpp" +#include +#include +#include namespace search { @@ -11,9 +12,9 @@ namespace search // i.e. BUILDING and STREET for POI or STREET for BUILDING. struct IntersectionResult { - static uint32_t const kInvalidId; + static uint32_t const kInvalidId = std::numeric_limits::max(); - IntersectionResult(); + IntersectionResult() = default; void Set(SearchModel::SearchType type, uint32_t id); @@ -27,10 +28,10 @@ struct IntersectionResult // Clears all fields to an invalid state. void Clear(); - uint32_t m_poi; - uint32_t m_building; - uint32_t m_street; + uint32_t m_poi = kInvalidId; + uint32_t m_building = kInvalidId; + uint32_t m_street = kInvalidId; }; -string DebugPrint(IntersectionResult const & result); +std::string DebugPrint(IntersectionResult const & result); } // namespace search diff --git a/search/pre_ranking_info.hpp b/search/pre_ranking_info.hpp index a251cd9951..a6657bd498 100644 --- a/search/pre_ranking_info.hpp +++ b/search/pre_ranking_info.hpp @@ -1,5 +1,6 @@ #pragma once +#include "search/intersection_result.hpp" #include "search/model.hpp" #include "search/token_range.hpp" @@ -39,6 +40,9 @@ struct PreRankingInfo // Tokens match to the feature name or house number. TokenRange m_tokenRange[SearchModel::SEARCH_TYPE_COUNT]; + // Different geo-parts participated in search. + IntersectionResult m_geoParts; + // Rank of the feature. uint8_t m_rank = 0; diff --git a/search/ranker.cpp b/search/ranker.cpp index b2e708a162..ca95c91b29 100644 --- a/search/ranker.cpp +++ b/search/ranker.cpp @@ -35,6 +35,31 @@ void UpdateNameScore(vector const & tokens, TSlice const & s bestScore = score; } +NameScore GetNameScore(FeatureType const & ft, Geocoder::Params const & params, + TokenRange const & range, SearchModel::SearchType type) +{ + NameScore bestScore = NAME_SCORE_ZERO; + TokenSlice slice(params, range); + TokenSliceNoCategories sliceNoCategories(params, range); + + for (auto const & lang : params.GetLangs()) + { + string name; + if (!ft.GetName(lang, name)) + continue; + vector tokens; + PrepareStringForMatching(name, tokens); + + UpdateNameScore(tokens, slice, bestScore); + UpdateNameScore(tokens, sliceNoCategories, bestScore); + } + + if (type == SearchModel::SEARCH_TYPE_BUILDING) + UpdateNameScore(ft.GetHouseNumber(), sliceNoCategories, bestScore); + + return bestScore; +} + void RemoveDuplicatingLinear(vector & values) { PreResult2::LessLinearTypesF lessCmp; @@ -141,29 +166,36 @@ class PreResult2Maker Geocoder::Params const & m_params; storage::CountryInfoGetter const & m_infoGetter; - unique_ptr m_pFV; + unique_ptr m_loader; - // For the best performance, incoming id's should be sorted by id.first (mwm file id). - bool LoadFeature(FeatureID const & id, FeatureType & f, m2::PointD & center, string & name, - string & country) + bool LoadFeature(FeatureID const & id, FeatureType & ft) { - if (m_pFV.get() == 0 || m_pFV->GetId() != id.m_mwmId) - m_pFV.reset(new Index::FeaturesLoaderGuard(m_index, id.m_mwmId)); - - if (!m_pFV->GetFeatureByIndex(id.m_index, f)) + if (!m_loader || m_loader->GetId() != id.m_mwmId) + m_loader = make_unique(m_index, id.m_mwmId); + if (!m_loader->GetFeatureByIndex(id.m_index, ft)) return false; - f.SetID(id); + ft.SetID(id); + return true; + } - center = feature::GetCenter(f); + // For the best performance, incoming id's should be sorted by id.first (mwm file id). + bool LoadFeature(FeatureID const & id, FeatureType & ft, m2::PointD & center, string & name, + string & country) + { + if (!LoadFeature(id, ft)) + return false; - m_ranker.GetBestMatchName(f, name); + center = feature::GetCenter(ft); + m_ranker.GetBestMatchName(ft, name); - // country (region) name is a file name if feature isn't from World.mwm - if (m_pFV->IsWorld()) + // Country (region) name is a file name if feature isn't from + // World.mwm. + ASSERT(m_loader && m_loader->GetId() == id.m_mwmId, ()); + if (m_loader->IsWorld()) country.clear(); else - country = m_pFV->GetCountryFileName(); + country = m_loader->GetCountryFileName(); return true; } @@ -178,26 +210,23 @@ class PreResult2Maker info.m_distanceToPivot = MercatorBounds::DistanceOnEarth(center, pivot); info.m_rank = preInfo.m_rank; info.m_searchType = preInfo.m_searchType; - info.m_nameScore = NAME_SCORE_ZERO; + info.m_nameScore = GetNameScore(ft, m_params, preInfo.InnermostTokenRange(), info.m_searchType); - TokenSlice slice(m_params, preInfo.InnermostTokenRange()); - TokenSliceNoCategories sliceNoCategories(m_params, preInfo.InnermostTokenRange()); - - for (auto const & lang : m_params.GetLangs()) + if (info.m_searchType != SearchModel::SEARCH_TYPE_STREET && + preInfo.m_geoParts.m_street != IntersectionResult::kInvalidId) { - string name; - if (!ft.GetName(lang, name)) - continue; - vector tokens; - PrepareStringForMatching(name, tokens); - - UpdateNameScore(tokens, slice, info.m_nameScore); - UpdateNameScore(tokens, sliceNoCategories, info.m_nameScore); + auto const & mwmId = ft.GetID().m_mwmId; + FeatureType street; + if (LoadFeature(FeatureID(mwmId, preInfo.m_geoParts.m_street), street)) + { + NameScore const nameScore = + GetNameScore(street, m_params, preInfo.m_tokenRange[SearchModel::SEARCH_TYPE_STREET], + SearchModel::SEARCH_TYPE_STREET); + info.m_nameScore = min(info.m_nameScore, nameScore); + } } - if (info.m_searchType == SearchModel::SEARCH_TYPE_BUILDING) - UpdateNameScore(ft.GetHouseNumber(), sliceNoCategories, info.m_nameScore); - + TokenSlice slice(m_params, preInfo.InnermostTokenRange()); feature::TypesHolder holder(ft); vector> matched(slice.Size()); ForEachCategoryType(QuerySlice(slice), m_ranker.m_params.m_categoryLocales, @@ -312,8 +341,7 @@ bool Ranker::IsResultExists(PreResult2 const & p, vector const & v }); } -void Ranker::MakePreResult2(Geocoder::Params const & geocoderParams, vector & cont, - vector & streets) +void Ranker::MakePreResult2(Geocoder::Params const & geocoderParams, vector & cont) { PreResult2Maker maker(*this, m_index, m_infoGetter, geocoderParams); for (auto const & r : m_preResults1) @@ -328,9 +356,6 @@ void Ranker::MakePreResult2(Geocoder::Params const & geocoderParams, vectorIsStreet()) - streets.push_back(p->GetID()); - if (!IsResultExists(*p, cont)) cont.push_back(IndexedValue(move(p))); }; @@ -492,8 +517,7 @@ void Ranker::UpdateResults(bool lastUpdate) { BailIfCancelled(); - vector streets; - MakePreResult2(m_geocoderParams, m_tentativeResults, streets); + MakePreResult2(m_geocoderParams, m_tentativeResults); RemoveDuplicatingLinear(m_tentativeResults); if (m_tentativeResults.empty()) return; diff --git a/search/ranker.hpp b/search/ranker.hpp index 0f1950f37b..6e220824df 100644 --- a/search/ranker.hpp +++ b/search/ranker.hpp @@ -81,8 +81,7 @@ public: bool IsResultExists(PreResult2 const & p, vector const & values); - void MakePreResult2(Geocoder::Params const & params, vector & cont, - vector & streets); + void MakePreResult2(Geocoder::Params const & params, vector & cont); Result MakeResult(PreResult2 const & r) const; void MakeResultHighlight(Result & res) const; diff --git a/search/search_integration_tests/CMakeLists.txt b/search/search_integration_tests/CMakeLists.txt index 1b69474920..bb22141e27 100644 --- a/search/search_integration_tests/CMakeLists.txt +++ b/search/search_integration_tests/CMakeLists.txt @@ -9,6 +9,7 @@ set( interactive_search_test.cpp pre_ranker_test.cpp processor_test.cpp + ranker_test.cpp search_edited_features_test.cpp smoke_test.cpp ) diff --git a/search/search_integration_tests/helpers.cpp b/search/search_integration_tests/helpers.cpp index 111d08718d..b1849dad69 100644 --- a/search/search_integration_tests/helpers.cpp +++ b/search/search_integration_tests/helpers.cpp @@ -67,13 +67,28 @@ bool SearchTest::ResultsMatch(string const & query, Mode mode, return MatchResults(m_engine, rules, request.Results()); } +bool SearchTest::ResultsMatch(vector const & results, TRules const & rules) +{ + return MatchResults(m_engine, rules, results); +} + +unique_ptr SearchTest::MakeRequest(string const & query) +{ + SearchParams params; + params.m_query = query; + params.m_inputLocale = "en"; + params.m_mode = Mode::Everywhere; + params.m_suggestsEnabled = false; + + auto request = make_unique(m_engine, params, m_viewport); + request->Run(); + return request; +} + size_t SearchTest::CountFeatures(m2::RectD const & rect) { size_t count = 0; - auto counter = [&count](const FeatureType & /* ft */) - { - ++count; - }; + auto counter = [&count](const FeatureType & /* ft */) { ++count; }; m_engine.ForEachInRect(counter, rect, scales::GetUpperScale()); return count; } diff --git a/search/search_integration_tests/helpers.hpp b/search/search_integration_tests/helpers.hpp index 0160bbfe7f..0c80cad599 100644 --- a/search/search_integration_tests/helpers.hpp +++ b/search/search_integration_tests/helpers.hpp @@ -2,6 +2,7 @@ #include "search/search_tests_support/test_results_matching.hpp" #include "search/search_tests_support/test_search_engine.hpp" +#include "search/search_tests_support/test_search_request.hpp" #include "generator/generator_tests_support/test_mwm_builder.hpp" @@ -101,12 +102,17 @@ public: } inline void SetViewport(m2::RectD const & viewport) { m_viewport = viewport; } + bool ResultsMatch(string const & query, TRules const & rules); bool ResultsMatch(string const & query, string const & locale, TRules const & rules); bool ResultsMatch(string const & query, Mode mode, TRules const & rules); + bool ResultsMatch(vector const & results, TRules const & rules); + + unique_ptr MakeRequest(string const & query); + size_t CountFeatures(m2::RectD const & rect); protected: diff --git a/search/search_integration_tests/processor_test.cpp b/search/search_integration_tests/processor_test.cpp index 63aa3b5342..ec892e44a6 100644 --- a/search/search_integration_tests/processor_test.cpp +++ b/search/search_integration_tests/processor_test.cpp @@ -66,25 +66,6 @@ private: class ProcessorTest : public SearchTest { -public: - unique_ptr MakeRequest(string const & query) - { - SearchParams params; - params.m_query = query; - params.m_inputLocale = "en"; - params.m_mode = Mode::Everywhere; - params.m_suggestsEnabled = false; - - auto request = make_unique(m_engine, params, m_viewport); - request->Run(); - return request; - } - - bool MatchResults(vector> rules, - vector const & actual) const - { - return ::MatchResults(m_engine, rules, actual); - } }; UNIT_CLASS_TEST(ProcessorTest, Smoke) @@ -337,7 +318,7 @@ UNIT_CLASS_TEST(ProcessorTest, DisableSuggests) request.Run(); TRules rules = {ExactMatch(worldId, london1), ExactMatch(worldId, london2)}; - TEST(MatchResults(rules, request.Results()), ()); + TEST(ResultsMatch(request.Results(), rules), ()); } } @@ -399,7 +380,7 @@ UNIT_CLASS_TEST(ProcessorTest, TestRankingInfo) TRules rules = {ExactMatch(wonderlandId, goldenGateBridge), ExactMatch(wonderlandId, goldenGateStreet)}; - TEST(MatchResults(rules, request->Results()), ()); + TEST(ResultsMatch(request->Results(), rules), ()); for (auto const & result : request->Results()) { auto const & info = result.GetRankingInfo(); @@ -416,11 +397,11 @@ UNIT_CLASS_TEST(ProcessorTest, TestRankingInfo) TRules rules{ExactMatch(wonderlandId, cafe1), ExactMatch(wonderlandId, cafe2), ExactMatch(wonderlandId, lermontov)}; - TEST(MatchResults(rules, results), ()); + TEST(ResultsMatch(results, rules), ()); TEST_EQUAL(3, results.size(), ("Unexpected number of retrieved cafes.")); auto const & top = results.front(); - TEST(MatchResults({ExactMatch(wonderlandId, lermontov)}, {top}), ()); + TEST(ResultsMatch({top}, {ExactMatch(wonderlandId, lermontov)}), ()); } { @@ -640,7 +621,7 @@ UNIT_CLASS_TEST(ProcessorTest, TestCategories) ExactMatch(wonderlandId, busStop)}; auto request = MakeRequest("atm"); - TEST(MatchResults(rules, request->Results()), ()); + TEST(ResultsMatch(request->Results(), rules), ()); for (auto const & result : request->Results()) { Index::FeaturesLoaderGuard loader(m_engine, wonderlandId); @@ -713,7 +694,7 @@ UNIT_CLASS_TEST(ProcessorTest, HotelsFiltering) TestSearchRequest request(m_engine, params, m_viewport); request.Run(); TRules rules = {ExactMatch(id, h1), ExactMatch(id, h2), ExactMatch(id, h3), ExactMatch(id, h4)}; - TEST(MatchResults(rules, request.Results()), ()); + TEST(ResultsMatch(request.Results(), rules), ()); } using namespace hotels_filter; @@ -723,7 +704,7 @@ UNIT_CLASS_TEST(ProcessorTest, HotelsFiltering) TestSearchRequest request(m_engine, params, m_viewport); request.Run(); TRules rules = {ExactMatch(id, h1), ExactMatch(id, h3)}; - TEST(MatchResults(rules, request.Results()), ()); + TEST(ResultsMatch(request.Results(), rules), ()); } params.m_hotelsFilter = Or(Eq(9.0), Le(4)); @@ -731,7 +712,7 @@ UNIT_CLASS_TEST(ProcessorTest, HotelsFiltering) TestSearchRequest request(m_engine, params, m_viewport); request.Run(); TRules rules = {ExactMatch(id, h1), ExactMatch(id, h3), ExactMatch(id, h4)}; - TEST(MatchResults(rules, request.Results()), ()); + TEST(ResultsMatch(request.Results(), rules), ()); } params.m_hotelsFilter = Or(And(Eq(7.0), Eq(5)), Eq(4)); @@ -739,7 +720,7 @@ UNIT_CLASS_TEST(ProcessorTest, HotelsFiltering) TestSearchRequest request(m_engine, params, m_viewport); request.Run(); TRules rules = {ExactMatch(id, h2), ExactMatch(id, h4)}; - TEST(MatchResults(rules, request.Results()), ()); + TEST(ResultsMatch(request.Results(), rules), ()); } } @@ -842,7 +823,7 @@ UNIT_CLASS_TEST(ProcessorTest, StopWords) TRules rules = {ExactMatch(id, street)}; auto const & results = request->Results(); - TEST(MatchResults(rules, results), ()); + TEST(ResultsMatch(results, rules), ()); auto const & info = results[0].GetRankingInfo(); TEST_EQUAL(info.m_nameScore, NAME_SCORE_FULL_MATCH, ()); diff --git a/search/search_integration_tests/ranker_test.cpp b/search/search_integration_tests/ranker_test.cpp new file mode 100644 index 0000000000..505e00ab3f --- /dev/null +++ b/search/search_integration_tests/ranker_test.cpp @@ -0,0 +1,52 @@ +#include "testing/testing.hpp" + +#include "search/search_integration_tests/helpers.hpp" +#include "search/search_tests_support/test_results_matching.hpp" + +#include "generator/generator_tests_support/test_feature.hpp" + + +using namespace generator::tests_support; +using namespace search::tests_support; +using namespace search; + +namespace +{ +class RankerTest : public SearchTest +{ +}; + +UNIT_CLASS_TEST(RankerTest, ErrorsInStreets) +{ + TestStreet mazurova( + vector{m2::PointD(-0.001, -0.001), m2::PointD(0, 0), m2::PointD(0.001, 0.001)}, + "Мазурова", "ru"); + TestBuilding mazurova14(m2::PointD(-0.001, -0.001), "", "14", mazurova, "ru"); + + TestStreet masherova( + vector{m2::PointD(-0.001, 0.001), m2::PointD(0, 0), m2::PointD(0.001, -0.001)}, + "Машерова", "ru"); + TestBuilding masherova14(m2::PointD(0.001, 0.001), "", "14", masherova, "ru"); + + auto id = BuildCountry("Belarus", [&](TestMwmBuilder & builder) { + builder.Add(mazurova); + builder.Add(mazurova14); + + builder.Add(masherova); + builder.Add(masherova14); + }); + + SetViewport(m2::RectD(m2::PointD(0, 0), m2::PointD(0.001, 0.001))); + { + auto request = MakeRequest("Мазурова 14"); + auto const & results = request->Results(); + + TRules rules = {ExactMatch(id, mazurova14), ExactMatch(id, masherova14)}; + TEST(ResultsMatch(results, rules), ()); + + TEST_EQUAL(results.size(), 2, ()); + TEST(ResultsMatch({results[0]}, {rules[0]}), ()); + TEST(ResultsMatch({results[1]}, {rules[1]}), ()); + } +} +} // namespace