[search] Switch from std::{vector|set} to CBV in retrieval.

This commit is contained in:
Yuri Gorshenin 2015-10-22 12:20:01 +03:00 committed by Sergey Yershov
parent 07185d0cc5
commit 1bef3f5d89
6 changed files with 191 additions and 102 deletions

View file

@ -42,7 +42,8 @@ class RankTable
public:
enum Version
{
V0 = 0
V0 = 0,
VERSION_COUNT
};
virtual ~RankTable() = default;

View file

@ -9,6 +9,7 @@
#include "indexer/scales.hpp"
#include "indexer/search_trie.hpp"
#include "coding/compressed_bit_vector.hpp"
#include "coding/reader_wrapper.hpp"
#include "base/assert.hpp"
@ -42,6 +43,19 @@ struct CancelException : public exception
{
};
template<typename T>
void SortUnique(std::vector<T> & v)
{
sort(v.begin(), v.end());
v.erase(unique(v.begin(), v.end()), v.end());
}
unique_ptr<coding::CompressedBitVector> SortFeaturesAndBuildCBV(vector<uint64_t> && features)
{
SortUnique(features);
return coding::CompressedBitVectorBuilder::FromBitPositions(move(features));
}
void CoverRect(m2::RectD const & rect, int scale, covering::IntervalsT & result)
{
covering::CoveringGetter covering(rect, covering::ViewportWithLowLevels);
@ -51,12 +65,10 @@ void CoverRect(m2::RectD const & rect, int scale, covering::IntervalsT & result)
// Retrieves from the search index corresponding to |handle| all
// features matching to |params|.
template <typename ToDo>
void RetrieveAddressFeatures(MwmSet::MwmHandle const & handle, SearchQueryParams const & params,
ToDo && toDo)
unique_ptr<coding::CompressedBitVector> RetrieveAddressFeatures(MwmSet::MwmHandle const & handle,
my::Cancellable const & cancellable,
SearchQueryParams const & params)
{
auto emptyFilter = [](uint32_t /* featureId */) { return true; };
auto * value = handle.GetValue<MwmValue>();
ASSERT(value, ());
serial::CodingParams codingParams(trie::GetCodingParams(value->GetHeader().GetDefCodingParams()));
@ -64,22 +76,48 @@ void RetrieveAddressFeatures(MwmSet::MwmHandle const & handle, SearchQueryParams
auto const trieRoot = trie::ReadTrie(SubReaderWrapper<Reader>(searchReader.GetPtr()),
trie::ValueReader(codingParams));
MatchFeaturesInTrie(params, *trieRoot, emptyFilter, forward<ToDo>(toDo));
auto emptyFilter = [](uint32_t /* featureId */)
{
return true;
};
// TODO (@y, @m): remove this code as soon as search index will have native support for bit
// vectors.
vector<uint64_t> features;
auto collector = [&](trie::ValueReader::ValueType const & value)
{
if (cancellable.IsCancelled())
throw CancelException();
features.push_back(value.m_featureId);
};
MatchFeaturesInTrie(params, *trieRoot, emptyFilter, collector);
return SortFeaturesAndBuildCBV(move(features));
}
// Retrieves from the geomery index corresponding to handle all
// features in (and, possibly, around) viewport and executes |toDo| on
// them.
template <typename ToDo>
void RetrieveGeometryFeatures(MwmSet::MwmHandle const & handle,
covering::IntervalsT const & covering, int scale, ToDo && toDo)
// features in (and, possibly, around) viewport.
unique_ptr<coding::CompressedBitVector> RetrieveGeometryFeatures(
MwmSet::MwmHandle const & handle, my::Cancellable const & cancellable,
covering::IntervalsT const & covering, int scale)
{
auto * value = handle.GetValue<MwmValue>();
ASSERT(value, ());
// TODO (@y, @m): remove this code as soon as search index will have native support for bit
// vectors.
vector<uint64_t> features;
auto collector = [&](uint64_t featureId)
{
if (cancellable.IsCancelled())
throw CancelException();
features.push_back(featureId);
};
ScaleIndex<ModelReaderPtr> index(value->m_cont.GetReader(INDEX_FILE_TAG), value->m_factory);
for (auto const & interval : covering)
index.ForEachInIntervalAndScale(toDo, interval.first, interval.second, scale);
index.ForEachInIntervalAndScale(collector, interval.first, interval.second, scale);
return SortFeaturesAndBuildCBV(move(features));
}
// This class represents a fast retrieval strategy. When number of
@ -89,23 +127,31 @@ class FastPathStrategy : public Retrieval::Strategy
{
public:
FastPathStrategy(Index const & index, MwmSet::MwmHandle & handle, m2::RectD const & viewport,
vector<uint32_t> const & addressFeatures)
unique_ptr<coding::CompressedBitVector> && addressFeatures)
: Strategy(handle, viewport), m_lastReported(0)
{
ASSERT(addressFeatures.get(),
("Strategy must be initialized with valid address features set."));
m2::PointD const center = m_viewport.Center();
Index::FeaturesLoaderGuard loader(index, m_handle.GetId());
for (auto const & featureId : addressFeatures)
{
FeatureType feature;
loader.GetFeatureByIndex(featureId, feature);
m_features.emplace_back(featureId, feature::GetCenter(feature, FeatureType::WORST_GEOMETRY));
}
coding::CompressedBitVectorEnumerator::ForEach(
*addressFeatures, [&](uint64_t featureId)
{
ASSERT_LESS_OR_EQUAL(featureId, numeric_limits<uint32_t>::max(), ());
FeatureType feature;
loader.GetFeatureByIndex(featureId, feature);
m_features.emplace_back(featureId,
feature::GetCenter(feature, FeatureType::WORST_GEOMETRY));
});
// Order features by distance from the center of |viewport|.
sort(m_features.begin(), m_features.end(),
[&center](pair<uint32_t, m2::PointD> const & lhs, pair<uint32_t, m2::PointD> const & rhs)
{
return lhs.second.SquareLength(center) < rhs.second.SquareLength(center);
});
{
return lhs.second.SquareLength(center) < rhs.second.SquareLength(center);
});
}
// Retrieval::Strategy overrides:
@ -115,7 +161,7 @@ public:
m2::RectD viewport = m_viewport;
viewport.Scale(scale);
vector<uint32_t> features;
vector<uint64_t> features;
ASSERT_LESS_OR_EQUAL(m_lastReported, m_features.size(), ());
while (m_lastReported < m_features.size() &&
@ -125,7 +171,8 @@ public:
++m_lastReported;
}
callback(features);
auto cbv = SortFeaturesAndBuildCBV(move(features));
callback(*cbv);
return true;
}
@ -144,13 +191,19 @@ class SlowPathStrategy : public Retrieval::Strategy
{
public:
SlowPathStrategy(MwmSet::MwmHandle & handle, m2::RectD const & viewport,
SearchQueryParams const & params, vector<uint32_t> const & addressFeatures)
: Strategy(handle, viewport), m_params(params), m_coverageScale(0)
SearchQueryParams const & params,
unique_ptr<coding::CompressedBitVector> && addressFeatures)
: Strategy(handle, viewport)
, m_params(params)
, m_nonReported(move(addressFeatures))
, m_coverageScale(0)
{
if (addressFeatures.empty())
return;
ASSERT(m_nonReported.get(), ("Strategy must be initialized with valid address features set."));
m_nonReported.insert(addressFeatures.begin(), addressFeatures.end());
// No need to initialize slow path strategy when there're no
// features at all.
if (m_nonReported->PopCount() == 0)
return;
auto * value = m_handle.GetValue<MwmValue>();
ASSERT(value, ());
@ -173,39 +226,27 @@ public:
// Early exit when all features from this mwm were already
// reported.
if (m_nonReported.empty())
if (!m_nonReported || m_nonReported->PopCount() == 0)
return true;
vector<uint32_t> geometryFeatures;
unique_ptr<coding::CompressedBitVector> geometryFeatures;
// Early exit when whole mwm is inside scaled viewport.
if (currViewport.IsRectInside(m_bounds))
{
geometryFeatures.assign(m_nonReported.begin(), m_nonReported.end());
m_nonReported.clear();
callback(geometryFeatures);
geometryFeatures.swap(m_nonReported);
callback(*geometryFeatures);
return true;
}
try
{
auto collector = [&](uint32_t feature)
{
if (cancellable.IsCancelled())
throw CancelException();
if (m_nonReported.count(feature) != 0)
{
geometryFeatures.push_back(feature);
m_nonReported.erase(feature);
}
};
if (m_prevScale < 0)
{
covering::IntervalsT coverage;
CoverRect(currViewport, m_coverageScale, coverage);
RetrieveGeometryFeatures(m_handle, coverage, m_coverageScale, collector);
geometryFeatures =
RetrieveGeometryFeatures(m_handle, cancellable, coverage, m_coverageScale);
for (auto const & interval : coverage)
m_visited.Add(interval);
}
@ -225,21 +266,21 @@ public:
CoverRect(c, m_coverageScale, coverage);
CoverRect(d, m_coverageScale, coverage);
sort(coverage.begin(), coverage.end());
coverage.erase(unique(coverage.begin(), coverage.end()), coverage.end());
SortUnique(coverage);
coverage = covering::SortAndMergeIntervals(coverage);
coverage.erase(
remove_if(coverage.begin(), coverage.end(), [this](covering::IntervalT const & interval)
{
return m_visited.Elems().count(interval) != 0;
}),
coverage.end());
coverage.erase(remove_if(coverage.begin(), coverage.end(),
[this](covering::IntervalT const & interval)
{
return m_visited.Elems().count(interval) != 0;
}),
coverage.end());
covering::IntervalsT reducedCoverage;
for (auto const & interval : coverage)
m_visited.SubtractFrom(interval, reducedCoverage);
RetrieveGeometryFeatures(m_handle, reducedCoverage, m_coverageScale, collector);
geometryFeatures =
RetrieveGeometryFeatures(m_handle, cancellable, reducedCoverage, m_coverageScale);
for (auto const & interval : reducedCoverage)
m_visited.Add(interval);
@ -250,14 +291,16 @@ public:
return false;
}
callback(geometryFeatures);
auto toReport = coding::CompressedBitVector::Intersect(*m_nonReported, *geometryFeatures);
m_nonReported = coding::CompressedBitVector::Subtract(*m_nonReported, *toReport);
callback(*toReport);
return true;
}
private:
SearchQueryParams const & m_params;
set<uint32_t> m_nonReported;
unique_ptr<coding::CompressedBitVector> m_nonReported;
// This set is used to accumulate all read intervals from mwm and
// prevent further reads from the same offsets.
@ -433,7 +476,7 @@ bool Retrieval::RetrieveForScale(Bucket & bucket, double scale, Callback & callb
return true;
}
auto wrapper = [&](vector<uint32_t> & features)
auto wrapper = [&](coding::CompressedBitVector const & features)
{
ReportFeatures(bucket, features, scale, callback);
};
@ -460,34 +503,29 @@ bool Retrieval::InitBucketStrategy(Bucket & bucket, double scale)
ASSERT(!bucket.m_strategy, ());
ASSERT_EQUAL(0, bucket.m_numAddressFeatures, ());
vector<uint32_t> addressFeatures;
unique_ptr<coding::CompressedBitVector> addressFeatures;
try
{
auto collector = [&](trie::ValueReader::ValueType const & value)
{
if (IsCancelled())
throw CancelException();
addressFeatures.push_back(value.m_featureId);
};
RetrieveAddressFeatures(bucket.m_handle, m_params, collector);
addressFeatures = RetrieveAddressFeatures(bucket.m_handle, *this /* cancellable */, m_params);
}
catch (CancelException &)
{
return false;
}
bucket.m_numAddressFeatures = addressFeatures.size();
ASSERT(addressFeatures.get(), ("Can't retrieve address features."));
bucket.m_numAddressFeatures = addressFeatures->PopCount();
if (bucket.m_numAddressFeatures < kFastPathThreshold)
{
bucket.m_strategy.reset(
new FastPathStrategy(*m_index, bucket.m_handle, m_viewport, addressFeatures));
new FastPathStrategy(*m_index, bucket.m_handle, m_viewport, move(addressFeatures)));
}
else
{
bucket.m_strategy.reset(
new SlowPathStrategy(bucket.m_handle, m_viewport, m_params, addressFeatures));
new SlowPathStrategy(bucket.m_handle, m_viewport, m_params, move(addressFeatures)));
}
return true;
@ -516,22 +554,17 @@ bool Retrieval::Finished() const
return true;
}
void Retrieval::ReportFeatures(Bucket & bucket, vector<uint32_t> & featureIds, double scale,
Callback & callback)
void Retrieval::ReportFeatures(Bucket & bucket, coding::CompressedBitVector const & features,
double scale, Callback & callback)
{
if (m_limits.IsMaxNumFeaturesSet())
{
if (m_featuresReported >= m_limits.GetMaxNumFeatures())
return;
uint64_t rest = m_limits.GetMaxNumFeatures() - m_featuresReported;
if (rest < featureIds.size())
featureIds.resize(static_cast<size_t>(rest));
}
if (featureIds.empty())
if (m_limits.IsMaxNumFeaturesSet() && m_featuresReported >= m_limits.GetMaxNumFeatures())
return;
callback.OnFeaturesRetrieved(bucket.m_handle.GetId(), scale, featureIds);
bucket.m_featuresReported += featureIds.size();
m_featuresReported += featureIds.size();
if (features.PopCount() == 0)
return;
callback.OnFeaturesRetrieved(bucket.m_handle.GetId(), scale, features);
bucket.m_featuresReported += features.PopCount();
m_featuresReported += features.PopCount();
}
} // namespace search

View file

@ -15,6 +15,11 @@
class Index;
namespace coding
{
class CompressedBitVector;
}
namespace search
{
class Retrieval : public my::Cancellable
@ -29,7 +34,7 @@ public:
// This method may be called several times for the same mwm,
// reporting disjoint sets of features.
virtual void OnFeaturesRetrieved(MwmSet::MwmId const & id, double scale,
vector<uint32_t> const & featureIds) = 0;
coding::CompressedBitVector const & features) = 0;
// Called when all matching features for an mwm were retrieved and
// reported. Cliens may assume that this method is called no more
@ -73,7 +78,7 @@ public:
class Strategy
{
public:
using TCallback = function<void(vector<uint32_t> &)>;
using TCallback = function<void(coding::CompressedBitVector const &)>;
Strategy(MwmSet::MwmHandle & handle, m2::RectD const & viewport);
@ -130,7 +135,7 @@ private:
unique_ptr<Strategy> m_strategy;
size_t m_featuresReported;
size_t m_numAddressFeatures;
uint32_t m_numAddressFeatures;
bool m_intersectsWithViewport : 1;
bool m_finished : 1;
};
@ -152,7 +157,7 @@ private:
bool Finished() const;
// Reports features, updates bucket's stats.
void ReportFeatures(Bucket & bucket, vector<uint32_t> & featureIds, double scale,
void ReportFeatures(Bucket & bucket, coding::CompressedBitVector const & features, double scale,
Callback & callback);
Index * m_index;

View file

@ -21,11 +21,14 @@
#include "platform/local_country_file_utils.hpp"
#include "platform/platform.hpp"
#include "coding/compressed_bit_vector.hpp"
#include "base/scope_guard.hpp"
#include "base/string_utils.hpp"
#include "std/algorithm.hpp"
#include "std/initializer_list.hpp"
#include "std/limits.hpp"
#include "std/sstream.hpp"
#include "std/shared_ptr.hpp"
@ -189,11 +192,16 @@ public:
// search::Retrieval::Callback overrides:
void OnFeaturesRetrieved(MwmSet::MwmId const & id, double scale,
vector<uint32_t> const & offsets) override
coding::CompressedBitVector const & features) override
{
TEST_EQUAL(m_id, id, ());
m_triggered = true;
m_offsets.insert(m_offsets.end(), offsets.begin(), offsets.end());
coding::CompressedBitVectorEnumerator::ForEach(
features, [&](uint64_t featureId)
{
CHECK_LESS(featureId, numeric_limits<uint32_t>::max(), ());
m_offsets.push_back(static_cast<uint32_t>(featureId));
});
}
void OnMwmProcessed(MwmSet::MwmId const & /* id */) override {}
@ -216,13 +224,13 @@ public:
// search::Retrieval::Callback overrides:
void OnFeaturesRetrieved(MwmSet::MwmId const & id, double /* scale */,
vector<uint32_t> const & offsets) override
coding::CompressedBitVector const & features) override
{
auto const it = find(m_ids.cbegin(), m_ids.cend(), id);
TEST(it != m_ids.cend(), ("Unknown mwm:", id));
m_retrieved.insert(id);
m_numFeatures += offsets.size();
m_numFeatures += features.PopCount();
}
void OnMwmProcessed(MwmSet::MwmId const & /* id */) override {}

View file

@ -27,10 +27,12 @@
#include "platform/preferred_languages.hpp"
#include "coding/compressed_bit_vector.hpp"
#include "coding/multilang_utf8_string.hpp"
#include "coding/reader_wrapper.hpp"
#include "base/logging.hpp"
#include "base/macros.hpp"
#include "base/scope_guard.hpp"
#include "base/stl_add.hpp"
#include "base/string_utils.hpp"
@ -38,6 +40,7 @@
#include "std/algorithm.hpp"
#include "std/function.hpp"
#include "std/iterator.hpp"
#include "std/limits.hpp"
#define LONG_OP(op) \
{ \
@ -130,6 +133,33 @@ public:
impl::PreResult2 const & operator*() const { return *m_val; }
};
// This dummy rank table is used instead of a normal rank table when
// the latter can't be loaded. It should not be serialized and can't
// be loaded.
class DummyRankTable : public RankTable
{
public:
// RankTable overrides:
uint8_t Get(uint64_t i) const override { return 0; }
uint64_t Size() const override
{
NOTIMPLEMENTED();
return numeric_limits<uint64_t>::max();
}
Version GetVersion() const override
{
NOTIMPLEMENTED();
return RankTable::VERSION_COUNT;
}
void Serialize(Writer & /* writer */, bool /* preserveHostEndianness */) override
{
NOTIMPLEMENTED();
}
};
string DebugPrint(IndexedValue const & value)
{
string index;
@ -164,20 +194,24 @@ Query::RetrievalCallback::RetrievalCallback(Index & index, Query & query, Viewpo
}
void Query::RetrievalCallback::OnFeaturesRetrieved(MwmSet::MwmId const & id, double scale,
vector<uint32_t> const & featureIds)
coding::CompressedBitVector const & features)
{
auto const * table = LoadTable(id);
static DummyRankTable dummyTable;
auto const * table = LoadTable(id);
if (!table)
{
LOG(LWARNING, ("Can't get rank table for:", id));
for (auto const & featureId : featureIds)
m_query.AddPreResult1(id, featureId, 0 /* rank */, scale, m_viewportId);
return;
table = &dummyTable;
}
for (auto const & featureId : featureIds)
m_query.AddPreResult1(id, featureId, table->Get(featureId), scale, m_viewportId);
coding::CompressedBitVectorEnumerator::ForEach(
features, [&](uint64_t featureId)
{
ASSERT_LESS_OR_EQUAL(featureId, numeric_limits<uint32_t>::max(), ());
m_query.AddPreResult1(id, static_cast<uint32_t>(featureId), table->Get(featureId), scale,
m_viewportId);
});
}
void Query::RetrievalCallback::OnMwmProcessed(MwmSet::MwmId const & id) { UnloadTable(id); }

View file

@ -37,7 +37,15 @@
class FeatureType;
class CategoriesHolder;
namespace storage { class CountryInfoGetter; }
namespace coding
{
class CompressedBitVector;
}
namespace storage
{
class CountryInfoGetter;
}
namespace search
{
@ -124,7 +132,7 @@ private:
// Retrieval::Callback overrides:
void OnFeaturesRetrieved(MwmSet::MwmId const & id, double scale,
vector<uint32_t> const & featureIds) override;
coding::CompressedBitVector const & features) override;
void OnMwmProcessed(MwmSet::MwmId const & id) override;