[search] Tracer.

This commit is contained in:
Yuri Gorshenin 2017-11-08 17:04:39 +03:00 committed by burivuh
parent a04e41f176
commit d108a4c2cc
17 changed files with 335 additions and 51 deletions

View file

@ -132,6 +132,8 @@ set(
token_range.hpp
token_slice.cpp
token_slice.hpp
tracer.cpp
tracer.hpp
types_skipper.cpp
types_skipper.hpp
utils.cpp

View file

@ -10,6 +10,7 @@
#include "search/processor.hpp"
#include "search/retrieval.hpp"
#include "search/token_slice.hpp"
#include "search/tracer.hpp"
#include "search/utils.hpp"
#include "indexer/classificator.hpp"
@ -78,33 +79,38 @@ UniString const kUniSpace(MakeUniString(" "));
struct ScopedMarkTokens
{
ScopedMarkTokens(vector<bool> & usedTokens, TokenRange const & range)
: m_usedTokens(usedTokens), m_range(range)
static BaseContext::TokenType constexpr kUnused = BaseContext::TOKEN_TYPE_COUNT;
ScopedMarkTokens(vector<BaseContext::TokenType> & tokens, BaseContext::TokenType type,
TokenRange const & range)
: m_tokens(tokens), m_type(type), m_range(range)
{
ASSERT(m_range.IsValid(), ());
ASSERT_LESS_OR_EQUAL(m_range.End(), m_usedTokens.size(), ());
ASSERT_LESS_OR_EQUAL(m_range.End(), m_tokens.size(), ());
#if defined(DEBUG)
for (size_t i : m_range)
ASSERT(!m_usedTokens[i], (i));
ASSERT_EQUAL(m_tokens[i], kUnused, (i));
#endif
fill(m_usedTokens.begin() + m_range.Begin(), m_usedTokens.begin() + m_range.End(),
true /* used */);
fill(m_tokens.begin() + m_range.Begin(), m_tokens.begin() + m_range.End(), m_type);
}
~ScopedMarkTokens()
{
#if defined(DEBUG)
for (size_t i : m_range)
ASSERT(m_usedTokens[i], (i));
ASSERT_EQUAL(m_tokens[i], m_type, (i));
#endif
fill(m_usedTokens.begin() + m_range.Begin(), m_usedTokens.begin() + m_range.End(),
false /* used */);
fill(m_tokens.begin() + m_range.Begin(), m_tokens.begin() + m_range.End(), kUnused);
}
vector<bool> & m_usedTokens;
vector<search::BaseContext::TokenType> & m_tokens;
search::BaseContext::TokenType const m_type;
TokenRange const m_range;
};
// static
BaseContext::TokenType constexpr ScopedMarkTokens::kUnused;
class LazyRankTable : public RankTable
{
public:
@ -573,7 +579,7 @@ void Geocoder::InitBaseContext(BaseContext & ctx)
{
Retrieval retrieval(*m_context, m_cancellable);
ctx.m_usedTokens.assign(m_params.GetNumTokens(), false);
ctx.m_tokens.assign(m_params.GetNumTokens(), BaseContext::TOKEN_TYPE_COUNT);
ctx.m_numTokens = m_params.GetNumTokens();
ctx.m_features.resize(ctx.m_numTokens);
for (size_t i = 0; i < ctx.m_features.size(); ++i)
@ -849,7 +855,7 @@ void Geocoder::MatchRegions(BaseContext & ctx, Region::Type type)
ctx.m_regions.push_back(&region);
MY_SCOPE_GUARD(cleanup, [&ctx]() { ctx.m_regions.pop_back(); });
ScopedMarkTokens mark(ctx.m_usedTokens, tokenRange);
ScopedMarkTokens mark(ctx.m_tokens, BaseContext::FromRegionType(type), tokenRange);
if (ctx.AllTokensUsed())
{
// Region matches to search query, we need to emit it as is.
@ -888,7 +894,7 @@ void Geocoder::MatchCities(BaseContext & ctx)
continue;
}
ScopedMarkTokens mark(ctx.m_usedTokens, tokenRange);
ScopedMarkTokens mark(ctx.m_tokens, BaseContext::TOKEN_TYPE_CITY, tokenRange);
ctx.m_city = &city;
MY_SCOPE_GUARD(cleanup, [&ctx]() { ctx.m_city = nullptr; });
@ -954,7 +960,7 @@ void Geocoder::WithPostcodes(BaseContext & ctx, TFn && fn)
size_t endToken = startToken;
for (size_t n = 1; startToken + n <= ctx.m_numTokens && n <= maxPostcodeTokens; ++n)
{
if (ctx.m_usedTokens[startToken + n - 1])
if (ctx.IsTokenUsed(startToken + n - 1))
break;
TokenSlice slice(m_params, TokenRange(startToken, startToken + n));
@ -972,7 +978,7 @@ void Geocoder::WithPostcodes(BaseContext & ctx, TFn && fn)
if (!postcodes.IsEmpty())
{
ScopedMarkTokens mark(ctx.m_usedTokens, tokenRange);
ScopedMarkTokens mark(ctx.m_tokens, BaseContext::TOKEN_TYPE_POSTCODE, tokenRange);
m_postcodes.Clear();
m_postcodes.m_tokenRange = tokenRange;
@ -1011,7 +1017,7 @@ void Geocoder::CreateStreetsLayerAndMatchLowerLayers(BaseContext & ctx,
});
layer.m_sortedFeatures = &sortedFeatures;
ScopedMarkTokens mark(ctx.m_usedTokens, prediction.m_tokenRange);
ScopedMarkTokens mark(ctx.m_tokens, BaseContext::TOKEN_TYPE_STREET, prediction.m_tokenRange);
MatchPOIsAndBuildings(ctx, 0 /* curToken */);
}
@ -1114,7 +1120,7 @@ void Geocoder::MatchPOIsAndBuildings(BaseContext & ctx, size_t curToken)
features.SetFull();
// Try to consume [curToken, m_numTokens) tokens range.
for (size_t n = 1; curToken + n <= ctx.m_numTokens && !ctx.m_usedTokens[curToken + n - 1]; ++n)
for (size_t n = 1; curToken + n <= ctx.m_numTokens && !ctx.IsTokenUsed(curToken + n - 1); ++n)
{
// At this point |features| is the intersection of
// m_addressFeatures[curToken], m_addressFeatures[curToken + 1],
@ -1194,6 +1200,8 @@ void Geocoder::MatchPOIsAndBuildings(BaseContext & ctx, size_t curToken)
}
layer.m_type = static_cast<Model::Type>(i);
ScopedMarkTokens mark(ctx.m_tokens, BaseContext::FromModelType(layer.m_type),
TokenRange(curToken, curToken + n));
if (IsLayerSequenceSane(layers))
MatchPOIsAndBuildings(ctx, curToken + n);
}
@ -1283,6 +1291,9 @@ void Geocoder::EmitResult(BaseContext const & ctx, MwmSet::MwmId const & mwmId,
if (m_params.m_cianMode && type != Model::TYPE_BUILDING)
return;
if (m_params.m_tracer)
m_params.m_tracer->EmitParse(ctx.m_tokens);
// Distance and rank will be filled at the end, for all results at once.
//
// TODO (@y, @m): need to skip zero rank features that are too
@ -1343,7 +1354,7 @@ void Geocoder::MatchUnclassified(BaseContext & ctx, size_t curToken)
auto startToken = curToken;
for (curToken = ctx.SkipUsedTokens(curToken);
curToken < ctx.m_numTokens && !ctx.m_usedTokens[curToken]; ++curToken)
curToken < ctx.m_numTokens && !ctx.IsTokenUsed(curToken); ++curToken)
{
allFeatures = allFeatures.Intersect(ctx.m_features[curToken]);
}

View file

@ -37,6 +37,7 @@
#include "std/limits.hpp"
#include "std/set.hpp"
#include "std/shared_ptr.hpp"
#include "std/string.hpp"
#include "std/unique_ptr.hpp"
#include "std/unordered_map.hpp"
@ -52,11 +53,11 @@ class CountryInfoGetter;
namespace search
{
class PreRanker;
class FeaturesFilter;
class FeaturesLayerMatcher;
class PreRanker;
class TokenSlice;
class Tracer;
// This class is used to retrieve all features corresponding to a
// search query. Search query is represented as a sequence of tokens
@ -83,6 +84,7 @@ public:
shared_ptr<hotels_filter::Rule> m_hotelsFilter;
bool m_cianMode = false;
set<uint32_t> m_preferredTypes;
shared_ptr<Tracer> m_tracer;
};
Geocoder(Index const & index, storage::CountryInfoGetter const & infoGetter,

View file

@ -5,37 +5,97 @@
#include "base/assert.hpp"
#include "base/stl_add.hpp"
#include <algorithm>
using namespace std;
namespace search
{
// static
BaseContext::TokenType BaseContext::FromModelType(Model::Type type)
{
switch (type)
{
case Model::TYPE_POI: return TOKEN_TYPE_POI;
case Model::TYPE_BUILDING: return TOKEN_TYPE_BUILDING;
case Model::TYPE_STREET: return TOKEN_TYPE_STREET;
case Model::TYPE_UNCLASSIFIED: return TOKEN_TYPE_UNCLASSIFIED;
case Model::TYPE_VILLAGE: return TOKEN_TYPE_VILLAGE;
case Model::TYPE_CITY: return TOKEN_TYPE_CITY;
case Model::TYPE_STATE: return TOKEN_TYPE_STATE;
case Model::TYPE_COUNTRY: return TOKEN_TYPE_COUNTRY;
case Model::TYPE_COUNT: return TOKEN_TYPE_COUNT;
}
}
// static
BaseContext::TokenType BaseContext::FromRegionType(Region::Type type)
{
switch (type)
{
case Region::TYPE_STATE: return TOKEN_TYPE_STATE;
case Region::TYPE_COUNTRY: return TOKEN_TYPE_COUNTRY;
case Region::TYPE_COUNT: return TOKEN_TYPE_COUNT;
}
}
size_t BaseContext::SkipUsedTokens(size_t curToken) const
{
while (curToken != m_usedTokens.size() && m_usedTokens[curToken])
while (curToken != m_tokens.size() && IsTokenUsed(curToken))
++curToken;
return curToken;
}
bool BaseContext::IsTokenUsed(size_t token) const
{
ASSERT_LESS(token, m_tokens.size(), ());
return m_tokens[token] != TOKEN_TYPE_COUNT;
}
bool BaseContext::AllTokensUsed() const
{
return std::all_of(m_usedTokens.begin(), m_usedTokens.end(), IdFunctor());
for (size_t i = 0; i < m_tokens.size(); ++i)
{
if (!IsTokenUsed(i))
return false;
}
return true;
}
bool BaseContext::HasUsedTokensInRange(TokenRange const & range) const
{
ASSERT(range.IsValid(), (range));
return std::any_of(m_usedTokens.begin() + range.Begin(), m_usedTokens.begin() + range.End(),
IdFunctor());
for (size_t i = range.Begin(); i < range.End(); ++i)
{
if (IsTokenUsed(i))
return true;
}
return false;
}
size_t BaseContext::NumUnusedTokenGroups() const
{
size_t numGroups = 0;
for (size_t i = 0; i < m_usedTokens.size(); ++i)
for (size_t i = 0; i < m_tokens.size(); ++i)
{
if (!m_usedTokens[i] && (i == 0 || m_usedTokens[i - 1]))
if (!IsTokenUsed(i) && (i == 0 || IsTokenUsed(i - 1)))
++numGroups;
}
return numGroups;
}
string DebugPrint(BaseContext::TokenType type)
{
switch (type)
{
case BaseContext::TOKEN_TYPE_POI: return "POI";
case BaseContext::TOKEN_TYPE_BUILDING: return "BUILDING";
case BaseContext::TOKEN_TYPE_STREET: return "STREET";
case BaseContext::TOKEN_TYPE_UNCLASSIFIED: return "UNCLASSIFIED";
case BaseContext::TOKEN_TYPE_VILLAGE: return "VILLAGE";
case BaseContext::TOKEN_TYPE_CITY: return "CITY";
case BaseContext::TOKEN_TYPE_STATE: return "STATE";
case BaseContext::TOKEN_TYPE_COUNTRY: return "COUNTRY";
case BaseContext::TOKEN_TYPE_POSTCODE: return "POSTCODE";
case BaseContext::TOKEN_TYPE_COUNT: return "COUNT";
}
}
} // namespace search

View file

@ -4,9 +4,11 @@
#include "search/features_layer.hpp"
#include "search/geocoder_locality.hpp"
#include "search/hotels_filter.hpp"
#include "search/model.hpp"
#include <cstddef>
#include <memory>
#include <string>
#include <vector>
namespace search
@ -16,10 +18,31 @@ class TokenRange;
struct BaseContext
{
enum TokenType
{
TOKEN_TYPE_POI,
TOKEN_TYPE_BUILDING,
TOKEN_TYPE_STREET,
TOKEN_TYPE_UNCLASSIFIED,
TOKEN_TYPE_VILLAGE,
TOKEN_TYPE_CITY,
TOKEN_TYPE_STATE,
TOKEN_TYPE_COUNTRY,
TOKEN_TYPE_POSTCODE,
TOKEN_TYPE_COUNT
};
static TokenType FromModelType(Model::Type type);
static TokenType FromRegionType(Region::Type type);
// Advances |curToken| to the nearest unused token, or to the end of
// |m_usedTokens| if there are no unused tokens.
size_t SkipUsedTokens(size_t curToken) const;
// Returns true if |token| is marked as used.
bool IsTokenUsed(size_t token) const;
// Returns true iff all tokens are used.
bool AllTokensUsed() const;
@ -45,11 +68,13 @@ struct BaseContext
// This vector is used to indicate what tokens were already matched
// and can't be re-used during the geocoding process.
std::vector<bool> m_usedTokens;
std::vector<TokenType> m_tokens;
// Number of tokens in the query.
size_t m_numTokens = 0;
std::unique_ptr<hotels_filter::HotelsFilter::ScopedFilter> m_hotelsFilter;
};
std::string DebugPrint(BaseContext::TokenType type);
} // namespace search

View file

@ -491,6 +491,7 @@ void Processor::InitGeocoder(Geocoder::Params & geocoderParams, SearchParams con
geocoderParams.m_hotelsFilter = searchParams.m_hotelsFilter;
geocoderParams.m_cianMode = searchParams.m_cianMode;
geocoderParams.m_preferredTypes = m_preferredTypes;
geocoderParams.m_tracer = searchParams.m_tracer;
m_geocoder.SetParams(geocoderParams);
}

View file

@ -81,6 +81,7 @@ HEADERS += \
suggest.hpp \
token_range.hpp \
token_slice.hpp \
tracer.hpp \
types_skipper.hpp \
utils.hpp \
viewport_search_callback.hpp \
@ -145,6 +146,7 @@ SOURCES += \
streets_matcher.cpp \
suggest.cpp \
token_slice.cpp \
tracer.cpp \
types_skipper.cpp \
utils.cpp \
viewport_search_callback.cpp \

View file

@ -12,6 +12,7 @@ set(
ranker_test.cpp
search_edited_features_test.cpp
smoke_test.cpp
tracer_tests.cpp
)
omim_add_test(${PROJECT_NAME} ${SRC})

View file

@ -5,6 +5,8 @@
#include "search/search_tests_support/test_search_request.hpp"
#include "search/search_tests_support/test_with_custom_mwms.hpp"
#include "generator/generator_tests_support/test_feature.hpp"
#include "indexer/indexer_tests_support/helpers.hpp"
#include "geometry/rect2d.hpp"
@ -57,4 +59,28 @@ protected:
m2::RectD m_viewport;
};
class TestCafe : public generator::tests_support::TestPOI
{
public:
TestCafe(m2::PointD const & center, std::string const & name, std::string const & lang)
: TestPOI(center, name, lang)
{
SetTypes({{"amenity", "cafe"}});
}
explicit TestCafe(m2::PointD const & center) : TestCafe(center, "cafe", "en") {}
};
class TestHotel : public generator::tests_support::TestPOI
{
public:
TestHotel(m2::PointD const & center, std::string const & name, std::string const & lang)
: TestPOI(center, name, lang)
{
SetTypes({{"tourism", "hotel"}});
}
explicit TestHotel(m2::PointD const & center) : TestHotel(center, "hotel", "en") {}
};
} // namespace search

View file

@ -1,7 +1,5 @@
#include "testing/testing.hpp"
#include "generator/generator_tests_support/test_feature.hpp"
#include "search/viewport_search_callback.hpp"
#include "search/mode.hpp"
#include "search/search_integration_tests/helpers.hpp"
@ -17,24 +15,6 @@ namespace search
{
namespace
{
class TestCafe : public TestPOI
{
public:
TestCafe(m2::PointD const & center) : TestPOI(center, "cafe", "en")
{
SetTypes({{"amenity", "cafe"}});
}
};
class TestHotel : public TestPOI
{
public:
TestHotel(m2::PointD const & center) : TestPOI(center, "hotel", "en")
{
SetTypes({{"tourism", "hotel"}});
}
};
class TestDelegate : public ViewportSearchCallback::Delegate
{
public:

View file

@ -33,6 +33,7 @@ SOURCES += \
processor_test.cpp \
search_edited_features_test.cpp \
smoke_test.cpp \
tracer_tests.cpp \
HEADERS += \
helpers.hpp \

View file

@ -0,0 +1,62 @@
#include "testing/testing.hpp"
#include "search/geocoder_context.hpp"
#include "search/search_integration_tests/helpers.hpp"
#include "search/search_tests_support/test_results_matching.hpp"
#include "search/tracer.hpp"
#include "generator/generator_tests_support/test_feature.hpp"
#include <memory>
#include <vector>
using namespace generator::tests_support;
using namespace search::tests_support;
using namespace search;
using namespace std;
namespace
{
class TracerTest : public SearchTest
{
};
UNIT_CLASS_TEST(TracerTest, Smoke)
{
using TokenType = BaseContext::TokenType;
TestCity moscow(m2::PointD(0, 0), "Moscow", "en", 100 /* rank */);
TestCafe regularCafe(m2::PointD(0, 0));
TestCafe moscowCafe(m2::PointD(0, 0), "Moscow", "en");
BuildWorld([&](TestMwmBuilder & builder) { builder.Add(moscow); });
auto const id = BuildCountry("Wonderland", [&](TestMwmBuilder & builder) {
builder.Add(regularCafe);
builder.Add(moscowCafe);
});
auto tracer = make_shared<Tracer>();
SearchParams params;
params.m_query = "moscow cafe";
params.m_inputLocale = "en";
params.m_viewport = m2::RectD(-1, -1, 1, 1);
params.m_mode = Mode::Everywhere;
params.m_tracer = tracer;
TestSearchRequest request(m_engine, params);
request.Run();
TRules rules = {ExactMatch(id, regularCafe), ExactMatch(id, moscowCafe)};
TEST(ResultsMatch(request.Results(), rules), ());
auto const actual = tracer->GetUniqueParses();
vector<Tracer::Parse> const expected{
Tracer::Parse{{{TokenType::TOKEN_TYPE_POI, TokenRange(0, 2)}}},
Tracer::Parse{{{TokenType::TOKEN_TYPE_CITY, TokenRange(0, 1)},
{TokenType::TOKEN_TYPE_POI, TokenRange(1, 2)}}}};
TEST_EQUAL(expected, actual, ());
}
} // namespace

View file

@ -16,6 +16,7 @@
namespace search
{
class Results;
class Tracer;
struct SearchParams
{
@ -56,7 +57,10 @@ struct SearchParams
bool m_needHighlighting = false;
std::shared_ptr<hotels_filter::Rule> m_hotelsFilter;
bool m_cianMode = false;
std::shared_ptr<Tracer> m_tracer;
};
std::string DebugPrint(SearchParams const & params);

View file

@ -64,7 +64,7 @@ public:
void GetTopLocalities(size_t limit)
{
BaseContext ctx;
ctx.m_usedTokens.assign(m_params.GetNumTokens(), false);
ctx.m_tokens.assign(m_params.GetNumTokens(), BaseContext::TOKEN_TYPE_COUNT);
ctx.m_numTokens = m_params.GetNumTokens();
for (size_t i = 0; i < m_params.GetNumTokens(); ++i)

View file

@ -59,7 +59,7 @@ void StreetsMatcher::FindStreets(BaseContext const & ctx, FeaturesFilter const &
{
for (size_t startToken = 0; startToken < ctx.m_numTokens; ++startToken)
{
if (ctx.m_usedTokens[startToken])
if (ctx.IsTokenUsed(startToken))
continue;
// Here we try to match as many tokens as possible while
@ -150,7 +150,7 @@ void StreetsMatcher::FindStreets(BaseContext const & ctx, FeaturesFilter const &
incomplete = false;
});
for (; curToken < ctx.m_numTokens && !ctx.m_usedTokens[curToken] && !streets.IsEmpty();
for (; curToken < ctx.m_numTokens && !ctx.IsTokenUsed(curToken) && !streets.IsEmpty();
++curToken)
{
auto const & token = params.GetToken(curToken).m_original;

65
search/tracer.cpp Normal file
View file

@ -0,0 +1,65 @@
#include "search/tracer.hpp"
#include "base/stl_helpers.hpp"
#include <cstddef>
#include <sstream>
using namespace std;
namespace search
{
// Tracer::Parse -----------------------------------------------------------------------------------
Tracer::Parse::Parse(std::vector<TokenType> const & types)
{
size_t i = 0;
while (i != types.size())
{
auto const type = types[i];
auto j = i + 1;
while (j != types.size() && types[j] == type)
++j;
m_ranges[type] = TokenRange(i, j);
i = j;
}
}
Tracer::Parse::Parse(vector<pair<TokenType, TokenRange>> const & ranges)
{
for (auto const & kv : ranges)
m_ranges[kv.first] = kv.second;
}
std::string DebugPrint(Tracer::Parse const & parse)
{
using TokenType = Tracer::Parse::TokenType;
ostringstream os;
os << "Parse [";
bool first = true;
for (size_t i = 0; i < TokenType::TOKEN_TYPE_COUNT; ++i)
{
auto const & range = parse.m_ranges[i];
if (range.Begin() == range.End())
continue;
if (!first)
os << ", ";
os << DebugPrint(static_cast<TokenType>(i)) << ": " << DebugPrint(range);
first = false;
}
os << "]";
return os.str();
}
// Tracer ------------------------------------------------------------------------------------------
vector<Tracer::Parse> Tracer::GetUniqueParses() const
{
auto parses = m_parses;
my::SortUnique(parses);
return parses;
}
} // namespace search

42
search/tracer.hpp Normal file
View file

@ -0,0 +1,42 @@
#pragma once
#include "search/geocoder_context.hpp"
#include "search/token_range.hpp"
#include <array>
#include <string>
#include <utility>
#include <vector>
namespace search
{
class Tracer
{
public:
struct Parse
{
using TokenType = BaseContext::TokenType;
explicit Parse(std::vector<TokenType> const & types);
explicit Parse(std::vector<std::pair<TokenType, TokenRange>> const & ranges);
bool operator==(Parse const & rhs) const { return m_ranges == rhs.m_ranges; }
bool operator<(Parse const & rhs) const { return m_ranges < rhs.m_ranges; }
std::array<TokenRange, TokenType::TOKEN_TYPE_COUNT> m_ranges;
};
template <typename ...Args>
void EmitParse(Args &&... args)
{
m_parses.emplace_back(std::forward<Args>(args)...);
}
std::vector<Parse> GetUniqueParses() const;
private:
std::vector<Parse> m_parses;
};
std::string DebugPrint(Tracer::Parse const & parse);
} // namespace search