[search] Optimized viewport retrieval.

This commit is contained in:
Yuri Gorshenin 2015-10-01 19:29:08 +03:00 committed by Sergey Yershov
parent d39493a758
commit 583882d23d
8 changed files with 401 additions and 94 deletions

View file

@ -42,6 +42,7 @@ HEADERS += \
const_helper.hpp \
exception.hpp \
internal/message.hpp \
interval_set.hpp \
limited_priority_queue.hpp \
logging.hpp \
macros.hpp \

View file

@ -19,6 +19,7 @@ SOURCES += \
condition_test.cpp \
const_helper.cpp \
containers_test.cpp \
interval_set_test.cpp \
logging_test.cpp \
math_test.cpp \
matrix_test.cpp \

View file

@ -0,0 +1,117 @@
#include "testing/testing.hpp"
#include "base/interval_set.hpp"
#include "std/initializer_list.hpp"
using namespace my;
namespace
{
template <typename TElem>
using TInterval = typename IntervalSet<TElem>::TInterval;
template <typename TElem>
void CheckSet(IntervalSet<TElem> const & actual, initializer_list<TInterval<TElem>> intervals)
{
set<TInterval<TElem>> expected(intervals);
TEST_EQUAL(actual.Elems(), expected, ());
}
} // namespace
UNIT_TEST(IntervalSet_Add)
{
IntervalSet<int> set;
TEST(set.Elems().empty(), ());
set.Add(TInterval<int>(0, 2));
CheckSet(set, {TInterval<int>(0, 2)});
set.Add(TInterval<int>(1, 3));
CheckSet(set, {TInterval<int>(0, 3)});
set.Add(TInterval<int>(-2, 0));
CheckSet(set, {TInterval<int>(-2, 3)});
set.Add(TInterval<int>(-4, -3));
CheckSet(set, {TInterval<int>(-4, -3), TInterval<int>(-2, 3)});
set.Add(TInterval<int>(7, 10));
CheckSet(set, {TInterval<int>(-4, -3), TInterval<int>(-2, 3), TInterval<int>(7, 10)});
set.Add(TInterval<int>(-3, -2));
CheckSet(set, {TInterval<int>(-4, 3), TInterval<int>(7, 10)});
set.Add(TInterval<int>(2, 8));
CheckSet(set, {TInterval<int>(-4, 10)});
set.Add(TInterval<int>(2, 3));
CheckSet(set, {TInterval<int>(-4, 10)});
}
UNIT_TEST(IntervalSet_SubtractFrom)
{
IntervalSet<int> set;
TEST(set.Elems().empty(), ());
set.Add(TInterval<int>(0, 2));
set.Add(TInterval<int>(4, 7));
set.Add(TInterval<int>(10, 11));
CheckSet(set, {TInterval<int>(0, 2), TInterval<int>(4, 7), TInterval<int>(10, 11)});
{
vector<TInterval<int>> difference;
set.SubtractFrom(TInterval<int>(1, 5), difference);
vector<TInterval<int>> expected{TInterval<int>(2, 4)};
TEST_EQUAL(difference, expected, ());
}
{
vector<TInterval<int>> difference;
set.SubtractFrom(TInterval<int>(-10, -5), difference);
vector<TInterval<int>> expected{TInterval<int>(-10, -5)};
TEST_EQUAL(difference, expected, ());
}
{
vector<TInterval<int>> difference;
set.SubtractFrom(TInterval<int>(0, 11), difference);
vector<TInterval<int>> expected{TInterval<int>(2, 4), TInterval<int>(7, 10)};
TEST_EQUAL(difference, expected, ());
}
{
vector<TInterval<int>> difference;
set.SubtractFrom(TInterval<int>(-1, 11), difference);
vector<TInterval<int>> expected{TInterval<int>(-1, 0), TInterval<int>(2, 4),
TInterval<int>(7, 10)};
TEST_EQUAL(difference, expected, ());
}
{
vector<TInterval<int>> difference;
set.SubtractFrom(TInterval<int>(1, 5), difference);
vector<TInterval<int>> expected{TInterval<int>(2, 4)};
TEST_EQUAL(difference, expected, ());
}
{
vector<TInterval<int>> difference;
set.SubtractFrom(TInterval<int>(5, 7), difference);
TEST(difference.empty(), ());
}
{
vector<TInterval<int>> difference;
set.SubtractFrom(TInterval<int>(4, 7), difference);
TEST(difference.empty(), ());
}
{
vector<TInterval<int>> difference;
set.SubtractFrom(TInterval<int>(3, 7), difference);
vector<TInterval<int>> expected{TInterval<int>(3, 4)};
TEST_EQUAL(difference, expected, ());
}
}

141
base/interval_set.hpp Normal file
View file

@ -0,0 +1,141 @@
#pragma once
#include "std/algorithm.hpp"
#include "std/set.hpp"
#include "std/utility.hpp"
#include "std/vector.hpp"
namespace my
{
// This class represents a set of disjoint half-opened intervals.
template <typename TElem>
class IntervalSet
{
public:
using TInterval = pair<TElem, TElem>;
// Adds an |interval| to the set.
//
// Complexity: O(num of intervals intersecting with |interval| +
// log(total number of intervals)).
void Add(TInterval const & interval);
// Subtracts set from an |interval| and appends result to
// |difference|.
//
// Complexity: O(num of intervals intersecting with |interval| +
// log(total number of intervals)).
void SubtractFrom(TInterval const & interval, vector<TInterval> & difference) const;
// Returns all elements of a set as a set of intervals.
//
// Complexity: O(1).
inline set<TInterval> const & Elems() const { return m_intervals; }
private:
using TIterator = typename set<TInterval>::iterator;
// Calculates range of intervals that have non-empty intersection with a given |interval|.
void Cover(TInterval const & interval, TIterator & begin, TIterator & end) const;
// This is a set of disjoint intervals.
set<TInterval> m_intervals;
};
template <typename TElem>
void IntervalSet<TElem>::Add(TInterval const & interval)
{
// Skips empty intervals.
if (interval.first == interval.second)
return;
TIterator begin;
TIterator end;
Cover(interval, begin, end);
TElem from = interval.first;
TElem to = interval.second;
// Update |from| and |to| in accordance with corner intervals (if any).
if (begin != end)
{
if (begin->first < from)
from = begin->first;
auto last = end;
--last;
if (last->second > to)
to = last->second;
}
// Now all elements [from, to) can be added to the set as a single
// interval which replace all intervals in [begin, end). But note
// that it can be possible to merge new interval with its neighbors,
// so we need to check it.
if (begin != m_intervals.begin())
{
auto prevBegin = begin;
--prevBegin;
if (prevBegin->second == from)
{
begin = prevBegin;
from = prevBegin->first;
}
}
if (end != m_intervals.end() && end->first == to)
{
to = end->second;
++end;
}
m_intervals.erase(begin, end);
m_intervals.emplace(from, to);
}
template <typename TElem>
void IntervalSet<TElem>::SubtractFrom(TInterval const & interval,
vector<TInterval> & difference) const
{
TIterator begin;
TIterator end;
Cover(interval, begin, end);
TElem from = interval.first;
TElem const to = interval.second;
for (auto it = begin; it != end && from < to; ++it)
{
if (it->first > from)
{
difference.emplace_back(from, min(it->first, to));
from = it->second;
}
else
{
from = std::max(from, it->second);
}
}
if (from < to)
difference.emplace_back(from, to);
}
template <typename TElem>
void IntervalSet<TElem>::Cover(TInterval const & interval, TIterator & begin, TIterator & end) const
{
TElem const & from = interval.first;
TElem const & to = interval.second;
begin = m_intervals.lower_bound(make_pair(from, from));
if (begin != m_intervals.begin())
{
auto prev = begin;
--prev;
if (prev->second > from)
begin = prev;
}
end = m_intervals.lower_bound(make_pair(to, to));
}
} // namespace my

View file

@ -11,7 +11,8 @@ class FeatureType;
namespace covering
{
typedef vector<pair<int64_t, int64_t> > IntervalsT;
typedef pair<int64_t, int64_t> IntervalT;
typedef vector<IntervalT> IntervalsT;
// Cover feature with RectIds and return their integer representations.
vector<int64_t> CoverFeature(FeatureType const & feature,

View file

@ -10,7 +10,7 @@
#include "coding/reader_wrapper.hpp"
#include "base/assert.hpp"
#include "base/logging.hpp"
#include "base/interval_set.hpp"
#include "std/algorithm.hpp"
#include "std/cmath.hpp"
@ -34,10 +34,12 @@ struct CancelException : public exception
// Otherwise, slow path is used.
uint64_t constexpr kFastPathThreshold = 100;
struct EmptyFilter
void CoverRect(m2::RectD const & rect, int scale, covering::IntervalsT & result)
{
inline bool operator()(uint32_t /* featureId */) const { return true; }
};
covering::CoveringGetter covering(rect, covering::ViewportWithLowLevels);
auto const & intervals = covering.Get(scale);
result.insert(result.end(), intervals.begin(), intervals.end());
}
// Retrieves from the search index corresponding to |handle| all
// features matching to |params|.
@ -45,6 +47,8 @@ template <typename ToDo>
void RetrieveAddressFeatures(MwmSet::MwmHandle const & handle, SearchQueryParams const & params,
ToDo && toDo)
{
auto emptyFilter = [](uint32_t /* featureId */) { return true; };
auto * value = handle.GetValue<MwmValue>();
ASSERT(value, ());
serial::CodingParams codingParams(trie::GetCodingParams(value->GetHeader().GetDefCodingParams()));
@ -52,31 +56,21 @@ void RetrieveAddressFeatures(MwmSet::MwmHandle const & handle, SearchQueryParams
auto const trieRoot = trie::ReadTrie(SubReaderWrapper<Reader>(searchReader.GetPtr()),
trie::ValueReader(codingParams));
MatchFeaturesInTrie(params, *trieRoot, EmptyFilter(), forward<ToDo>(toDo));
MatchFeaturesInTrie(params, *trieRoot, emptyFilter, forward<ToDo>(toDo));
}
// Retrieves from the geomery index corresponding to handle all
// features in (and, possibly, around) viewport and executes |toDo| on
// them.
template <typename ToDo>
void RetrieveGeometryFeatures(MwmSet::MwmHandle const & handle, m2::RectD viewport,
SearchQueryParams const & params, ToDo && toDo)
void RetrieveGeometryFeatures(MwmSet::MwmHandle const & handle,
covering::IntervalsT const & covering, int scale, ToDo && toDo)
{
auto * value = handle.GetValue<MwmValue>();
ASSERT(value, ());
feature::DataHeader const & header = value->GetHeader();
if (!viewport.Intersect(header.GetBounds()))
return;
auto const scaleRange = header.GetScaleRange();
int const scale = min(max(params.m_scale, scaleRange.first), scaleRange.second);
covering::CoveringGetter covering(viewport, covering::ViewportWithLowLevels);
covering::IntervalsT const & intervals = covering.Get(scale);
ScaleIndex<ModelReaderPtr> index(value->m_cont.GetReader(INDEX_FILE_TAG), value->m_factory);
for (auto const & interval : intervals)
for (auto const & interval : covering)
index.ForEachInIntervalAndScale(toDo, interval.first, interval.second, scale);
}
@ -143,14 +137,19 @@ class SlowPathStrategy : public Retrieval::Strategy
public:
SlowPathStrategy(MwmSet::MwmHandle & handle, m2::RectD const & viewport,
SearchQueryParams const & params, vector<uint32_t> const & addressFeatures)
: Strategy(handle, viewport), m_params(params)
: Strategy(handle, viewport), m_params(params), m_coverageScale(0)
{
if (addressFeatures.empty())
return;
m_nonReported.resize(*max_element(addressFeatures.begin(), addressFeatures.end()) + 1);
for (auto const & featureId : addressFeatures)
m_nonReported[featureId] = true;
m_nonReported.insert(addressFeatures.begin(), addressFeatures.end());
auto * value = m_handle.GetValue<MwmValue>();
ASSERT(value, ());
feature::DataHeader const & header = value->GetHeader();
auto const scaleRange = header.GetScaleRange();
m_coverageScale = min(max(m_params.m_scale, scaleRange.first), scaleRange.second);
m_bounds = header.GetBounds();
}
// Retrieval::Strategy overrides:
@ -160,8 +159,26 @@ public:
m2::RectD currViewport = m_viewport;
currViewport.Scale(scale);
// Early exit when scaled viewport does not intersect mwm bounds.
if (!currViewport.Intersect(m_bounds))
return true;
// Early exit when all features from this mwm were already
// reported.
if (m_nonReported.empty())
return true;
vector<uint32_t> geometryFeatures;
// Early exit when whole mwm is inside scaled viewport.
if (currViewport.IsRectInside(m_bounds))
{
geometryFeatures.assign(m_nonReported.begin(), m_nonReported.end());
m_nonReported.clear();
callback(geometryFeatures);
return true;
}
try
{
auto collector = [&](uint32_t feature)
@ -169,16 +186,20 @@ public:
if (cancellable.IsCancelled())
throw CancelException();
if (feature < m_nonReported.size() && m_nonReported[feature])
if (m_nonReported.count(feature) != 0)
{
geometryFeatures.push_back(feature);
m_nonReported[feature] = false;
m_nonReported.erase(feature);
}
};
if (m_prevScale < 0)
{
RetrieveGeometryFeatures(m_handle, currViewport, m_params, collector);
covering::IntervalsT coverage;
CoverRect(currViewport, m_coverageScale, coverage);
RetrieveGeometryFeatures(m_handle, coverage, m_coverageScale, collector);
for (auto const & interval : coverage)
m_visited.Add(interval);
}
else
{
@ -190,10 +211,30 @@ public:
m2::RectD b(a.RightTop(), c.RightTop());
m2::RectD d(a.LeftBottom(), c.LeftBottom());
RetrieveGeometryFeatures(m_handle, a, m_params, collector);
RetrieveGeometryFeatures(m_handle, b, m_params, collector);
RetrieveGeometryFeatures(m_handle, c, m_params, collector);
RetrieveGeometryFeatures(m_handle, d, m_params, collector);
covering::IntervalsT coverage;
CoverRect(a, m_coverageScale, coverage);
CoverRect(b, m_coverageScale, coverage);
CoverRect(c, m_coverageScale, coverage);
CoverRect(d, m_coverageScale, coverage);
sort(coverage.begin(), coverage.end());
coverage.erase(unique(coverage.begin(), coverage.end()), coverage.end());
coverage = covering::SortAndMergeIntervals(coverage);
coverage.erase(
remove_if(coverage.begin(), coverage.end(), [this](covering::IntervalT const & interval)
{
return m_visited.Elems().count(interval) != 0;
}),
coverage.end());
covering::IntervalsT reducedCoverage;
for (auto const & interval : coverage)
m_visited.SubtractFrom(interval, reducedCoverage);
RetrieveGeometryFeatures(m_handle, reducedCoverage, m_coverageScale, collector);
for (auto const & interval : reducedCoverage)
m_visited.Add(interval);
}
}
catch (CancelException &)
@ -208,7 +249,14 @@ public:
private:
SearchQueryParams const & m_params;
vector<bool> m_nonReported;
set<uint32_t> m_nonReported;
// This set is used to accumulate all read intervals from mwm and
// prevent further reads from the same offsets.
my::IntervalSet<int64_t> m_visited;
m2::RectD m_bounds;
int m_coverageScale;
};
} // namespace
@ -320,11 +368,15 @@ void Retrieval::Go(Callback & callback)
if (m_limits.IsMaxViewportScaleSet() && reducedScale >= m_limits.GetMaxViewportScale())
reducedScale = m_limits.GetMaxViewportScale();
if (!RetrieveForScale(reducedScale, callback))
break;
for (auto & bucket : m_buckets)
{
if (!RetrieveForScale(bucket, reducedScale, callback))
break;
}
if (Finished())
break;
if (m_limits.IsMaxViewportScaleSet() && reducedScale >= m_limits.GetMaxViewportScale())
break;
if (m_limits.IsMaxNumFeaturesSet() && m_featuresReported >= m_limits.GetMaxNumFeatures())
@ -334,59 +386,55 @@ void Retrieval::Go(Callback & callback)
}
}
bool Retrieval::RetrieveForScale(double scale, Callback & callback)
bool Retrieval::RetrieveForScale(Bucket & bucket, double scale, Callback & callback)
{
m2::RectD viewport = m_viewport;
viewport.Scale(scale);
for (auto & bucket : m_buckets)
if (IsCancelled())
return false;
if (bucket.m_finished || !viewport.IsIntersect(bucket.m_bounds))
return true;
if (!bucket.m_intersectsWithViewport)
{
if (IsCancelled())
// This is the first time viewport intersects with
// mwm. Initialize bucket's retrieval strategy.
if (!InitBucketStrategy(bucket))
return false;
bucket.m_intersectsWithViewport = true;
if (bucket.m_addressFeatures.empty())
bucket.m_finished = true;
}
if (bucket.m_finished || !viewport.IsIntersect(bucket.m_bounds))
continue;
ASSERT(bucket.m_intersectsWithViewport, ());
ASSERT_LESS_OR_EQUAL(bucket.m_featuresReported, bucket.m_addressFeatures.size(), ());
if (bucket.m_featuresReported == bucket.m_addressFeatures.size())
{
// All features were reported for the bucket, mark it as
// finished and move to the next bucket.
FinishBucket(bucket, callback);
return true;
}
if (!bucket.m_intersectsWithViewport)
{
// This is the first time viewport intersects with
// mwm. Initialize bucket's retrieval strategy.
if (!InitBucketStrategy(bucket))
return false;
bucket.m_intersectsWithViewport = true;
if (bucket.m_addressFeatures.empty())
bucket.m_finished = true;
}
auto wrapper = [&](vector<uint32_t> & features)
{
ReportFeatures(bucket, features, scale, callback);
};
ASSERT(bucket.m_intersectsWithViewport, ());
ASSERT_LESS_OR_EQUAL(bucket.m_featuresReported, bucket.m_addressFeatures.size(), ());
if (bucket.m_featuresReported == bucket.m_addressFeatures.size())
{
// All features were reported for the bucket, mark it as
// finished and move to the next bucket.
FinishBucket(bucket, callback);
continue;
}
if (!bucket.m_strategy->Retrieve(scale, *this /* cancellable */, wrapper))
return false;
auto wrapper = [&](vector<uint32_t> & features)
{
ReportFeatures(bucket, features, scale, callback);
};
if (!bucket.m_strategy->Retrieve(scale, *this /* cancellable */, wrapper))
return false;
if (viewport.IsRectInside(bucket.m_bounds))
{
// Viewport completely covers the bucket, so mark it as finished
// and switch to the next bucket. Note that "viewport covers the
// bucket" is not the same as "all features from the bucket were
// reported", because of scale parameter. Search index reports
// all matching features, but geometry index can skip features
// from more detailed scales.
FinishBucket(bucket, callback);
continue;
}
if (viewport.IsRectInside(bucket.m_bounds))
{
// Viewport completely covers the bucket, so mark it as finished
// and switch to the next bucket. Note that "viewport covers the
// bucket" is not the same as "all features from the bucket were
// reported", because of scale parameter. Search index reports
// all matching features, but geometry index can skip features
// from more detailed scales.
FinishBucket(bucket, callback);
}
return true;
@ -431,6 +479,7 @@ void Retrieval::FinishBucket(Bucket & bucket, Callback & callback)
if (bucket.m_finished)
return;
bucket.m_finished = true;
bucket.m_handle = MwmSet::MwmHandle();
callback.OnMwmProcessed(bucket.m_handle.GetId());
}
@ -447,18 +496,10 @@ bool Retrieval::Finished() const
void Retrieval::ReportFeatures(Bucket & bucket, vector<uint32_t> & featureIds, double scale,
Callback & callback)
{
ASSERT(!m_limits.IsMaxNumFeaturesSet() || m_featuresReported <= m_limits.GetMaxNumFeatures(), ());
if (m_limits.IsMaxNumFeaturesSet())
{
uint64_t const delta = m_limits.GetMaxNumFeatures() - m_featuresReported;
if (featureIds.size() > delta)
featureIds.resize(delta);
}
if (!featureIds.empty())
{
callback.OnFeaturesRetrieved(bucket.m_handle.GetId(), scale, featureIds);
bucket.m_featuresReported += featureIds.size();
m_featuresReported += featureIds.size();
}
if (featureIds.empty())
return;
callback.OnFeaturesRetrieved(bucket.m_handle.GetId(), scale, featureIds);
bucket.m_featuresReported += featureIds.size();
m_featuresReported += featureIds.size();
}
} // namespace search

View file

@ -135,7 +135,7 @@ private:
//
// *NOTE* |scale| of successive calls of this method should be
// non-decreasing.
WARN_UNUSED_RESULT bool RetrieveForScale(double scale, Callback & callback);
WARN_UNUSED_RESULT bool RetrieveForScale(Bucket & bucket, double scale, Callback & callback);
// Inits bucket retrieval strategy. Returns false when cancelled.
WARN_UNUSED_RESULT bool InitBucketStrategy(Bucket & bucket);

View file

@ -274,6 +274,8 @@ UNIT_TEST(Retrieval_Smoke)
}
// Retrieve all whiskey bars from the left-bottom 5 x 5 square.
// Note that due to current coverage algorithm the number of
// retrieved results can be greater than 36.
{
TestCallback callback(id);
search::Retrieval::Limits limits;
@ -282,11 +284,14 @@ UNIT_TEST(Retrieval_Smoke)
retrieval.Init(index, infos, m2::RectD(m2::PointD(0, 0), m2::PointD(1, 1)), params, limits);
retrieval.Go(callback);
TEST(callback.WasTriggered(), ());
TEST_EQUAL(36 /* number of whiskey bars in a 5 x 5 square (border is counted) */,
callback.Offsets().size(), ());
TEST_GREATER_OR_EQUAL(callback.Offsets().size(),
36 /* number of whiskey bars in a 5 x 5 square (border is counted) */,
());
}
// Retrieve exactly 8 whiskey bars from the center.
// Retrieve exactly 8 whiskey bars from the center. Note that due
// to current coverage algorithm the number of retrieved results can
// be greater than 8.
{
TestCallback callback(id);
search::Retrieval::Limits limits;
@ -296,7 +301,7 @@ UNIT_TEST(Retrieval_Smoke)
limits);
retrieval.Go(callback);
TEST(callback.WasTriggered(), ());
TEST_EQUAL(callback.Offsets().size(), 8, ());
TEST_GREATER_OR_EQUAL(callback.Offsets().size(), 8, ());
}
}