diff --git a/base/base.pro b/base/base.pro index 934e3298ac..49052f5b1a 100644 --- a/base/base.pro +++ b/base/base.pro @@ -57,6 +57,7 @@ HEADERS += \ object_tracker.hpp \ observer_list.hpp \ range_iterator.hpp \ + ref_counted.hpp \ regexp.hpp \ rolling_hash.hpp \ scope_guard.hpp \ diff --git a/base/base_tests/base_tests.pro b/base/base_tests/base_tests.pro index c447870577..7624cc052f 100644 --- a/base/base_tests/base_tests.pro +++ b/base/base_tests/base_tests.pro @@ -28,6 +28,7 @@ SOURCES += \ mem_trie_test.cpp \ observer_list_test.cpp \ range_iterator_test.cpp \ + ref_counted_tests.cpp \ regexp_test.cpp \ rolling_hash_test.cpp \ scope_guard_test.cpp \ diff --git a/base/base_tests/ref_counted_tests.cpp b/base/base_tests/ref_counted_tests.cpp new file mode 100644 index 0000000000..f1863eb499 --- /dev/null +++ b/base/base_tests/ref_counted_tests.cpp @@ -0,0 +1,75 @@ +#include "testing/testing.hpp" + +#include "base/ref_counted.hpp" + +using namespace my; + +namespace +{ +struct Resource : public RefCounted +{ + Resource(bool & destroyed) : m_destroyed(destroyed) { m_destroyed = false; } + + ~Resource() override { m_destroyed = true; } + + bool & m_destroyed; +}; + +UNIT_TEST(RefCounted_Smoke) +{ + { + RefCountPtr p; + } + + { + bool destroyed; + { + RefCountPtr p(new Resource(destroyed)); + TEST_EQUAL(1, p->NumRefs(), ()); + TEST(!destroyed, ()); + } + TEST(destroyed, ()); + } + + { + bool destroyed; + { + RefCountPtr a(new Resource(destroyed)); + TEST_EQUAL(1, a->NumRefs(), ()); + TEST(!destroyed, ()); + + RefCountPtr b(a); + TEST(a.Get() == b.Get(), ()); + TEST_EQUAL(2, a->NumRefs(), ()); + TEST(!destroyed, ()); + + { + RefCountPtr c; + TEST(c.Get() == nullptr, ()); + + c = b; + TEST(a.Get() == b.Get(), ()); + TEST(b.Get() == c.Get(), ()); + TEST_EQUAL(3, a->NumRefs(), ()); + TEST(!destroyed, ()); + } + + TEST(a.Get() == b.Get(), ()); + TEST_EQUAL(2, a->NumRefs(), ()); + TEST(!destroyed, ()); + + RefCountPtr d(move(b)); + TEST(b.Get() == nullptr, ()); + TEST(a.Get() == d.Get(), ()); + TEST_EQUAL(2, a->NumRefs(), ()); + TEST(!destroyed, ()); + + a = a; + TEST_EQUAL(a.Get(), d.Get(), ()); + TEST_EQUAL(2, a->NumRefs(), ()); + TEST(!destroyed, ()); + } + TEST(destroyed, ()); + } +} +} // namespace diff --git a/base/ref_counted.hpp b/base/ref_counted.hpp new file mode 100644 index 0000000000..54feffc08b --- /dev/null +++ b/base/ref_counted.hpp @@ -0,0 +1,107 @@ +#pragma once + +#include "base/macros.hpp" + +#include "std/cstdint.hpp" +#include "std/unique_ptr.hpp" + +namespace my +{ +class RefCounted +{ +public: + virtual ~RefCounted() = default; + + inline void IncRef() noexcept { ++m_refs; } + inline uint64_t DecRef() noexcept { return --m_refs; } + inline uint64_t NumRefs() const noexcept { return m_refs; } + +protected: + RefCounted() noexcept = default; + + uint64_t m_refs = 0; + + DISALLOW_COPY_AND_MOVE(RefCounted); +}; + +template +class RefCountPtr +{ +public: + RefCountPtr() noexcept = default; + + explicit RefCountPtr(T * p) noexcept : m_p(p) + { + if (m_p) + m_p->IncRef(); + } + + explicit RefCountPtr(unique_ptr p) noexcept : RefCountPtr(p.release()) {} + + RefCountPtr(RefCountPtr const & rhs) { *this = rhs; } + + RefCountPtr(RefCountPtr && rhs) { *this = move(rhs); } + + ~RefCountPtr() { Reset(); } + + RefCountPtr & operator=(unique_ptr p) + { + Reset(); + + m_p = p.release(); + if (m_p) + m_p->IncRef(); + + return *this; + } + + RefCountPtr & operator=(RefCountPtr const & rhs) + { + if (this == &rhs) + return *this; + + Reset(); + m_p = rhs.m_p; + if (m_p) + m_p->IncRef(); + + return *this; + } + + RefCountPtr & operator=(RefCountPtr && rhs) + { + if (this == &rhs) + return *this; + + Reset(); + m_p = rhs.m_p; + rhs.m_p = nullptr; + + return *this; + } + + void Reset() + { + if (!m_p) + return; + + if (m_p->DecRef() == 0) + delete m_p; + m_p = nullptr; + } + + T * Get() noexcept { return m_p; } + T const * Get() const noexcept { return m_p; } + + T & operator*() { return *m_p; } + T const & operator*() const { return *m_p; } + + T * operator->() noexcept { return m_p; } + T const * operator->() const noexcept { return m_p; } + + inline operator bool() const noexcept { return m_p != nullptr; } + +private: + T * m_p = nullptr; +}; +} // namespace my diff --git a/coding/compressed_bit_vector.hpp b/coding/compressed_bit_vector.hpp index 7e51b363aa..925759d83d 100644 --- a/coding/compressed_bit_vector.hpp +++ b/coding/compressed_bit_vector.hpp @@ -5,6 +5,7 @@ #include "coding/writer.hpp" #include "base/assert.hpp" +#include "base/ref_counted.hpp" #include "std/algorithm.hpp" #include "std/unique_ptr.hpp" @@ -13,7 +14,7 @@ namespace coding { -class CompressedBitVector +class CompressedBitVector : public my::RefCounted { public: enum class StorageStrategy @@ -198,18 +199,18 @@ public: static_cast(header); switch (strat) { - case CompressedBitVector::StorageStrategy::Dense: - { - vector bitGroups; - rw::ReadVectorOfPOD(src, bitGroups); - return DenseCBV::BuildFromBitGroups(move(bitGroups)); - } - case CompressedBitVector::StorageStrategy::Sparse: - { - vector setBits; - rw::ReadVectorOfPOD(src, setBits); - return make_unique(move(setBits)); - } + case CompressedBitVector::StorageStrategy::Dense: + { + vector bitGroups; + rw::ReadVectorOfPOD(src, bitGroups); + return DenseCBV::BuildFromBitGroups(move(bitGroups)); + } + case CompressedBitVector::StorageStrategy::Sparse: + { + vector setBits; + rw::ReadVectorOfPOD(src, setBits); + return make_unique(move(setBits)); + } } return unique_ptr(); } @@ -227,19 +228,34 @@ public: CompressedBitVector::StorageStrategy strat = cbv.GetStorageStrategy(); switch (strat) { - case CompressedBitVector::StorageStrategy::Dense: - { - DenseCBV const & denseCBV = static_cast(cbv); - denseCBV.ForEach(f); - return; - } - case CompressedBitVector::StorageStrategy::Sparse: - { - SparseCBV const & sparseCBV = static_cast(cbv); - sparseCBV.ForEach(f); - return; - } + case CompressedBitVector::StorageStrategy::Dense: + { + DenseCBV const & denseCBV = static_cast(cbv); + denseCBV.ForEach(f); + return; + } + case CompressedBitVector::StorageStrategy::Sparse: + { + SparseCBV const & sparseCBV = static_cast(cbv); + sparseCBV.ForEach(f); + return; + } } } }; + +class CompressedBitVectorHasher +{ +public: + static uint64_t Hash(CompressedBitVector const & cbv) + { + uint64_t const kBase = 127; + uint64_t hash = 0; + CompressedBitVectorEnumerator::ForEach(cbv, [&hash](uint64_t i) + { + hash = hash * kBase + i + 1; + }); + return hash; + } +}; } // namespace coding diff --git a/search/cbv.cpp b/search/cbv.cpp new file mode 100644 index 0000000000..437abe7c36 --- /dev/null +++ b/search/cbv.cpp @@ -0,0 +1,114 @@ +#include "search/cbv.hpp" + +#include "std/limits.hpp" +#include "std/vector.hpp" + +using namespace my; + +namespace search +{ +namespace +{ +uint64_t const kModulo = 18446744073709551557LLU; +} // namespace + +CBV::CBV(unique_ptr p) : m_p(move(p)) {} + +CBV::CBV(CBV && cbv) : m_p(move(cbv.m_p)), m_isFull(cbv.m_isFull) { cbv.m_isFull = false; } + +CBV & CBV::operator=(unique_ptr p) +{ + m_p = move(p); + m_isFull = false; + + return *this; +} + +CBV & CBV::operator=(CBV && rhs) +{ + if (this == &rhs) + return *this; + + m_p = move(rhs.m_p); + m_isFull = rhs.m_isFull; + + rhs.m_isFull = false; + + return *this; +} + +void CBV::SetFull() +{ + m_p.Reset(); + m_isFull = true; +} + +void CBV::Reset() +{ + m_p.Reset(); + m_isFull = false; +} + +bool CBV::HasBit(uint64_t id) const +{ + if (IsFull()) + return true; + if (IsEmpty()) + return false; + return m_p->GetBit(id); +} + +uint64_t CBV::PopCount() const +{ + ASSERT(!IsFull(), ()); + if (IsEmpty()) + return 0; + return m_p->PopCount(); +} + +CBV CBV::Union(CBV const & rhs) const +{ + if (IsFull() || rhs.IsEmpty()) + return *this; + if (IsEmpty() || rhs.IsFull()) + return rhs; + return CBV(coding::CompressedBitVector::Union(*m_p, *rhs.m_p)); +} + +CBV CBV::Intersect(CBV const & rhs) const +{ + if (IsFull() || rhs.IsEmpty()) + return rhs; + if (IsEmpty() || rhs.IsFull()) + return *this; + return CBV(coding::CompressedBitVector::Intersect(*m_p, *rhs.m_p)); +} + +CBV CBV::Take(uint64_t n) const +{ + if (IsEmpty()) + return *this; + if (IsFull()) + { + vector groups((n + 63) / 64, numeric_limits::max()); + uint64_t const r = n % 64; + if (r != 0) + { + ASSERT(!groups.empty(), ()); + groups.back() = (static_cast(1) << r) - 1; + } + return CBV(coding::DenseCBV::BuildFromBitGroups(move(groups))); + } + + return CBV(m_p->LeaveFirstSetNBits(n)); +} + +uint64_t CBV::Hash() const +{ + if (IsEmpty()) + return 0; + if (IsFull()) + return kModulo; + return coding::CompressedBitVectorHasher::Hash(*m_p) % kModulo; +} +} // namespace search diff --git a/search/cbv.hpp b/search/cbv.hpp new file mode 100644 index 0000000000..b52ff839f6 --- /dev/null +++ b/search/cbv.hpp @@ -0,0 +1,59 @@ +#pragma once + +#include "coding/compressed_bit_vector.hpp" + +#include "base/ref_counted.hpp" + +#include "std/function.hpp" +#include "std/utility.hpp" + +namespace search +{ +// A wrapper around coding::CompressedBitVector that augments the +// latter with the "full" state and uses reference counting for +// ownership sharing. +class CBV +{ +public: + CBV() = default; + explicit CBV(unique_ptr p); + CBV(CBV const & cbv) = default; + CBV(CBV && cbv); + + inline operator bool() const { return !IsEmpty(); } + CBV & operator=(unique_ptr p); + CBV & operator=(CBV const & rhs) = default; + CBV & operator=(CBV && rhs); + + void SetFull(); + void Reset(); + + inline bool IsEmpty() const { return !m_isFull && coding::CompressedBitVector::IsEmpty(m_p.Get()); } + inline bool IsFull() const { return m_isFull; } + + bool HasBit(uint64_t id) const; + uint64_t PopCount() const; + + template + void ForEach(TFn && fn) const + { + ASSERT(!m_isFull, ()); + if (!IsEmpty()) + coding::CompressedBitVectorEnumerator::ForEach(*m_p, forward(fn)); + } + + CBV Union(CBV const & rhs) const; + CBV Intersect(CBV const & rhs) const; + + // Takes first set |n| bits. + CBV Take(uint64_t n) const; + + uint64_t Hash() const; + +private: + my::RefCountPtr m_p; + + // True iff all bits are set to one. + bool m_isFull = false; +}; +} // namespace search diff --git a/search/cbv_ptr.cpp b/search/cbv_ptr.cpp deleted file mode 100644 index 7bdf63e64a..0000000000 --- a/search/cbv_ptr.cpp +++ /dev/null @@ -1,67 +0,0 @@ -#include "search/cbv_ptr.hpp" - -namespace search -{ -CBVPtr::CBVPtr(coding::CompressedBitVector const * p, bool isOwner) { Set(p, isOwner); } - -void CBVPtr::Release() -{ - if (m_isOwner) - delete m_ptr; - - m_ptr = nullptr; - m_isOwner = false; - m_isFull = false; -} - -void CBVPtr::Set(coding::CompressedBitVector const * p, bool isOwner /* = false*/) -{ - Release(); - - m_ptr = p; - m_isOwner = p && isOwner; -} - -void CBVPtr::Set(unique_ptr p) -{ - Set(p.release(), true /* isOwner */); -} - -void CBVPtr::Union(coding::CompressedBitVector const * p) -{ - if (!p || m_isFull) - return; - - if (!m_ptr) - { - m_ptr = p; - m_isFull = false; - } - else - { - Set(coding::CompressedBitVector::Union(*m_ptr, *p).release(), true); - } -} - -void CBVPtr::Intersect(coding::CompressedBitVector const * p) -{ - if (!p) - { - Release(); - return; - } - - if (m_ptr) - { - Set(coding::CompressedBitVector::Intersect(*m_ptr, *p).release(), true); - } - else if (m_isFull) - { - m_ptr = p; - m_isFull = false; - } -} - -bool CBVPtr::IsEmpty() const { return !m_isFull && coding::CompressedBitVector::IsEmpty(m_ptr); } - -} // namespace search diff --git a/search/cbv_ptr.hpp b/search/cbv_ptr.hpp deleted file mode 100644 index 3c9cd68fd8..0000000000 --- a/search/cbv_ptr.hpp +++ /dev/null @@ -1,58 +0,0 @@ -#pragma once - -#include "coding/compressed_bit_vector.hpp" - -#include "base/assert.hpp" -#include "base/macros.hpp" - -#include "std/function.hpp" -#include "std/utility.hpp" - -namespace search -{ -/// CompressedBitVector pointer class that incapsulates -/// binary operators logic and takes ownership if needed. -class CBVPtr -{ - DISALLOW_COPY_AND_MOVE(CBVPtr); - - coding::CompressedBitVector const * m_ptr = nullptr; - bool m_isOwner = false; - bool m_isFull = false; ///< True iff all bits are set to one. - - void Release(); - -public: - CBVPtr() = default; - CBVPtr(coding::CompressedBitVector const * p, bool isOwner); - ~CBVPtr() { Release(); } - - inline void SetFull() - { - Release(); - m_isFull = true; - } - - void Set(coding::CompressedBitVector const * p, bool isOwner = false); - void Set(unique_ptr p); - - inline coding::CompressedBitVector const * Get() const { return m_ptr; } - - coding::CompressedBitVector const & operator*() const { return *m_ptr; } - coding::CompressedBitVector const * operator->() const { return m_ptr; } - - bool IsEmpty() const; - - void Union(coding::CompressedBitVector const * p); - void Intersect(coding::CompressedBitVector const * p); - - template - void ForEach(TFn && fn) const - { - ASSERT(!m_isFull, ()); - if (!IsEmpty()) - coding::CompressedBitVectorEnumerator::ForEach(*m_ptr, forward(fn)); - } -}; - -} // namespace search diff --git a/search/features_filter.cpp b/search/features_filter.cpp index 0f9912ad07..5f04345b6c 100644 --- a/search/features_filter.cpp +++ b/search/features_filter.cpp @@ -1,47 +1,48 @@ #include "search/features_filter.hpp" -#include "coding/compressed_bit_vector.hpp" +#include "search/cbv.hpp" #include "std/algorithm.hpp" +#include "std/vector.hpp" namespace search { // FeaturesFilter ---------------------------------------------------------------------------------- -FeaturesFilter::FeaturesFilter(coding::CompressedBitVector const & filter, uint32_t threshold) +FeaturesFilter::FeaturesFilter(CBV const & filter, uint32_t threshold) : m_filter(filter), m_threshold(threshold) { } -bool FeaturesFilter::NeedToFilter(coding::CompressedBitVector const & cbv) const +bool FeaturesFilter::NeedToFilter(CBV const & cbv) const { + if (cbv.IsFull()) + return true; return cbv.PopCount() > m_threshold; } // LocalityFilter ---------------------------------------------------------------------------------- -LocalityFilter::LocalityFilter(coding::CompressedBitVector const & filter) +LocalityFilter::LocalityFilter(CBV const & filter) : FeaturesFilter(filter, 0 /* threshold */) { } -unique_ptr LocalityFilter::Filter( - coding::CompressedBitVector const & cbv) const +CBV LocalityFilter::Filter(CBV const & cbv) const { - return coding::CompressedBitVector::Intersect(m_filter, cbv); + return m_filter.Intersect(cbv); } // ViewportFilter ---------------------------------------------------------------------------------- -ViewportFilter::ViewportFilter(coding::CompressedBitVector const & filter, uint32_t threshold) +ViewportFilter::ViewportFilter(CBV const & filter, uint32_t threshold) : FeaturesFilter(filter, threshold) { } -unique_ptr ViewportFilter::Filter( - coding::CompressedBitVector const & cbv) const +CBV ViewportFilter::Filter(CBV const & cbv) const { - auto result = coding::CompressedBitVector::Intersect(m_filter, cbv); - if (!coding::CompressedBitVector::IsEmpty(result)) + auto result = m_filter.Intersect(cbv); + if (!result.IsEmpty()) return result; - return cbv.LeaveFirstSetNBits(m_threshold); -} + return cbv.Take(m_threshold); +} } // namespace search diff --git a/search/features_filter.hpp b/search/features_filter.hpp index fb5f425c3c..8aaf27d782 100644 --- a/search/features_filter.hpp +++ b/search/features_filter.hpp @@ -2,30 +2,26 @@ #include "std/unique_ptr.hpp" -namespace coding -{ -class CompressedBitVector; -} - namespace search { +class CBV; + // A lightweight filter of features. // // NOTE: this class and its subclasses *ARE* thread-safe. class FeaturesFilter { public: - FeaturesFilter(coding::CompressedBitVector const & filter, uint32_t threshold); + FeaturesFilter(CBV const & filter, uint32_t threshold); virtual ~FeaturesFilter() = default; - bool NeedToFilter(coding::CompressedBitVector const & features) const; + bool NeedToFilter(CBV const & features) const; - virtual unique_ptr Filter( - coding::CompressedBitVector const & cbv) const = 0; + virtual CBV Filter(CBV const & cbv) const = 0; protected: - coding::CompressedBitVector const & m_filter; + CBV const & m_filter; uint32_t const m_threshold; }; @@ -34,11 +30,10 @@ protected: class LocalityFilter : public FeaturesFilter { public: - LocalityFilter(coding::CompressedBitVector const & filter); + LocalityFilter(CBV const & filter); // FeaturesFilter overrides: - unique_ptr Filter( - coding::CompressedBitVector const & cbv) const override; + CBV Filter(CBV const & cbv) const override; }; // Fuzzy filter - tries to leave only features belonging to the set it @@ -49,11 +44,10 @@ public: class ViewportFilter : public FeaturesFilter { public: - ViewportFilter(coding::CompressedBitVector const & filter, uint32_t threshold); + ViewportFilter(CBV const & filter, uint32_t threshold); // FeaturesFilter overrides: - unique_ptr Filter( - coding::CompressedBitVector const & cbv) const override; + CBV Filter(CBV const & cbv) const override; }; } // namespace search diff --git a/search/features_layer_matcher.cpp b/search/features_layer_matcher.cpp index 0132857e76..f80b304306 100644 --- a/search/features_layer_matcher.cpp +++ b/search/features_layer_matcher.cpp @@ -34,7 +34,7 @@ void FeaturesLayerMatcher::SetContext(MwmContext * context) m_loader.SetContext(context); } -void FeaturesLayerMatcher::SetPostcodes(coding::CompressedBitVector const * postcodes) +void FeaturesLayerMatcher::SetPostcodes(CBV const * postcodes) { m_postcodes = postcodes; } diff --git a/search/features_layer_matcher.hpp b/search/features_layer_matcher.hpp index 5d69b58bbc..da9098efb5 100644 --- a/search/features_layer_matcher.hpp +++ b/search/features_layer_matcher.hpp @@ -1,6 +1,7 @@ #pragma once #include "search/cancel_exception.hpp" +#include "search/cbv.hpp" #include "search/features_layer.hpp" #include "search/house_numbers_matcher.hpp" #include "search/model.hpp" @@ -19,8 +20,6 @@ #include "geometry/point2d.hpp" #include "geometry/rect2d.hpp" -#include "coding/compressed_bit_vector.hpp" - #include "base/cancellable.hpp" #include "base/logging.hpp" #include "base/macros.hpp" @@ -61,7 +60,7 @@ public: FeaturesLayerMatcher(Index const & index, my::Cancellable const & cancellable); void SetContext(MwmContext * context); - void SetPostcodes(coding::CompressedBitVector const * postcodes); + void SetPostcodes(CBV const * postcodes); template void Match(FeaturesLayer const & child, FeaturesLayer const & parent, TFn && fn) @@ -165,7 +164,7 @@ private: MercatorBounds::RectByCenterXYAndSizeInMeters(poiCenters[i], kBuildingRadiusMeters), [&](FeatureType & ft) { - if (m_postcodes && !m_postcodes->GetBit(ft.GetID().m_index)) + if (m_postcodes && !m_postcodes->HasBit(ft.GetID().m_index)) return; if (house_numbers::HouseNumbersMatch(strings::MakeUniString(ft.GetHouseNumber()), queryParse)) @@ -252,7 +251,7 @@ private: if (binary_search(buildings.begin(), buildings.end(), id)) return true; - if (m_postcodes && !m_postcodes->GetBit(id)) + if (m_postcodes && !m_postcodes->HasBit(id)) return false; if (!loaded) @@ -344,7 +343,7 @@ private: MwmContext * m_context; - coding::CompressedBitVector const * m_postcodes; + CBV const * m_postcodes; ReverseGeocoder m_reverseGeocoder; diff --git a/search/geocoder.cpp b/search/geocoder.cpp index a2ae1e262c..72027b827c 100644 --- a/search/geocoder.cpp +++ b/search/geocoder.cpp @@ -1,6 +1,6 @@ #include "search/geocoder.hpp" -#include "search/cbv_ptr.hpp" +#include "search/cbv.hpp" #include "search/dummy_rank_table.hpp" #include "search/features_filter.hpp" #include "search/features_layer_matcher.hpp" @@ -387,14 +387,14 @@ size_t OrderCountries(m2::RectD const & pivot, vector> & inf // Performs pairwise union of adjacent bit vectors // until at most one bit vector is left. -void UniteCBVs(vector> & cbvs) +void UniteCBVs(vector & cbvs) { while (cbvs.size() > 1) { size_t i = 0; size_t j = 0; for (; j + 1 < cbvs.size(); j += 2) - cbvs[i++] = coding::CompressedBitVector::Union(*cbvs[j], *cbvs[j + 1]); + cbvs[i++] = cbvs[j].Union(cbvs[j + 1]); for (; j < cbvs.size(); ++j) cbvs[i++] = move(cbvs[j]); cbvs.resize(i); @@ -411,12 +411,10 @@ Geocoder::Geocoder(Index const & index, storage::CountryInfoGetter const & infoG : m_index(index) , m_infoGetter(infoGetter) , m_cancellable(cancellable) - , m_numTokens(0) , m_model(SearchModel::Instance()) , m_pivotRectsCache(kPivotRectsCacheSize, m_cancellable, Processor::kMaxViewportRadiusM) , m_localityRectsCache(kLocalityRectsCacheSize, m_cancellable) , m_pivotFeatures(index) - , m_villages(nullptr) , m_filter(nullptr) , m_matcher(nullptr) , m_finder(m_cancellable) @@ -446,13 +444,10 @@ void Geocoder::SetParams(Params const & params) } m_retrievalParams = m_params; - m_numTokens = m_params.m_tokens.size(); - if (!m_params.m_prefixTokens.empty()) - ++m_numTokens; // Remove all category synonyms for streets, as they're extracted // individually via LoadStreets. - for (size_t i = 0; i < m_numTokens; ++i) + for (size_t i = 0; i < m_params.GetNumTokens(); ++i) { auto & synonyms = m_params.GetTokens(i); ASSERT(!synonyms.empty(), ()); @@ -488,7 +483,7 @@ void Geocoder::GoEverywhere(PreRanker & preRanker) MY_SCOPE_GUARD(stopProfiler, &ProfilerStop); #endif - if (m_numTokens == 0) + if (m_params.GetNumTokens() == 0) return; vector> infos; @@ -499,7 +494,7 @@ void Geocoder::GoEverywhere(PreRanker & preRanker) void Geocoder::GoInViewport(PreRanker & preRanker) { - if (m_numTokens == 0) + if (m_params.GetNumTokens() == 0) return; vector> infos; @@ -533,10 +528,12 @@ void Geocoder::GoImpl(PreRanker & preRanker, vector> & infos // it's ok to save MwmId. m_worldId = handle.GetId(); m_context = make_unique(move(handle)); + if (HasSearchIndex(value)) { - PrepareAddressFeatures(); - FillLocalitiesTable(); + BaseContext ctx; + InitBaseContext(ctx); + FillLocalitiesTable(ctx); } m_context.reset(); } @@ -561,15 +558,13 @@ void Geocoder::GoImpl(PreRanker & preRanker, vector> & infos { ASSERT(context, ()); m_context = move(context); + MY_SCOPE_GUARD(cleanup, [&]() { LOG(LDEBUG, (m_context->GetName(), "geocoding complete.")); m_matcher->OnQueryFinished(); m_matcher = nullptr; m_context.reset(); - m_addressFeatures.clear(); - m_streets = nullptr; - m_villages = nullptr; }); auto it = m_matchersCache.find(m_context->GetId()); @@ -582,41 +577,32 @@ void Geocoder::GoImpl(PreRanker & preRanker, vector> & infos m_matcher = it->second.get(); m_matcher->SetContext(m_context.get()); - PrepareAddressFeatures(); + BaseContext ctx; + InitBaseContext(ctx); - coding::CompressedBitVector const * viewportCBV = nullptr; if (inViewport) - viewportCBV = RetrieveGeometryFeatures(*m_context, m_params.m_pivot, RECT_ID_PIVOT); - - if (viewportCBV) { - for (size_t i = 0; i < m_numTokens; ++i) - { - m_addressFeatures[i] = - coding::CompressedBitVector::Intersect(*m_addressFeatures[i], *viewportCBV); - } + auto const viewportCBV = + RetrieveGeometryFeatures(*m_context, m_params.m_pivot, RECT_ID_PIVOT); + for (auto & features : ctx.m_features) + features = features.Intersect(viewportCBV); } - // |m_streets| will be initialized in LimitedSearch() and its - // callees, if needed. - m_streets = nullptr; - - m_villages = LoadVillages(*m_context); + ctx.m_villages = LoadVillages(*m_context); auto citiesFromWorld = m_cities; - FillVillageLocalities(); + FillVillageLocalities(ctx); MY_SCOPE_GUARD(remove_villages, [&]() { m_cities = citiesFromWorld; }); - m_usedTokens.assign(m_numTokens, false); m_lastMatchedRegion = nullptr; - MatchRegions(REGION_TYPE_COUNTRY); + MatchRegions(ctx, REGION_TYPE_COUNTRY); if (index < numIntersectingMaps || m_preRanker->IsEmpty()) - MatchAroundPivot(); + MatchAroundPivot(ctx); }; // Iterates through all alive mwms and performs geocoding. @@ -636,17 +622,15 @@ void Geocoder::ClearCaches() m_localityRectsCache.Clear(); m_pivotFeatures.Clear(); - m_addressFeatures.clear(); m_matchersCache.clear(); m_streetsCache.clear(); - m_villages.reset(); m_postcodes.Clear(); } void Geocoder::PrepareRetrievalParams(size_t curToken, size_t endToken) { ASSERT_LESS(curToken, endToken, ()); - ASSERT_LESS_OR_EQUAL(endToken, m_numTokens, ()); + ASSERT_LESS_OR_EQUAL(endToken, m_params.GetNumTokens(), ()); m_retrievalParams.m_tokens.clear(); m_retrievalParams.m_prefixTokens.clear(); @@ -663,15 +647,16 @@ void Geocoder::PrepareRetrievalParams(size_t curToken, size_t endToken) } } -void Geocoder::PrepareAddressFeatures() +void Geocoder::InitBaseContext(BaseContext & ctx) { - m_addressFeatures.resize(m_numTokens); - for (size_t i = 0; i < m_numTokens; ++i) + ctx.m_usedTokens.assign(m_params.GetNumTokens(), false); + ctx.m_numTokens = m_params.GetNumTokens(); + ctx.m_features.resize(ctx.m_numTokens); + for (size_t i = 0; i < ctx.m_features.size(); ++i) { PrepareRetrievalParams(i, i + 1); - m_addressFeatures[i] = RetrieveAddressFeatures(m_context->GetId(), m_context->m_value, - m_cancellable, m_retrievalParams); - ASSERT(m_addressFeatures[i], ()); + ctx.m_features[i] = RetrieveAddressFeatures(m_context->GetId(), m_context->m_value, + m_cancellable, m_retrievalParams); } } @@ -688,28 +673,21 @@ void Geocoder::InitLayer(SearchModel::SearchType type, size_t startToken, size_t layer.m_lastTokenIsPrefix = (layer.m_endToken > m_params.m_tokens.size()); } -void Geocoder::FillLocalityCandidates(coding::CompressedBitVector const * filter, +void Geocoder::FillLocalityCandidates(BaseContext const & ctx, CBV const & filter, size_t const maxNumLocalities, vector & preLocalities) { preLocalities.clear(); - for (size_t startToken = 0; startToken < m_numTokens; ++startToken) + for (size_t startToken = 0; startToken < ctx.m_numTokens; ++startToken) { - CBVPtr intersection; - CBVPtr unfilteredIntersection; - intersection.SetFull(); - unfilteredIntersection.SetFull(); - if (filter) - { - intersection.Intersect(filter); - unfilteredIntersection.Intersect(m_addressFeatures[startToken].get()); - } - intersection.Intersect(m_addressFeatures[startToken].get()); + CBV intersection = filter.Intersect(ctx.m_features[startToken]); if (intersection.IsEmpty()) continue; - for (size_t endToken = startToken + 1; endToken <= m_numTokens; ++endToken) + CBV unfilteredIntersection = ctx.m_features[startToken]; + + for (size_t endToken = startToken + 1; endToken <= ctx.m_numTokens; ++endToken) { // Skip locality candidates that match only numbers. if (!m_params.IsNumberTokens(startToken, endToken)) @@ -721,22 +699,19 @@ void Geocoder::FillLocalityCandidates(coding::CompressedBitVector const * filter l.m_featureId = featureId; l.m_startToken = startToken; l.m_endToken = endToken; - if (filter) - { - l.m_prob = static_cast(intersection->PopCount()) / - static_cast(unfilteredIntersection->PopCount()); - } + l.m_prob = static_cast(intersection.PopCount()) / + static_cast(unfilteredIntersection.PopCount()); preLocalities.push_back(l); }); } - if (endToken < m_numTokens) + if (endToken < ctx.m_numTokens) { - intersection.Intersect(m_addressFeatures[endToken].get()); - if (filter) - unfilteredIntersection.Intersect(m_addressFeatures[endToken].get()); + intersection = intersection.Intersect(ctx.m_features[endToken]); if (intersection.IsEmpty()) break; + + unfilteredIntersection = unfilteredIntersection.Intersect(ctx.m_features[endToken]); } } } @@ -746,10 +721,13 @@ void Geocoder::FillLocalityCandidates(coding::CompressedBitVector const * filter scorer.GetTopLocalities(maxNumLocalities, preLocalities); } -void Geocoder::FillLocalitiesTable() +void Geocoder::FillLocalitiesTable(BaseContext const & ctx) { vector preLocalities; - FillLocalityCandidates(nullptr, kMaxNumLocalities, preLocalities); + + CBV filter; + filter.SetFull(); + FillLocalityCandidates(ctx, filter, kMaxNumLocalities, preLocalities); size_t numCities = 0; size_t numStates = 0; @@ -818,10 +796,10 @@ void Geocoder::FillLocalitiesTable() } } -void Geocoder::FillVillageLocalities() +void Geocoder::FillVillageLocalities(BaseContext const & ctx) { vector preLocalities; - FillLocalityCandidates(m_villages.get(), kMaxNumVillages, preLocalities); + FillLocalityCandidates(ctx, ctx.m_villages /* filter */, kMaxNumVillages, preLocalities); size_t numVillages = 0; @@ -874,19 +852,19 @@ void Geocoder::ForEachCountry(vector> const & infos, TFn && } } -void Geocoder::MatchRegions(RegionType type) +void Geocoder::MatchRegions(BaseContext & ctx, RegionType type) { switch (type) { case REGION_TYPE_STATE: // Tries to skip state matching and go to cities matching. // Then, performs states matching. - MatchCities(); + MatchCities(ctx); break; case REGION_TYPE_COUNTRY: // Tries to skip country matching and go to states matching. // Then, performs countries matching. - MatchRegions(REGION_TYPE_STATE); + MatchRegions(ctx, REGION_TYPE_STATE); break; case REGION_TYPE_COUNT: ASSERT(false, ("Invalid region type.")); return; } @@ -903,7 +881,7 @@ void Geocoder::MatchRegions(RegionType type) size_t const startToken = p.first.first; size_t const endToken = p.first.second; - if (HasUsedTokensInRange(startToken, endToken)) + if (ctx.HasUsedTokensInRange(startToken, endToken)) continue; for (auto const & region : p.second) @@ -927,8 +905,8 @@ void Geocoder::MatchRegions(RegionType type) if (!matches) continue; - ScopedMarkTokens mark(m_usedTokens, startToken, endToken); - if (AllTokensUsed()) + ScopedMarkTokens mark(ctx.m_usedTokens, startToken, endToken); + if (ctx.AllTokensUsed()) { // Region matches to search query, we need to emit it as is. EmitResult(region, startToken, endToken); @@ -942,22 +920,22 @@ void Geocoder::MatchRegions(RegionType type) }); switch (type) { - case REGION_TYPE_STATE: MatchCities(); break; - case REGION_TYPE_COUNTRY: MatchRegions(REGION_TYPE_STATE); break; + case REGION_TYPE_STATE: MatchCities(ctx); break; + case REGION_TYPE_COUNTRY: MatchRegions(ctx, REGION_TYPE_STATE); break; case REGION_TYPE_COUNT: ASSERT(false, ("Invalid region type.")); break; } } } } -void Geocoder::MatchCities() +void Geocoder::MatchCities(BaseContext & ctx) { // Localities are ordered my (m_startToken, m_endToken) pairs. for (auto const & p : m_cities) { size_t const startToken = p.first.first; size_t const endToken = p.first.second; - if (HasUsedTokensInRange(startToken, endToken)) + if (ctx.HasUsedTokensInRange(startToken, endToken)) continue; for (auto const & city : p.second) @@ -970,8 +948,8 @@ void Geocoder::MatchCities() continue; } - ScopedMarkTokens mark(m_usedTokens, startToken, endToken); - if (AllTokensUsed()) + ScopedMarkTokens mark(ctx.m_usedTokens, startToken, endToken); + if (ctx.AllTokensUsed()) { // City matches to search query, we need to emit it as is. EmitResult(city, startToken, endToken); @@ -982,30 +960,25 @@ void Geocoder::MatchCities() if (m_context->GetInfo()->GetType() == MwmInfo::WORLD) continue; - auto const * cityFeatures = - RetrieveGeometryFeatures(*m_context, city.m_rect, RECT_ID_LOCALITY); + auto cityFeatures = RetrieveGeometryFeatures(*m_context, city.m_rect, RECT_ID_LOCALITY); - if (coding::CompressedBitVector::IsEmpty(cityFeatures)) + if (cityFeatures.IsEmpty()) continue; - LocalityFilter filter(*cityFeatures); - LimitedSearch(filter); + LocalityFilter filter(cityFeatures); + LimitedSearch(ctx, filter); } } } -void Geocoder::MatchAroundPivot() +void Geocoder::MatchAroundPivot(BaseContext & ctx) { - auto const * features = RetrieveGeometryFeatures(*m_context, m_params.m_pivot, RECT_ID_PIVOT); - - if (!features) - return; - - ViewportFilter filter(*features, m_preRanker->Limit() /* threshold */); - LimitedSearch(filter); + auto const features = RetrieveGeometryFeatures(*m_context, m_params.m_pivot, RECT_ID_PIVOT); + ViewportFilter filter(features, m_preRanker->Limit() /* threshold */); + LimitedSearch(ctx, filter); } -void Geocoder::LimitedSearch(FeaturesFilter const & filter) +void Geocoder::LimitedSearch(BaseContext & ctx, FeaturesFilter const & filter) { m_filter = &filter; MY_SCOPE_GUARD(resetFilter, [&]() @@ -1013,36 +986,36 @@ void Geocoder::LimitedSearch(FeaturesFilter const & filter) m_filter = nullptr; }); - if (!m_streets) - m_streets = LoadStreets(*m_context); + if (!ctx.m_streets) + ctx.m_streets = LoadStreets(*m_context); - MatchUnclassified(0 /* curToken */); + MatchUnclassified(ctx, 0 /* curToken */); - auto search = [this]() + auto const search = [this, &ctx]() { - GreedilyMatchStreets(); - MatchPOIsAndBuildings(0 /* curToken */); + GreedilyMatchStreets(ctx); + MatchPOIsAndBuildings(ctx, 0 /* curToken */); }; - WithPostcodes(search); + WithPostcodes(ctx, search); search(); } template -void Geocoder::WithPostcodes(TFn && fn) +void Geocoder::WithPostcodes(BaseContext & ctx, TFn && fn) { size_t const maxPostcodeTokens = GetMaxNumTokensInPostcode(); - for (size_t startToken = 0; startToken != m_numTokens; ++startToken) + for (size_t startToken = 0; startToken != ctx.m_numTokens; ++startToken) { size_t endToken = startToken; - for (size_t n = 1; startToken + n <= m_numTokens && n <= maxPostcodeTokens; ++n) + for (size_t n = 1; startToken + n <= ctx.m_numTokens && n <= maxPostcodeTokens; ++n) { - if (m_usedTokens[startToken + n - 1]) + if (ctx.m_usedTokens[startToken + n - 1]) break; TokenSlice slice(m_params, startToken, startToken + n); - auto const isPrefix = startToken + n == m_numTokens; + auto const isPrefix = startToken + n == ctx.m_numTokens; if (LooksLikePostcode(QuerySlice(slice), isPrefix)) endToken = startToken + n; } @@ -1056,9 +1029,9 @@ void Geocoder::WithPostcodes(TFn && fn) m_postcodes.Clear(); }); - if (!coding::CompressedBitVector::IsEmpty(postcodes)) + if (!postcodes.IsEmpty()) { - ScopedMarkTokens mark(m_usedTokens, startToken, endToken); + ScopedMarkTokens mark(ctx.m_usedTokens, startToken, endToken); m_postcodes.Clear(); m_postcodes.m_startToken = startToken; @@ -1070,140 +1043,57 @@ void Geocoder::WithPostcodes(TFn && fn) } } -void Geocoder::GreedilyMatchStreets() +void Geocoder::GreedilyMatchStreets(BaseContext & ctx) { - for (size_t startToken = 0; startToken < m_numTokens; ++startToken) - { - if (m_usedTokens[startToken]) - continue; + vector predictions; + StreetsMatcher::Go(ctx, *m_filter, m_params, predictions); - // Here we try to match as many tokens as possible while - // intersection is a non-empty bit vector of streets. Single - // tokens that are synonyms to streets are ignored. Moreover, - // each time a token that looks like a beginning of a house number - // is met, we try to use current intersection of tokens as a - // street layer and try to match BUILDINGs or POIs. - CBVPtr allFeatures(m_streets, false /* isOwner */); - - size_t curToken = startToken; - - // This variable is used for prevention of duplicate calls to - // CreateStreetsLayerAndMatchLowerLayers() with the same - // arguments. - size_t lastToken = startToken; - - // When true, no bit vectors were intersected with allFeatures at - // all. - bool emptyIntersection = true; - - // When true, allFeatures is in the incomplete state and can't be - // used for creation of street layers. - bool incomplete = false; - - auto createStreetsLayerAndMatchLowerLayers = [&]() - { - if (!allFeatures.IsEmpty() && !emptyIntersection && !incomplete && lastToken != curToken) - { - CreateStreetsLayerAndMatchLowerLayers(startToken, curToken, *allFeatures); - lastToken = curToken; - } - }; - - StreetTokensFilter filter([&](strings::UniString const & /* token */, size_t tag) - { - auto buffer = coding::CompressedBitVector::Intersect( - *allFeatures, *m_addressFeatures[tag]); - if (tag < curToken) - { - // This is the case for delayed - // street synonym. Therefore, - // allFeatures is temporarily in the - // incomplete state. - allFeatures.Set(move(buffer)); - emptyIntersection = false; - - incomplete = true; - return; - } - ASSERT_EQUAL(tag, curToken, ()); - - // |allFeatures| will become empty - // after the intersection. Therefore - // we need to create streets layer - // right now. - if (coding::CompressedBitVector::IsEmpty(buffer)) - createStreetsLayerAndMatchLowerLayers(); - - allFeatures.Set(move(buffer)); - emptyIntersection = false; - incomplete = false; - }); - - for (; curToken < m_numTokens && !m_usedTokens[curToken] && !allFeatures.IsEmpty(); ++curToken) - { - auto const & token = m_params.GetTokens(curToken).front(); - bool const isPrefix = curToken >= m_params.m_tokens.size(); - - if (house_numbers::LooksLikeHouseNumber(token, isPrefix)) - createStreetsLayerAndMatchLowerLayers(); - - filter.Put(token, isPrefix, curToken); - } - createStreetsLayerAndMatchLowerLayers(); - } + for (auto const & prediction : predictions) + CreateStreetsLayerAndMatchLowerLayers(ctx, prediction); } -void Geocoder::CreateStreetsLayerAndMatchLowerLayers( - size_t startToken, size_t endToken, coding::CompressedBitVector const & features) +void Geocoder::CreateStreetsLayerAndMatchLowerLayers(BaseContext & ctx, + StreetsMatcher::Prediction const & prediction) { ASSERT(m_layers.empty(), ()); - if (coding::CompressedBitVector::IsEmpty(&features)) - return; - - CBVPtr filtered(&features, false /* isOwner */); - if (m_filter->NeedToFilter(features)) - filtered.Set(m_filter->Filter(features).release(), true /* isOwner */); - m_layers.emplace_back(); MY_SCOPE_GUARD(cleanupGuard, bind(&vector::pop_back, &m_layers)); auto & layer = m_layers.back(); - InitLayer(SearchModel::SEARCH_TYPE_STREET, startToken, endToken, layer); + InitLayer(SearchModel::SEARCH_TYPE_STREET, prediction.m_startToken, prediction.m_endToken, layer); vector sortedFeatures; - sortedFeatures.reserve(features.PopCount()); - filtered.ForEach(MakeBackInsertFunctor(sortedFeatures)); + sortedFeatures.reserve(prediction.m_features.PopCount()); + prediction.m_features.ForEach(MakeBackInsertFunctor(sortedFeatures)); layer.m_sortedFeatures = &sortedFeatures; - ScopedMarkTokens mark(m_usedTokens, startToken, endToken); - MatchPOIsAndBuildings(0 /* curToken */); + ScopedMarkTokens mark(ctx.m_usedTokens, prediction.m_startToken, prediction.m_endToken); + MatchPOIsAndBuildings(ctx, 0 /* curToken */); } -void Geocoder::MatchPOIsAndBuildings(size_t curToken) +void Geocoder::MatchPOIsAndBuildings(BaseContext & ctx, size_t curToken) { BailIfCancelled(); - curToken = SkipUsedTokens(curToken); - if (curToken == m_numTokens) + curToken = ctx.SkipUsedTokens(curToken); + if (curToken == ctx.m_numTokens) { // All tokens were consumed, find paths through layers, emit // features. - if (m_postcodes.IsEmpty()) + if (m_postcodes.m_features.IsEmpty()) return FindPaths(); // When there are no layers but user entered a postcode, we have // to emit all features matching to the postcode. if (m_layers.size() == 0) { - CBVPtr filtered; - if (m_filter->NeedToFilter(*m_postcodes.m_features)) - filtered.Set(m_filter->Filter(*m_postcodes.m_features)); - else - filtered.Set(m_postcodes.m_features.get(), false /* isOwner */); + CBV filtered = m_postcodes.m_features; + if (m_filter->NeedToFilter(m_postcodes.m_features)) + filtered = m_filter->Filter(m_postcodes.m_features); filtered.ForEach([&](uint32_t id) { - EmitResult(m_context->GetId(), id, GetSearchTypeInGeocoding(id), + EmitResult(m_context->GetId(), id, GetSearchTypeInGeocoding(ctx, id), m_postcodes.m_startToken, m_postcodes.m_endToken); }); return; @@ -1222,7 +1112,7 @@ void Geocoder::MatchPOIsAndBuildings(size_t curToken) { for (auto const & id : *m_layers.back().m_sortedFeatures) { - if (!m_postcodes.Has(id)) + if (!m_postcodes.m_features.HasBit(id)) continue; EmitResult(m_context->GetId(), id, SearchModel::SEARCH_TYPE_STREET, m_layers.back().m_startToken, m_layers.back().m_endToken); @@ -1239,8 +1129,7 @@ void Geocoder::MatchPOIsAndBuildings(size_t curToken) layer); vector features; - coding::CompressedBitVectorEnumerator::ForEach(*m_postcodes.m_features, - MakeBackInsertFunctor(features)); + m_postcodes.m_features.ForEach(MakeBackInsertFunctor(features)); layer.m_sortedFeatures = &features; return FindPaths(); } @@ -1257,23 +1146,23 @@ void Geocoder::MatchPOIsAndBuildings(size_t curToken) // any. auto clusterize = [&](uint32_t featureId) { - auto const searchType = GetSearchTypeInGeocoding(featureId); + auto const searchType = GetSearchTypeInGeocoding(ctx, featureId); // All SEARCH_TYPE_CITY features were filtered in // MatchCities(). All SEARCH_TYPE_STREET features were // filtered in GreedilyMatchStreets(). if (searchType < kNumClusters) { - if (m_postcodes.IsEmpty() || m_postcodes.m_features->GetBit(featureId)) + if (m_postcodes.m_features.IsEmpty() || m_postcodes.m_features.HasBit(featureId)) clusters[searchType].push_back(featureId); } }; - CBVPtr features; + CBV features; features.SetFull(); // Try to consume [curToken, m_numTokens) tokens range. - for (size_t n = 1; curToken + n <= m_numTokens && !m_usedTokens[curToken + n - 1]; ++n) + for (size_t n = 1; curToken + n <= ctx.m_numTokens && !ctx.m_usedTokens[curToken + n - 1]; ++n) { // At this point |features| is the intersection of // m_addressFeatures[curToken], m_addressFeatures[curToken + 1], @@ -1286,15 +1175,11 @@ void Geocoder::MatchPOIsAndBuildings(size_t curToken) InitLayer(layer.m_type, curToken, curToken + n, layer); } - features.Intersect(m_addressFeatures[curToken + n - 1].get()); - ASSERT(features.Get(), ()); + features = features.Intersect(ctx.m_features[curToken + n - 1]); - CBVPtr filtered; - if (m_filter->NeedToFilter(*features)) - filtered.Set(m_filter->Filter(*features)); - else - filtered.Set(features.Get(), false /* isOwner */); - ASSERT(filtered.Get(), ()); + CBV filtered = features; + if (m_filter->NeedToFilter(features)) + filtered = m_filter->Filter(features); bool const looksLikeHouseNumber = house_numbers::LooksLikeHouseNumber( m_layers.back().m_subQuery, m_layers.back().m_lastTokenIsPrefix); @@ -1310,7 +1195,7 @@ void Geocoder::MatchPOIsAndBuildings(size_t curToken) { auto noFeature = [&filtered](uint32_t featureId) -> bool { - return !filtered->GetBit(featureId); + return !filtered.HasBit(featureId); }; for (auto & cluster : clusters) my::EraseIf(cluster, noFeature); @@ -1357,7 +1242,7 @@ void Geocoder::MatchPOIsAndBuildings(size_t curToken) layer.m_type = static_cast(i); if (IsLayerSequenceSane()) - MatchPOIsAndBuildings(curToken + n); + MatchPOIsAndBuildings(ctx, curToken + n); } } } @@ -1422,7 +1307,10 @@ void Geocoder::FindPaths() auto const & innermostLayer = *sortedLayers.front(); - m_matcher->SetPostcodes(m_postcodes.m_features.get()); + if (!m_postcodes.m_features.IsEmpty()) + m_matcher->SetPostcodes(&m_postcodes.m_features); + else + m_matcher->SetPostcodes(nullptr); m_finder.ForEachReachableVertex( *m_matcher, sortedLayers, [this, &innermostLayer](IntersectionResult const & result) { @@ -1505,7 +1393,7 @@ void Geocoder::FillMissingFieldsInResults() } } -void Geocoder::MatchUnclassified(size_t curToken) +void Geocoder::MatchUnclassified(BaseContext & ctx, size_t curToken) { ASSERT(m_layers.empty(), ()); @@ -1516,36 +1404,32 @@ void Geocoder::MatchUnclassified(size_t curToken) // adjacent tokens will be matched to "Hyde Park", whereas it's not // ok to match something to "Park London Hyde", because tokens // "Park" and "Hyde" are not adjacent. - if (NumUnusedTokensGroups() != 1) + if (ctx.NumUnusedTokenGroups() != 1) return; - CBVPtr allFeatures; + CBV allFeatures; allFeatures.SetFull(); auto startToken = curToken; - for (curToken = SkipUsedTokens(curToken); curToken < m_numTokens && !m_usedTokens[curToken]; - ++curToken) + for (curToken = ctx.SkipUsedTokens(curToken); + curToken < ctx.m_numTokens && !ctx.m_usedTokens[curToken]; ++curToken) { - allFeatures.Intersect(m_addressFeatures[curToken].get()); + allFeatures = allFeatures.Intersect(ctx.m_features[curToken]); } - if (m_filter->NeedToFilter(*allFeatures)) - allFeatures.Set(m_filter->Filter(*allFeatures).release(), true /* isOwner */); - - if (allFeatures.IsEmpty()) - return; + if (m_filter->NeedToFilter(allFeatures)) + allFeatures = m_filter->Filter(allFeatures); auto emitUnclassified = [&](uint32_t featureId) { - auto type = GetSearchTypeInGeocoding(featureId); + auto type = GetSearchTypeInGeocoding(ctx, featureId); if (type == SearchModel::SEARCH_TYPE_UNCLASSIFIED) EmitResult(m_context->GetId(), featureId, type, startToken, curToken); }; allFeatures.ForEach(emitUnclassified); } -unique_ptr Geocoder::LoadCategories( - MwmContext & context, vector const & categories) +CBV Geocoder::LoadCategories(MwmContext & context, vector const & categories) { ASSERT(context.m_handle.IsAlive(), ()); ASSERT(HasSearchIndex(context.m_value), ()); @@ -1554,71 +1438,69 @@ unique_ptr Geocoder::LoadCategories( m_retrievalParams.m_tokens[0].resize(1); m_retrievalParams.m_prefixTokens.clear(); - vector> cbvs; + vector cbvs; for_each(categories.begin(), categories.end(), [&](strings::UniString const & category) { m_retrievalParams.m_tokens[0][0] = category; - auto cbv = RetrieveAddressFeatures(context.GetId(), context.m_value, m_cancellable, - m_retrievalParams); - if (!coding::CompressedBitVector::IsEmpty(cbv)) + CBV cbv(RetrieveAddressFeatures(context.GetId(), context.m_value, m_cancellable, + m_retrievalParams)); + if (!cbv.IsEmpty()) cbvs.push_back(move(cbv)); }); UniteCBVs(cbvs); if (cbvs.empty()) - cbvs.push_back(make_unique()); + cbvs.emplace_back(); return move(cbvs[0]); } -coding::CompressedBitVector const * Geocoder::LoadStreets(MwmContext & context) +CBV Geocoder::LoadStreets(MwmContext & context) { if (!context.m_handle.IsAlive() || !HasSearchIndex(context.m_value)) - return nullptr; + return CBV(); auto mwmId = context.m_handle.GetId(); auto const it = m_streetsCache.find(mwmId); if (it != m_streetsCache.cend()) - return it->second.get(); + return it->second; auto streets = LoadCategories(context, StreetCategories::Instance().GetCategories()); - - auto const * result = streets.get(); - m_streetsCache[mwmId] = move(streets); - return result; + m_streetsCache[mwmId] = streets; + return streets; } -unique_ptr Geocoder::LoadVillages(MwmContext & context) +CBV Geocoder::LoadVillages(MwmContext & context) { if (!context.m_handle.IsAlive() || !HasSearchIndex(context.m_value)) - return make_unique(); + return CBV(); return LoadCategories(context, GetVillageCategories()); } -unique_ptr Geocoder::RetrievePostcodeFeatures( - MwmContext const & context, TokenSlice const & slice) +CBV Geocoder::RetrievePostcodeFeatures(MwmContext const & context, TokenSlice const & slice) { - return ::search::RetrievePostcodeFeatures(context.GetId(), context.m_value, m_cancellable, slice); + return CBV( + ::search::RetrievePostcodeFeatures(context.GetId(), context.m_value, m_cancellable, slice)); } -coding::CompressedBitVector const * Geocoder::RetrieveGeometryFeatures(MwmContext const & context, - m2::RectD const & rect, - RectId id) +CBV Geocoder::RetrieveGeometryFeatures(MwmContext const & context, m2::RectD const & rect, + RectId id) { switch (id) { case RECT_ID_PIVOT: return m_pivotRectsCache.Get(context, rect, m_params.m_scale); case RECT_ID_LOCALITY: return m_localityRectsCache.Get(context, rect, m_params.m_scale); - case RECT_ID_COUNT: ASSERT(false, ("Invalid RectId.")); return nullptr; + case RECT_ID_COUNT: ASSERT(false, ("Invalid RectId.")); return CBV(); } } -SearchModel::SearchType Geocoder::GetSearchTypeInGeocoding(uint32_t featureId) +SearchModel::SearchType Geocoder::GetSearchTypeInGeocoding(BaseContext const & ctx, + uint32_t featureId) { - if (m_streets->GetBit(featureId)) + if (ctx.m_streets.HasBit(featureId)) return SearchModel::SEARCH_TYPE_STREET; - if (m_villages->GetBit(featureId)) + if (ctx.m_villages.HasBit(featureId)) return SearchModel::SEARCH_TYPE_VILLAGE; FeatureType feature; @@ -1626,34 +1508,6 @@ SearchModel::SearchType Geocoder::GetSearchTypeInGeocoding(uint32_t featureId) return m_model.GetSearchType(feature); } -bool Geocoder::AllTokensUsed() const -{ - return all_of(m_usedTokens.begin(), m_usedTokens.end(), IdFunctor()); -} - -bool Geocoder::HasUsedTokensInRange(size_t from, size_t to) const -{ - return any_of(m_usedTokens.begin() + from, m_usedTokens.begin() + to, IdFunctor()); -} - -size_t Geocoder::NumUnusedTokensGroups() const -{ - size_t numGroups = 0; - for (size_t i = 0; i < m_usedTokens.size(); ++i) - { - if (!m_usedTokens[i] && (i == 0 || m_usedTokens[i - 1])) - ++numGroups; - } - return numGroups; -} - -size_t Geocoder::SkipUsedTokens(size_t curToken) const -{ - while (curToken != m_usedTokens.size() && m_usedTokens[curToken]) - ++curToken; - return curToken; -} - string DebugPrint(Geocoder::Locality const & locality) { ostringstream os; @@ -1661,5 +1515,4 @@ string DebugPrint(Geocoder::Locality const & locality) << ", startToken=" << locality.m_startToken << ", endToken=" << locality.m_endToken << "]"; return os.str(); } - } // namespace search diff --git a/search/geocoder.hpp b/search/geocoder.hpp index c771402ad9..6347c03cf6 100644 --- a/search/geocoder.hpp +++ b/search/geocoder.hpp @@ -3,6 +3,7 @@ #include "search/cancel_exception.hpp" #include "search/features_layer.hpp" #include "search/features_layer_path_finder.hpp" +#include "search/geocoder_context.hpp" #include "search/geometry_cache.hpp" #include "search/mode.hpp" #include "search/model.hpp" @@ -11,6 +12,7 @@ #include "search/pre_ranking_info.hpp" #include "search/query_params.hpp" #include "search/ranking_utils.hpp" +#include "search/streets_matcher.hpp" #include "indexer/index.hpp" #include "indexer/mwm_set.hpp" @@ -170,16 +172,12 @@ private: { m_startToken = 0; m_endToken = 0; - m_features.reset(); + m_features.Reset(); } - inline bool Has(uint64_t id) const { return m_features->GetBit(id); } - - inline bool IsEmpty() const { return coding::CompressedBitVector::IsEmpty(m_features); } - size_t m_startToken = 0; size_t m_endToken = 0; - unique_ptr m_features; + CBV m_features; }; void GoImpl(PreRanker & preRanker, vector> & infos, bool inViewport); @@ -195,60 +193,57 @@ private: // Creates a cache of posting lists corresponding to features in m_context // for each token and saves it to m_addressFeatures. - void PrepareAddressFeatures(); + void InitBaseContext(BaseContext & ctx); void InitLayer(SearchModel::SearchType type, size_t startToken, size_t endToken, FeaturesLayer & layer); - void FillLocalityCandidates(coding::CompressedBitVector const * filter, + void FillLocalityCandidates(BaseContext const & ctx, CBV const & filter, size_t const maxNumLocalities, vector & preLocalities); - void FillLocalitiesTable(); + void FillLocalitiesTable(BaseContext const & ctx); - void FillVillageLocalities(); + void FillVillageLocalities(BaseContext const & ctx); template void ForEachCountry(vector> const & infos, TFn && fn); // Throws CancelException if cancelled. - inline void BailIfCancelled() - { - ::search::BailIfCancelled(m_cancellable); - } + inline void BailIfCancelled() { ::search::BailIfCancelled(m_cancellable); } // Tries to find all countries and states in a search query and then // performs matching of cities in found maps. - void MatchRegions(RegionType type); + void MatchRegions(BaseContext & ctx, RegionType type); // Tries to find all cities in a search query and then performs // matching of streets in found cities. - void MatchCities(); + void MatchCities(BaseContext & ctx); // Tries to do geocoding without localities, ie. find POIs, // BUILDINGs and STREETs without knowledge about country, state, // city or village. If during the geocoding too many features are // retrieved, viewport is used to throw away excess features. - void MatchAroundPivot(); + void MatchAroundPivot(BaseContext & ctx); // Tries to do geocoding in a limited scope, assuming that knowledge // about high-level features, like cities or countries, is // incorporated into |filter|. - void LimitedSearch(FeaturesFilter const & filter); + void LimitedSearch(BaseContext & ctx, FeaturesFilter const & filter); template - void WithPostcodes(TFn && fn); + void WithPostcodes(BaseContext & ctx, TFn && fn); // Tries to match some adjacent tokens in the query as streets and // then performs geocoding in street vicinities. - void GreedilyMatchStreets(); + void GreedilyMatchStreets(BaseContext & ctx); - void CreateStreetsLayerAndMatchLowerLayers(size_t startToken, size_t endToken, - coding::CompressedBitVector const & features); + void CreateStreetsLayerAndMatchLowerLayers(BaseContext & ctx, + StreetsMatcher::Prediction const & prediction); // Tries to find all paths in a search tree, where each edge is // marked with some substring of the query tokens. These paths are // called "layer sequence" and current path is stored in |m_layers|. - void MatchPOIsAndBuildings(size_t curToken); + void MatchPOIsAndBuildings(BaseContext & ctx, size_t curToken); // Returns true if current path in the search tree (see comment for // MatchPOIsAndBuildings()) looks sane. This method is used as a fast @@ -271,40 +266,23 @@ private: // Tries to match unclassified objects from lower layers, like // parks, forests, lakes, rivers, etc. This method finds all // UNCLASSIFIED objects that match to all currently unused tokens. - void MatchUnclassified(size_t curToken); + void MatchUnclassified(BaseContext & ctx, size_t curToken); - unique_ptr LoadCategories( - MwmContext & context, vector const & categories); + CBV LoadCategories(MwmContext & context, vector const & categories); - coding::CompressedBitVector const * LoadStreets(MwmContext & context); + CBV LoadStreets(MwmContext & context); - unique_ptr LoadVillages(MwmContext & context); + CBV LoadVillages(MwmContext & context); // A wrapper around RetrievePostcodeFeatures. - unique_ptr RetrievePostcodeFeatures(MwmContext const & context, - TokenSlice const & slice); + CBV RetrievePostcodeFeatures(MwmContext const & context, TokenSlice const & slice); // A caching wrapper around Retrieval::RetrieveGeometryFeatures. - coding::CompressedBitVector const * RetrieveGeometryFeatures(MwmContext const & context, - m2::RectD const & rect, RectId id); + CBV RetrieveGeometryFeatures(MwmContext const & context, m2::RectD const & rect, RectId id); // This is a faster wrapper around SearchModel::GetSearchType(), as // it uses pre-loaded lists of streets and villages. - SearchModel::SearchType GetSearchTypeInGeocoding(uint32_t featureId); - - // Returns true iff all tokens are used. - bool AllTokensUsed() const; - - // Returns true if there exists at least one used token in [from, - // to). - bool HasUsedTokensInRange(size_t from, size_t to) const; - - // Counts number of groups of consecutive unused tokens. - size_t NumUnusedTokensGroups() const; - - // Advances |curToken| to the nearest unused token, or to the end of - // |m_usedTokens| if there are no unused tokens. - size_t SkipUsedTokens(size_t curToken) const; + SearchModel::SearchType GetSearchTypeInGeocoding(BaseContext const & ctx, uint32_t featureId); Index const & m_index; @@ -315,9 +293,6 @@ private: // Geocoder params. Params m_params; - // Total number of search query tokens. - size_t m_numTokens; - // This field is used to map features to a limited number of search // classes. SearchModel const & m_model; @@ -344,30 +319,12 @@ private: // Cache of nested rects used to estimate distance from a feature to the pivot. NestedRectsCache m_pivotFeatures; - // Cache of posting lists for each token in the query. TODO (@y, - // @m, @vng): consider to update this cache lazily, as user inputs - // tokens one-by-one. - vector> m_addressFeatures; - // Cache of street ids in mwms. - map> m_streetsCache; - - // Street features in the mwm that is currently being processed. - // The initialization of m_streets is postponed in order to gain - // some speed. Therefore m_streets may be used only in - // LimitedSearch() and in all its callees. - coding::CompressedBitVector const * m_streets; - - // Village features in the mwm that is currently being processed. - unique_ptr m_villages; + map m_streetsCache; // Postcodes features in the mwm that is currently being processed. Postcodes m_postcodes; - // This vector is used to indicate what tokens were matched by - // locality and can't be re-used during the geocoding process. - vector m_usedTokens; - // This filter is used to throw away excess features. FeaturesFilter const * m_filter; diff --git a/search/geocoder_context.cpp b/search/geocoder_context.cpp new file mode 100644 index 0000000000..9f03510916 --- /dev/null +++ b/search/geocoder_context.cpp @@ -0,0 +1,36 @@ +#include "search/geocoder_context.hpp" + +#include "base/stl_add.hpp" + +#include "std/algorithm.hpp" + +namespace search +{ +size_t BaseContext::SkipUsedTokens(size_t curToken) const +{ + while (curToken != m_usedTokens.size() && m_usedTokens[curToken]) + ++curToken; + return curToken; +} + +bool BaseContext::AllTokensUsed() const +{ + return all_of(m_usedTokens.begin(), m_usedTokens.end(), IdFunctor()); +} + +bool BaseContext::HasUsedTokensInRange(size_t from, size_t to) const +{ + return any_of(m_usedTokens.begin() + from, m_usedTokens.begin() + to, IdFunctor()); +} + +size_t BaseContext::NumUnusedTokenGroups() const +{ + size_t numGroups = 0; + for (size_t i = 0; i < m_usedTokens.size(); ++i) + { + if (!m_usedTokens[i] && (i == 0 || m_usedTokens[i - 1])) + ++numGroups; + } + return numGroups; +} +} // namespace search diff --git a/search/geocoder_context.hpp b/search/geocoder_context.hpp new file mode 100644 index 0000000000..63b7d4696f --- /dev/null +++ b/search/geocoder_context.hpp @@ -0,0 +1,41 @@ +#pragma once + +#include "search/cbv.hpp" + +#include "std/unique_ptr.hpp" +#include "std/vector.hpp" + +namespace search +{ +class FeaturesFilter; + +struct BaseContext +{ + // Advances |curToken| to the nearest unused token, or to the end of + // |m_usedTokens| if there are no unused tokens. + size_t SkipUsedTokens(size_t curToken) const; + + // Returns true iff all tokens are used. + bool AllTokensUsed() const; + + // Returns true if there exists at least one used token in [from, + // to). + bool HasUsedTokensInRange(size_t from, size_t to) const; + + // Counts number of groups of consecutive unused tokens. + size_t NumUnusedTokenGroups() const; + + // List of bit-vectors of features, where i-th element of the list + // corresponds to the i-th token in the search query. + vector m_features; + CBV m_villages; + CBV m_streets; + + // This vector is used to indicate what tokens were already matched + // and can't be re-used during the geocoding process. + vector m_usedTokens; + + // Number of tokens in the query. + size_t m_numTokens = 0; +}; +} // namespace search diff --git a/search/geometry_cache.cpp b/search/geometry_cache.cpp index 275c2ff875..990cae77cd 100644 --- a/search/geometry_cache.cpp +++ b/search/geometry_cache.cpp @@ -35,8 +35,7 @@ PivotRectsCache::PivotRectsCache(size_t maxNumEntries, my::Cancellable const & c { } -coding::CompressedBitVector const * PivotRectsCache::Get(MwmContext const & context, - m2::RectD const & rect, int scale) +CBV PivotRectsCache::Get(MwmContext const & context, m2::RectD const & rect, int scale) { auto p = FindOrCreateEntry( context.GetId(), [&rect, &scale](Entry const & entry) @@ -53,7 +52,7 @@ coding::CompressedBitVector const * PivotRectsCache::Get(MwmContext const & cont normRect = rect; InitEntry(context, normRect, scale, entry); } - return entry.m_cbv.get(); + return entry.m_cbv; } // LocalityRectsCache ------------------------------------------------------------------------------ @@ -62,8 +61,7 @@ LocalityRectsCache::LocalityRectsCache(size_t maxNumEntries, my::Cancellable con { } -coding::CompressedBitVector const * LocalityRectsCache::Get(MwmContext const & context, - m2::RectD const & rect, int scale) +CBV LocalityRectsCache::Get(MwmContext const & context, m2::RectD const & rect, int scale) { auto p = FindOrCreateEntry(context.GetId(), [&rect, &scale](Entry const & entry) { @@ -73,7 +71,7 @@ coding::CompressedBitVector const * LocalityRectsCache::Get(MwmContext const & c auto & entry = p.first; if (p.second) InitEntry(context, rect, scale, entry); - return entry.m_cbv.get(); + return entry.m_cbv; } } // namespace search diff --git a/search/geometry_cache.hpp b/search/geometry_cache.hpp index d5d305301b..2b36707cc2 100644 --- a/search/geometry_cache.hpp +++ b/search/geometry_cache.hpp @@ -1,8 +1,8 @@ #pragma once -#include "indexer/mwm_set.hpp" +#include "search/cbv.hpp" -#include "coding/compressed_bit_vector.hpp" +#include "indexer/mwm_set.hpp" #include "geometry/rect2d.hpp" @@ -35,8 +35,7 @@ public: // Returns (hopefully, cached) list of features in a given // rect. Note that return value may be invalidated on next calls to // this method. - virtual coding::CompressedBitVector const * Get(MwmContext const & context, - m2::RectD const & rect, int scale) = 0; + virtual CBV Get(MwmContext const & context, m2::RectD const & rect, int scale) = 0; inline void Clear() { m_entries.clear(); } @@ -44,7 +43,7 @@ protected: struct Entry { m2::RectD m_rect; - unique_ptr m_cbv; + CBV m_cbv; int m_scale = 0; }; @@ -87,8 +86,7 @@ public: double maxRadiusMeters); // GeometryCache overrides: - coding::CompressedBitVector const * Get(MwmContext const & context, m2::RectD const & rect, - int scale) override; + CBV Get(MwmContext const & context, m2::RectD const & rect, int scale) override; private: double const m_maxRadiusMeters; @@ -100,8 +98,7 @@ public: LocalityRectsCache(size_t maxNumEntries, my::Cancellable const & cancellable); // GeometryCache overrides: - coding::CompressedBitVector const * Get(MwmContext const & context, m2::RectD const & rect, - int scale) override; + CBV Get(MwmContext const & context, m2::RectD const & rect, int scale) override; }; } // namespace search diff --git a/search/mwm_context.hpp b/search/mwm_context.hpp index 200f5cd053..ca9b50b524 100644 --- a/search/mwm_context.hpp +++ b/search/mwm_context.hpp @@ -90,5 +90,4 @@ private: DISALLOW_COPY_AND_MOVE(MwmContext); }; - } // namespace search diff --git a/search/query_params.hpp b/search/query_params.hpp index f3bc6f90eb..2f00bc6c34 100644 --- a/search/query_params.hpp +++ b/search/query_params.hpp @@ -36,6 +36,10 @@ struct QueryParams TSynonymsVector const & GetTokens(size_t i) const; TSynonymsVector & GetTokens(size_t i); + inline size_t GetNumTokens() const + { + return m_prefixTokens.empty() ? m_tokens.size() : m_tokens.size() + 1; + } /// @return true if all tokens in [start, end) range has number synonym. bool IsNumberTokens(size_t start, size_t end) const; diff --git a/search/search.pro b/search/search.pro index eec3a48784..b25b324e4e 100644 --- a/search/search.pro +++ b/search/search.pro @@ -12,7 +12,7 @@ HEADERS += \ algos.hpp \ approximate_string_match.hpp \ cancel_exception.hpp \ - cbv_ptr.hpp \ + cbv.hpp \ common.hpp \ dummy_rank_table.hpp \ engine.hpp \ @@ -22,6 +22,7 @@ HEADERS += \ features_layer_matcher.hpp \ features_layer_path_finder.hpp \ geocoder.hpp \ + geocoder_context.hpp \ geometry_cache.hpp \ geometry_utils.hpp \ house_detector.hpp \ @@ -60,6 +61,7 @@ HEADERS += \ search_trie.hpp \ stats_cache.hpp \ street_vicinity_loader.hpp \ + streets_matcher.hpp \ string_intersection.hpp \ suggest.hpp \ token_slice.hpp \ @@ -68,7 +70,7 @@ HEADERS += \ SOURCES += \ approximate_string_match.cpp \ - cbv_ptr.cpp \ + cbv.cpp \ dummy_rank_table.cpp \ engine.cpp \ features_filter.cpp \ @@ -76,6 +78,7 @@ SOURCES += \ features_layer_matcher.cpp \ features_layer_path_finder.cpp \ geocoder.cpp \ + geocoder_context.cpp \ geometry_cache.cpp \ geometry_utils.cpp \ house_detector.cpp \ @@ -109,5 +112,6 @@ SOURCES += \ retrieval.cpp \ reverse_geocoder.cpp \ street_vicinity_loader.cpp \ + streets_matcher.cpp \ token_slice.cpp \ types_skipper.cpp \ diff --git a/search/streets_matcher.cpp b/search/streets_matcher.cpp new file mode 100644 index 0000000000..dfd524f327 --- /dev/null +++ b/search/streets_matcher.cpp @@ -0,0 +1,168 @@ +#include "search/streets_matcher.hpp" +#include "search/features_filter.hpp" +#include "search/house_numbers_matcher.hpp" +#include "search/query_params.hpp" + +#include "indexer/search_string_utils.hpp" + +#include "base/logging.hpp" +#include "base/stl_helpers.hpp" + +namespace search +{ +namespace +{ +bool LessByHash(StreetsMatcher::Prediction const & lhs, StreetsMatcher::Prediction const & rhs) +{ + if (lhs.m_hash != rhs.m_hash) + return lhs.m_hash < rhs.m_hash; + + if (lhs.m_prob != rhs.m_prob) + return lhs.m_prob > rhs.m_prob; + + if (lhs.GetNumTokens() != rhs.GetNumTokens()) + return lhs.GetNumTokens() > rhs.GetNumTokens(); + + return lhs.m_startToken < rhs.m_startToken; +} +} // namespace + +// static +void StreetsMatcher::Go(BaseContext const & ctx, FeaturesFilter const & filter, + QueryParams const & params, vector & predictions) +{ + size_t const kMaxNumOfImprobablePredictions = 3; + double const kTailProbability = 0.05; + + predictions.clear(); + FindStreets(ctx, filter, params, predictions); + + if (predictions.empty()) + return; + + sort(predictions.begin(), predictions.end(), &LessByHash); + predictions.erase( + unique(predictions.begin(), predictions.end(), my::EqualsBy(&Prediction::m_hash)), + predictions.end()); + + sort(predictions.rbegin(), predictions.rend(), my::LessBy(&Prediction::m_prob)); + while (predictions.size() > kMaxNumOfImprobablePredictions && + predictions.back().m_prob < kTailProbability) + { + predictions.pop_back(); + } +} + +// static +void StreetsMatcher::FindStreets(BaseContext const & ctx, FeaturesFilter const & filter, + QueryParams const & params, vector & predictions) +{ + for (size_t startToken = 0; startToken < ctx.m_numTokens; ++startToken) + { + if (ctx.m_usedTokens[startToken]) + continue; + + // Here we try to match as many tokens as possible while + // intersection is a non-empty bit vector of streets. Single + // tokens that are synonyms to streets are ignored. Moreover, + // each time a token that looks like a beginning of a house number + // is met, we try to use current intersection of tokens as a + // street layer and try to match BUILDINGs or POIs. + CBV streets(ctx.m_streets); + + CBV all; + all.SetFull(); + + size_t curToken = startToken; + + // This variable is used for prevention of duplicate calls to + // CreateStreetsLayerAndMatchLowerLayers() with the same + // arguments. + size_t lastToken = startToken; + + // When true, no bit vectors were intersected with |streets| at all. + bool emptyIntersection = true; + + // When true, |streets| is in the incomplete state and can't be + // used for creation of street layers. + bool incomplete = false; + + auto emit = [&]() + { + if (!streets.IsEmpty() && !emptyIntersection && !incomplete && lastToken != curToken) + { + CBV fs(streets); + CBV fa(all); + + ASSERT(!fs.IsFull(), ()); + ASSERT(!fa.IsFull(), ()); + + if (filter.NeedToFilter(fs)) + fs = filter.Filter(fs); + + if (fs.IsEmpty()) + return; + + if (filter.NeedToFilter(fa)) + fa = filter.Filter(fa).Union(fs); + + predictions.emplace_back(); + auto & prediction = predictions.back(); + + prediction.m_startToken = startToken; + prediction.m_endToken = curToken; + + ASSERT_NOT_EQUAL(fs.PopCount(), 0, ()); + ASSERT_LESS_OR_EQUAL(fs.PopCount(), fa.PopCount(), ()); + prediction.m_prob = static_cast(fs.PopCount()) / static_cast(fa.PopCount()); + + prediction.m_features = move(fs); + prediction.m_hash = prediction.m_features.Hash(); + } + }; + + StreetTokensFilter filter([&](strings::UniString const & /* token */, size_t tag) + { + auto buffer = streets.Intersect(ctx.m_features[tag]); + if (tag < curToken) + { + // This is the case for delayed + // street synonym. Therefore, + // |streets| is temporarily in the + // incomplete state. + streets = buffer; + all = all.Intersect(ctx.m_features[tag]); + emptyIntersection = false; + + incomplete = true; + return; + } + ASSERT_EQUAL(tag, curToken, ()); + + // |streets| will become empty after + // the intersection. Therefore we need + // to create streets layer right now. + if (buffer.IsEmpty()) + emit(); + + streets = buffer; + all = all.Intersect(ctx.m_features[tag]); + emptyIntersection = false; + incomplete = false; + }); + + for (; curToken < ctx.m_numTokens && !ctx.m_usedTokens[curToken] && !streets.IsEmpty(); + ++curToken) + { + auto const & token = params.GetTokens(curToken).front(); + bool const isPrefix = curToken >= params.m_tokens.size(); + + if (house_numbers::LooksLikeHouseNumber(token, isPrefix)) + emit(); + + filter.Put(token, isPrefix, curToken); + } + emit(); + } +} +} // namespace search diff --git a/search/streets_matcher.hpp b/search/streets_matcher.hpp new file mode 100644 index 0000000000..7918380774 --- /dev/null +++ b/search/streets_matcher.hpp @@ -0,0 +1,37 @@ +#pragma once + +#include "search/cbv.hpp" +#include "search/geocoder_context.hpp" + +#include "std/vector.hpp" + +namespace search +{ +class FeaturesFilter; +struct QueryParams; + +class StreetsMatcher +{ +public: + struct Prediction + { + inline size_t GetNumTokens() const { return m_endToken - m_startToken; } + + CBV m_features; + + size_t m_startToken = 0; + size_t m_endToken = 0; + + double m_prob = 0.0; + + uint64_t m_hash = 0; + }; + + static void Go(BaseContext const & ctx, FeaturesFilter const & filter, QueryParams const & params, + vector & predictions); + +private: + static void FindStreets(BaseContext const & ctx, FeaturesFilter const & filter, + QueryParams const & params, vector & prediction); +}; +} // namespace search