From 00366cc5d8e522efa9294ab14b72992f4fe2be86 Mon Sep 17 00:00:00 2001 From: Viktor Govako Date: Sun, 31 Jul 2022 18:21:00 +0300 Subject: [PATCH] [search] Use indices vector to sort and select results in PreRanker. Signed-off-by: Viktor Govako --- indexer/feature_decl.cpp | 18 ++-- indexer/feature_decl.hpp | 9 ++ search/pre_ranker.cpp | 86 +++++++------------ search/pre_ranker.hpp | 15 ++-- .../pre_ranker_test.cpp | 45 +++++----- 5 files changed, 81 insertions(+), 92 deletions(-) diff --git a/indexer/feature_decl.cpp b/indexer/feature_decl.cpp index 191852c59f..8033760adb 100644 --- a/indexer/feature_decl.cpp +++ b/indexer/feature_decl.cpp @@ -1,12 +1,12 @@ #include "indexer/feature_decl.hpp" +#include + #include -using namespace std; - namespace feature { -string DebugPrint(GeomType type) +std::string DebugPrint(GeomType type) { switch (type) { @@ -19,7 +19,7 @@ string DebugPrint(GeomType type) } } // namespace feature -string DebugPrint(FeatureID const & id) +std::string DebugPrint(FeatureID const & id) { return "{ " + DebugPrint(id.m_mwmId) + ", " + std::to_string(id.m_index) + " }"; } @@ -30,7 +30,7 @@ char const * const FeatureID::kInvalidFileName = "INVALID"; int64_t const FeatureID::kInvalidMwmVersion = -1; -string FeatureID::GetMwmName() const +std::string FeatureID::GetMwmName() const { return IsValid() ? m_mwmId.GetInfo()->GetCountryName() : kInvalidFileName; } @@ -39,3 +39,11 @@ int64_t FeatureID::GetMwmVersion() const { return IsValid() ? m_mwmId.GetInfo()->GetVersion() : kInvalidMwmVersion; } + +size_t std::hash::operator()(FeatureID const & fID) const +{ + size_t seed = 0; + boost::hash_combine(seed, fID.m_mwmId.GetInfo()); + boost::hash_combine(seed, fID.m_index); + return seed; +} diff --git a/indexer/feature_decl.hpp b/indexer/feature_decl.hpp index 9088b879a4..ce180a6b3d 100644 --- a/indexer/feature_decl.hpp +++ b/indexer/feature_decl.hpp @@ -50,3 +50,12 @@ struct FeatureID }; std::string DebugPrint(FeatureID const & id); + +namespace std +{ +template <> +struct hash +{ + size_t operator()(FeatureID const & fID) const; +}; +} // namespace std diff --git a/search/pre_ranker.cpp b/search/pre_ranker.cpp index cc497b4fca..1ca2b4911f 100644 --- a/search/pre_ranker.cpp +++ b/search/pre_ranker.cpp @@ -24,7 +24,7 @@ using namespace std; namespace { -void SweepNearbyResults(m2::PointD const & eps, set const & prevEmit, +void SweepNearbyResults(m2::PointD const & eps, unordered_set const & prevEmit, vector & results) { m2::NearbyPointsSweeper sweeper(eps.x, eps.y); @@ -143,6 +143,22 @@ void PreRanker::FillMissingFieldsInPreResults() }); } +namespace +{ +template class CompareIndices +{ + CompT m_cmp; + ContT const & m_cont; + +public: + CompareIndices(CompT const & cmp, ContT const & cont) : m_cmp(cmp), m_cont(cont) {} + bool operator()(size_t l, size_t r) const + { + return m_cmp(m_cont[l], m_cont[r]); + } +}; +} // namespace + void PreRanker::Filter(bool viewportSearch) { auto const lessForUnique = [](PreRankerResult const & lhs, PreRankerResult const & rhs) @@ -167,64 +183,22 @@ void PreRanker::Filter(bool viewportSearch) if (m_results.size() <= BatchSize()) return; - sort(m_results.begin(), m_results.end(), &PreRankerResult::LessDistance); + vector indices(m_results.size()); + generate(indices.begin(), indices.end(), [n = 0] () mutable { return n++; }); + unordered_set filtered; - /// @todo To have any benefit from the next sort-shuffle code block, we should have at least 2 *strictly* equal - /// (distance in meters) results in the middle of m_results vector. The probability of that is -> 0. - /// This code had sence, when we had some approximated viewport distance before centers table. - /*{ - // Priority is some kind of distance from the viewport or - // position, therefore if we have a bunch of results with the same - // priority, we have no idea here which results are relevant. To - // prevent bias from previous search routines (like sorting by - // feature id) this code randomly selects tail of the - // sorted-by-priority list of pre-results. - - double const last = m_results[BatchSize()].GetDistance(); - - auto b = m_results.begin() + BatchSize(); - for (; b != m_results.begin() && b->GetDistance() == last; --b) - ; - if (b->GetDistance() != last) - ++b; - - auto e = m_results.begin() + BatchSize(); - for (; e != m_results.end() && e->GetDistance() == last; ++e) - ; - - // The main reason of shuffling here is to select a random subset - // from the low-priority results. We're using a linear - // congruential method with default seed because it is fast, - // simple and doesn't need an external entropy source. - // - // TODO (@y, @m, @vng): consider to take some kind of hash from - // features and then select a subset with smallest values of this - // hash. In this case this subset of results will be persistent - // to small changes in the original set. - shuffle(b, e, m_rng); - }*/ - - struct LessFeatureID - { - inline bool operator()(PreRankerResult const & lhs, PreRankerResult const & rhs) const - { - return lhs.GetId() < rhs.GetId(); - } - }; - set filtered; - - auto const numResults = min(m_results.size(), BatchSize()); - auto const iBeg = m_results.begin(); - auto const iMiddle = iBeg + numResults; - auto const iEnd = m_results.end(); + auto const iBeg = indices.begin(); + auto const iMiddle = iBeg + BatchSize(); + auto const iEnd = indices.end(); + nth_element(iBeg, iMiddle, iEnd, CompareIndices(&PreRankerResult::LessDistance, m_results)); filtered.insert(iBeg, iMiddle); if (!m_params.m_categorialRequest) { - nth_element(iBeg, iMiddle, iEnd, &PreRankerResult::LessRankAndPopularity); + nth_element(iBeg, iMiddle, iEnd, CompareIndices(&PreRankerResult::LessRankAndPopularity, m_results)); filtered.insert(iBeg, iMiddle); - nth_element(iBeg, iMiddle, iEnd, &PreRankerResult::LessByExactMatch); + nth_element(iBeg, iMiddle, iEnd, CompareIndices(&PreRankerResult::LessByExactMatch, m_results)); filtered.insert(iBeg, iMiddle); } else @@ -238,11 +212,15 @@ void PreRanker::Filter(bool viewportSearch) 2 * kPedestrianRadiusMeters; comparator.m_viewport = m_params.m_viewport; - nth_element(iBeg, iMiddle, iEnd, comparator); + nth_element(iBeg, iMiddle, iEnd, CompareIndices(comparator, m_results)); filtered.insert(iBeg, iMiddle); } - m_results.assign(make_move_iterator(filtered.begin()), make_move_iterator(filtered.end())); + PreResultsContainerT newResults; + newResults.reserve(filtered.size()); + for (size_t idx : filtered) + newResults.push_back(m_results[idx]); + m_results.swap(newResults); } void PreRanker::UpdateResults(bool lastUpdate) diff --git a/search/pre_ranker.hpp b/search/pre_ranker.hpp index f235c90ea4..305df7c62c 100644 --- a/search/pre_ranker.hpp +++ b/search/pre_ranker.hpp @@ -12,10 +12,9 @@ #include #include #include -#include #include #include -#include +#include #include class DataSource; @@ -144,8 +143,10 @@ private: DataSource const & m_dataSource; Ranker & m_ranker; - std::vector m_results; - std::vector m_relaxedResults; + + using PreResultsContainerT = std::vector; + PreResultsContainerT m_results, m_relaxedResults; + Params m_params; // Amount of results sent up the pipeline. @@ -161,12 +162,10 @@ private: /// search session. They're used for filtering of current search, because we need to give more priority /// to results that were on map previously, to avoid result's annoying blinking/flickering on map. /// @{ - std::set m_currEmit; - std::set m_prevEmit; + std::unordered_set m_currEmit; + std::unordered_set m_prevEmit; /// @} - std::minstd_rand m_rng; - DISALLOW_COPY_AND_MOVE(PreRanker); }; } // namespace search diff --git a/search/search_integration_tests/pre_ranker_test.cpp b/search/search_integration_tests/pre_ranker_test.cpp index fb0ba0295f..e06085551c 100644 --- a/search/search_integration_tests/pre_ranker_test.cpp +++ b/search/search_integration_tests/pre_ranker_test.cpp @@ -27,31 +27,19 @@ #include "base/assert.hpp" #include "base/cancellable.hpp" -#include "base/limited_priority_queue.hpp" #include "base/math.hpp" #include "base/stl_helpers.hpp" #include -#include -#include #include #include +namespace pre_ranker_test +{ using namespace generator::tests_support; -using namespace search::tests_support; +using namespace search; using namespace std; -class DataSource; - -namespace storage -{ -class CountryInfoGetter; -} - -namespace search -{ -namespace -{ class TestRanker : public Ranker { public: @@ -105,12 +93,15 @@ UNIT_CLASS_TEST(PreRankerTest, Smoke) // emits results nearest to the pivot. m2::PointD const kPivot(0, 0); - m2::RectD const kViewport(m2::PointD(-5, -5), m2::PointD(5, 5)); + m2::RectD const kViewport(-5, -5, 5, 5); - size_t const kBatchSize = 50; + /// @todo Well, I'm not sure that 50 results will have unique distances to pivot. + /// 7x7 grid is 49, so potentially it can be 51 (north and south) or (east and west). + /// But we should consider circle (ellipse) around pivot and I can't say, + /// how it goes in meters radius on integer mercator grid. + size_t constexpr kBatchSize = 50; vector pois; - for (int x = -5; x <= 5; ++x) { for (int y = -5; y <= 5; ++y) @@ -122,7 +113,8 @@ UNIT_CLASS_TEST(PreRankerTest, Smoke) TEST_LESS(kBatchSize, pois.size(), ()); - auto mwmId = BuildCountry("Cafeland", [&](TestMwmBuilder & builder) { + auto mwmId = BuildCountry("Cafeland", [&](TestMwmBuilder & builder) + { for (auto const & poi : pois) builder.Add(poi); }); @@ -150,7 +142,8 @@ UNIT_CLASS_TEST(PreRankerTest, Smoke) vector emit(pois.size()); FeaturesVectorTest fv(mwmId.GetInfo()->GetLocalFile().GetPath(MapFileType::Map)); - fv.GetVector().ForEach([&](FeatureType & ft, uint32_t index) { + fv.GetVector().ForEach([&](FeatureType & ft, uint32_t index) + { FeatureID id(mwmId, index); ResultTracer::Provenance provenance; preRanker.Emplace(id, PreRankingInfo(Model::TYPE_SUBPOI, TokenRange(0, 1)), provenance); @@ -164,19 +157,21 @@ UNIT_CLASS_TEST(PreRankerTest, Smoke) TEST(all_of(emit.begin(), emit.end(), base::IdFunctor()), (emit)); TEST(ranker.Finished(), ()); - TEST_EQUAL(results.size(), kBatchSize, ()); + + size_t const count = results.size(); + // See todo comment above for details. + TEST(count == kBatchSize || count == kBatchSize + 1, (count)); vector checked(pois.size()); - for (size_t i = 0; i < results.size(); ++i) + for (size_t i = 0; i < count; ++i) { size_t const index = results[i].GetId().m_index; TEST_LESS(index, pois.size(), ()); TEST(!checked[index], (index)); - TEST(base::AlmostEqualAbs(distances[index], results[i].GetDistance(), 1.0), + TEST(base::AlmostEqualAbs(distances[index], results[i].GetDistance(), 1.0 /* 1 meter epsilon */), (distances[index], results[i].GetDistance())); checked[index] = true; } } -} // namespace -} // namespace search +} // namespace pre_ranker_test