[search] Ranking with errors.

This commit is contained in:
Yuri Gorshenin 2017-02-15 11:52:51 +03:00
parent 3c40cdbf1e
commit 994df64d58
12 changed files with 187 additions and 101 deletions

View file

@ -965,7 +965,10 @@ void Geocoder::MatchPOIsAndBuildings(BaseContext & ctx, size_t curToken)
filtered.ForEach([&](uint32_t id) {
SearchModel::SearchType searchType;
if (GetSearchTypeInGeocoding(ctx, id, searchType))
EmitResult(ctx, m_context->GetId(), id, searchType, m_postcodes.m_tokenRange);
{
EmitResult(ctx, m_context->GetId(), id, searchType, m_postcodes.m_tokenRange,
nullptr /* geoParts */);
}
});
return;
}
@ -986,7 +989,7 @@ void Geocoder::MatchPOIsAndBuildings(BaseContext & ctx, size_t curToken)
if (!m_postcodes.m_features.HasBit(id))
continue;
EmitResult(ctx, m_context->GetId(), id, SearchModel::SEARCH_TYPE_STREET,
layers.back().m_tokenRange);
layers.back().m_tokenRange, nullptr /* geoParts */);
}
}
@ -1186,15 +1189,14 @@ void Geocoder::FindPaths(BaseContext const & ctx)
*m_matcher, sortedLayers, [this, &ctx, &innermostLayer](IntersectionResult const & result)
{
ASSERT(result.IsValid(), ());
// TODO(@y, @m, @vng): use rest fields of IntersectionResult for
// better scoring.
EmitResult(ctx, m_context->GetId(), result.InnermostResult(), innermostLayer.m_type,
innermostLayer.m_tokenRange);
innermostLayer.m_tokenRange, &result);
});
}
void Geocoder::EmitResult(BaseContext const & ctx, MwmSet::MwmId const & mwmId, uint32_t ftId,
SearchModel::SearchType type, TokenRange const & tokenRange)
SearchModel::SearchType type, TokenRange const & tokenRange,
IntersectionResult const * geoParts)
{
FeatureID id(mwmId, ftId);
@ -1220,6 +1222,9 @@ void Geocoder::EmitResult(BaseContext const & ctx, MwmSet::MwmId const & mwmId,
if (ctx.m_city)
info.m_tokenRange[SearchModel::SEARCH_TYPE_CITY] = ctx.m_city->m_tokenRange;
if (geoParts)
info.m_geoParts = *geoParts;
m_preRanker.Emplace(id, info);
}
@ -1227,12 +1232,13 @@ void Geocoder::EmitResult(BaseContext const & ctx, Region const & region,
TokenRange const & tokenRange)
{
auto const type = Region::ToSearchType(region.m_type);
EmitResult(ctx, region.m_countryId, region.m_featureId, type, tokenRange);
EmitResult(ctx, region.m_countryId, region.m_featureId, type, tokenRange, nullptr /* geoParts */);
}
void Geocoder::EmitResult(BaseContext const & ctx, City const & city, TokenRange const & tokenRange)
{
EmitResult(ctx, city.m_countryId, city.m_featureId, city.m_type, tokenRange);
EmitResult(ctx, city.m_countryId, city.m_featureId, city.m_type, tokenRange,
nullptr /* geoParts */);
}
void Geocoder::MatchUnclassified(BaseContext & ctx, size_t curToken)
@ -1268,7 +1274,10 @@ void Geocoder::MatchUnclassified(BaseContext & ctx, size_t curToken)
if (!GetSearchTypeInGeocoding(ctx, featureId, searchType))
return;
if (searchType == SearchModel::SEARCH_TYPE_UNCLASSIFIED)
EmitResult(ctx, m_context->GetId(), featureId, searchType, TokenRange(startToken, curToken));
{
EmitResult(ctx, m_context->GetId(), featureId, searchType, TokenRange(startToken, curToken),
nullptr /* geoParts */);
}
};
allFeatures.ForEach(emitUnclassified);
}

View file

@ -196,7 +196,8 @@ private:
// Forms result and feeds it to |m_preRanker|.
void EmitResult(BaseContext const & ctx, MwmSet::MwmId const & mwmId, uint32_t ftId,
SearchModel::SearchType type, TokenRange const & tokenRange);
SearchModel::SearchType type, TokenRange const & tokenRange,
IntersectionResult const * m_geoParts);
void EmitResult(BaseContext const & ctx, Region const & region, TokenRange const & tokenRange);
void EmitResult(BaseContext const & ctx, City const & city, TokenRange const & tokenRange);

View file

@ -1,17 +1,11 @@
#include "search/intersection_result.hpp"
#include "std/limits.hpp"
#include "std/sstream.hpp"
#include <sstream>
namespace search
{
// static
uint32_t const IntersectionResult::kInvalidId = numeric_limits<uint32_t>::max();
IntersectionResult::IntersectionResult()
: m_poi(kInvalidId), m_building(kInvalidId), m_street(kInvalidId)
{
}
uint32_t const IntersectionResult::kInvalidId;
void IntersectionResult::Set(SearchModel::SearchType type, uint32_t id)
{
@ -47,9 +41,9 @@ void IntersectionResult::Clear()
m_street = kInvalidId;
}
string DebugPrint(IntersectionResult const & result)
std::string DebugPrint(IntersectionResult const & result)
{
ostringstream os;
std::ostringstream os;
os << "IntersectionResult [ ";
if (result.m_poi != IntersectionResult::kInvalidId)
os << "POI:" << result.m_poi << " ";
@ -60,5 +54,4 @@ string DebugPrint(IntersectionResult const & result)
os << "]";
return os.str();
}
} // namespace search

View file

@ -2,8 +2,9 @@
#include "search/model.hpp"
#include "std/cstdint.hpp"
#include "std/string.hpp"
#include <cstdint>
#include <limits>
#include <string>
namespace search
{
@ -11,9 +12,9 @@ namespace search
// i.e. BUILDING and STREET for POI or STREET for BUILDING.
struct IntersectionResult
{
static uint32_t const kInvalidId;
static uint32_t const kInvalidId = std::numeric_limits<uint32_t>::max();
IntersectionResult();
IntersectionResult() = default;
void Set(SearchModel::SearchType type, uint32_t id);
@ -27,10 +28,10 @@ struct IntersectionResult
// Clears all fields to an invalid state.
void Clear();
uint32_t m_poi;
uint32_t m_building;
uint32_t m_street;
uint32_t m_poi = kInvalidId;
uint32_t m_building = kInvalidId;
uint32_t m_street = kInvalidId;
};
string DebugPrint(IntersectionResult const & result);
std::string DebugPrint(IntersectionResult const & result);
} // namespace search

View file

@ -1,5 +1,6 @@
#pragma once
#include "search/intersection_result.hpp"
#include "search/model.hpp"
#include "search/token_range.hpp"
@ -39,6 +40,9 @@ struct PreRankingInfo
// Tokens match to the feature name or house number.
TokenRange m_tokenRange[SearchModel::SEARCH_TYPE_COUNT];
// Different geo-parts participated in search.
IntersectionResult m_geoParts;
// Rank of the feature.
uint8_t m_rank = 0;

View file

@ -35,6 +35,31 @@ void UpdateNameScore(vector<strings::UniString> const & tokens, TSlice const & s
bestScore = score;
}
NameScore GetNameScore(FeatureType const & ft, Geocoder::Params const & params,
TokenRange const & range, SearchModel::SearchType type)
{
NameScore bestScore = NAME_SCORE_ZERO;
TokenSlice slice(params, range);
TokenSliceNoCategories sliceNoCategories(params, range);
for (auto const & lang : params.GetLangs())
{
string name;
if (!ft.GetName(lang, name))
continue;
vector<strings::UniString> tokens;
PrepareStringForMatching(name, tokens);
UpdateNameScore(tokens, slice, bestScore);
UpdateNameScore(tokens, sliceNoCategories, bestScore);
}
if (type == SearchModel::SEARCH_TYPE_BUILDING)
UpdateNameScore(ft.GetHouseNumber(), sliceNoCategories, bestScore);
return bestScore;
}
void RemoveDuplicatingLinear(vector<IndexedValue> & values)
{
PreResult2::LessLinearTypesF lessCmp;
@ -141,29 +166,36 @@ class PreResult2Maker
Geocoder::Params const & m_params;
storage::CountryInfoGetter const & m_infoGetter;
unique_ptr<Index::FeaturesLoaderGuard> m_pFV;
unique_ptr<Index::FeaturesLoaderGuard> m_loader;
// For the best performance, incoming id's should be sorted by id.first (mwm file id).
bool LoadFeature(FeatureID const & id, FeatureType & f, m2::PointD & center, string & name,
string & country)
bool LoadFeature(FeatureID const & id, FeatureType & ft)
{
if (m_pFV.get() == 0 || m_pFV->GetId() != id.m_mwmId)
m_pFV.reset(new Index::FeaturesLoaderGuard(m_index, id.m_mwmId));
if (!m_pFV->GetFeatureByIndex(id.m_index, f))
if (!m_loader || m_loader->GetId() != id.m_mwmId)
m_loader = make_unique<Index::FeaturesLoaderGuard>(m_index, id.m_mwmId);
if (!m_loader->GetFeatureByIndex(id.m_index, ft))
return false;
f.SetID(id);
ft.SetID(id);
return true;
}
center = feature::GetCenter(f);
// For the best performance, incoming id's should be sorted by id.first (mwm file id).
bool LoadFeature(FeatureID const & id, FeatureType & ft, m2::PointD & center, string & name,
string & country)
{
if (!LoadFeature(id, ft))
return false;
m_ranker.GetBestMatchName(f, name);
center = feature::GetCenter(ft);
m_ranker.GetBestMatchName(ft, name);
// country (region) name is a file name if feature isn't from World.mwm
if (m_pFV->IsWorld())
// Country (region) name is a file name if feature isn't from
// World.mwm.
ASSERT(m_loader && m_loader->GetId() == id.m_mwmId, ());
if (m_loader->IsWorld())
country.clear();
else
country = m_pFV->GetCountryFileName();
country = m_loader->GetCountryFileName();
return true;
}
@ -178,26 +210,23 @@ class PreResult2Maker
info.m_distanceToPivot = MercatorBounds::DistanceOnEarth(center, pivot);
info.m_rank = preInfo.m_rank;
info.m_searchType = preInfo.m_searchType;
info.m_nameScore = NAME_SCORE_ZERO;
info.m_nameScore = GetNameScore(ft, m_params, preInfo.InnermostTokenRange(), info.m_searchType);
TokenSlice slice(m_params, preInfo.InnermostTokenRange());
TokenSliceNoCategories sliceNoCategories(m_params, preInfo.InnermostTokenRange());
for (auto const & lang : m_params.GetLangs())
if (info.m_searchType != SearchModel::SEARCH_TYPE_STREET &&
preInfo.m_geoParts.m_street != IntersectionResult::kInvalidId)
{
string name;
if (!ft.GetName(lang, name))
continue;
vector<strings::UniString> tokens;
PrepareStringForMatching(name, tokens);
UpdateNameScore(tokens, slice, info.m_nameScore);
UpdateNameScore(tokens, sliceNoCategories, info.m_nameScore);
auto const & mwmId = ft.GetID().m_mwmId;
FeatureType street;
if (LoadFeature(FeatureID(mwmId, preInfo.m_geoParts.m_street), street))
{
NameScore const nameScore =
GetNameScore(street, m_params, preInfo.m_tokenRange[SearchModel::SEARCH_TYPE_STREET],
SearchModel::SEARCH_TYPE_STREET);
info.m_nameScore = min(info.m_nameScore, nameScore);
}
}
if (info.m_searchType == SearchModel::SEARCH_TYPE_BUILDING)
UpdateNameScore(ft.GetHouseNumber(), sliceNoCategories, info.m_nameScore);
TokenSlice slice(m_params, preInfo.InnermostTokenRange());
feature::TypesHolder holder(ft);
vector<pair<size_t, size_t>> matched(slice.Size());
ForEachCategoryType(QuerySlice(slice), m_ranker.m_params.m_categoryLocales,
@ -312,8 +341,7 @@ bool Ranker::IsResultExists(PreResult2 const & p, vector<IndexedValue> const & v
});
}
void Ranker::MakePreResult2(Geocoder::Params const & geocoderParams, vector<IndexedValue> & cont,
vector<FeatureID> & streets)
void Ranker::MakePreResult2(Geocoder::Params const & geocoderParams, vector<IndexedValue> & cont)
{
PreResult2Maker maker(*this, m_index, m_infoGetter, geocoderParams);
for (auto const & r : m_preResults1)
@ -328,9 +356,6 @@ void Ranker::MakePreResult2(Geocoder::Params const & geocoderParams, vector<Inde
continue;
}
if (p->IsStreet())
streets.push_back(p->GetID());
if (!IsResultExists(*p, cont))
cont.push_back(IndexedValue(move(p)));
};
@ -492,8 +517,7 @@ void Ranker::UpdateResults(bool lastUpdate)
{
BailIfCancelled();
vector<FeatureID> streets;
MakePreResult2(m_geocoderParams, m_tentativeResults, streets);
MakePreResult2(m_geocoderParams, m_tentativeResults);
RemoveDuplicatingLinear(m_tentativeResults);
if (m_tentativeResults.empty())
return;

View file

@ -81,8 +81,7 @@ public:
bool IsResultExists(PreResult2 const & p, vector<IndexedValue> const & values);
void MakePreResult2(Geocoder::Params const & params, vector<IndexedValue> & cont,
vector<FeatureID> & streets);
void MakePreResult2(Geocoder::Params const & params, vector<IndexedValue> & cont);
Result MakeResult(PreResult2 const & r) const;
void MakeResultHighlight(Result & res) const;

View file

@ -9,6 +9,7 @@ set(
interactive_search_test.cpp
pre_ranker_test.cpp
processor_test.cpp
ranker_test.cpp
search_edited_features_test.cpp
smoke_test.cpp
)

View file

@ -67,13 +67,28 @@ bool SearchTest::ResultsMatch(string const & query, Mode mode,
return MatchResults(m_engine, rules, request.Results());
}
bool SearchTest::ResultsMatch(vector<search::Result> const & results, TRules const & rules)
{
return MatchResults(m_engine, rules, results);
}
unique_ptr<tests_support::TestSearchRequest> SearchTest::MakeRequest(string const & query)
{
SearchParams params;
params.m_query = query;
params.m_inputLocale = "en";
params.m_mode = Mode::Everywhere;
params.m_suggestsEnabled = false;
auto request = make_unique<tests_support::TestSearchRequest>(m_engine, params, m_viewport);
request->Run();
return request;
}
size_t SearchTest::CountFeatures(m2::RectD const & rect)
{
size_t count = 0;
auto counter = [&count](const FeatureType & /* ft */)
{
++count;
};
auto counter = [&count](const FeatureType & /* ft */) { ++count; };
m_engine.ForEachInRect(counter, rect, scales::GetUpperScale());
return count;
}

View file

@ -2,6 +2,7 @@
#include "search/search_tests_support/test_results_matching.hpp"
#include "search/search_tests_support/test_search_engine.hpp"
#include "search/search_tests_support/test_search_request.hpp"
#include "generator/generator_tests_support/test_mwm_builder.hpp"
@ -101,12 +102,17 @@ public:
}
inline void SetViewport(m2::RectD const & viewport) { m_viewport = viewport; }
bool ResultsMatch(string const & query, TRules const & rules);
bool ResultsMatch(string const & query, string const & locale, TRules const & rules);
bool ResultsMatch(string const & query, Mode mode, TRules const & rules);
bool ResultsMatch(vector<search::Result> const & results, TRules const & rules);
unique_ptr<tests_support::TestSearchRequest> MakeRequest(string const & query);
size_t CountFeatures(m2::RectD const & rect);
protected:

View file

@ -66,25 +66,6 @@ private:
class ProcessorTest : public SearchTest
{
public:
unique_ptr<TestSearchRequest> MakeRequest(string const & query)
{
SearchParams params;
params.m_query = query;
params.m_inputLocale = "en";
params.m_mode = Mode::Everywhere;
params.m_suggestsEnabled = false;
auto request = make_unique<TestSearchRequest>(m_engine, params, m_viewport);
request->Run();
return request;
}
bool MatchResults(vector<shared_ptr<MatchingRule>> rules,
vector<search::Result> const & actual) const
{
return ::MatchResults(m_engine, rules, actual);
}
};
UNIT_CLASS_TEST(ProcessorTest, Smoke)
@ -337,7 +318,7 @@ UNIT_CLASS_TEST(ProcessorTest, DisableSuggests)
request.Run();
TRules rules = {ExactMatch(worldId, london1), ExactMatch(worldId, london2)};
TEST(MatchResults(rules, request.Results()), ());
TEST(ResultsMatch(request.Results(), rules), ());
}
}
@ -399,7 +380,7 @@ UNIT_CLASS_TEST(ProcessorTest, TestRankingInfo)
TRules rules = {ExactMatch(wonderlandId, goldenGateBridge),
ExactMatch(wonderlandId, goldenGateStreet)};
TEST(MatchResults(rules, request->Results()), ());
TEST(ResultsMatch(request->Results(), rules), ());
for (auto const & result : request->Results())
{
auto const & info = result.GetRankingInfo();
@ -416,11 +397,11 @@ UNIT_CLASS_TEST(ProcessorTest, TestRankingInfo)
TRules rules{ExactMatch(wonderlandId, cafe1), ExactMatch(wonderlandId, cafe2),
ExactMatch(wonderlandId, lermontov)};
TEST(MatchResults(rules, results), ());
TEST(ResultsMatch(results, rules), ());
TEST_EQUAL(3, results.size(), ("Unexpected number of retrieved cafes."));
auto const & top = results.front();
TEST(MatchResults({ExactMatch(wonderlandId, lermontov)}, {top}), ());
TEST(ResultsMatch({top}, {ExactMatch(wonderlandId, lermontov)}), ());
}
{
@ -640,7 +621,7 @@ UNIT_CLASS_TEST(ProcessorTest, TestCategories)
ExactMatch(wonderlandId, busStop)};
auto request = MakeRequest("atm");
TEST(MatchResults(rules, request->Results()), ());
TEST(ResultsMatch(request->Results(), rules), ());
for (auto const & result : request->Results())
{
Index::FeaturesLoaderGuard loader(m_engine, wonderlandId);
@ -713,7 +694,7 @@ UNIT_CLASS_TEST(ProcessorTest, HotelsFiltering)
TestSearchRequest request(m_engine, params, m_viewport);
request.Run();
TRules rules = {ExactMatch(id, h1), ExactMatch(id, h2), ExactMatch(id, h3), ExactMatch(id, h4)};
TEST(MatchResults(rules, request.Results()), ());
TEST(ResultsMatch(request.Results(), rules), ());
}
using namespace hotels_filter;
@ -723,7 +704,7 @@ UNIT_CLASS_TEST(ProcessorTest, HotelsFiltering)
TestSearchRequest request(m_engine, params, m_viewport);
request.Run();
TRules rules = {ExactMatch(id, h1), ExactMatch(id, h3)};
TEST(MatchResults(rules, request.Results()), ());
TEST(ResultsMatch(request.Results(), rules), ());
}
params.m_hotelsFilter = Or(Eq<Rating>(9.0), Le<PriceRate>(4));
@ -731,7 +712,7 @@ UNIT_CLASS_TEST(ProcessorTest, HotelsFiltering)
TestSearchRequest request(m_engine, params, m_viewport);
request.Run();
TRules rules = {ExactMatch(id, h1), ExactMatch(id, h3), ExactMatch(id, h4)};
TEST(MatchResults(rules, request.Results()), ());
TEST(ResultsMatch(request.Results(), rules), ());
}
params.m_hotelsFilter = Or(And(Eq<Rating>(7.0), Eq<PriceRate>(5)), Eq<PriceRate>(4));
@ -739,7 +720,7 @@ UNIT_CLASS_TEST(ProcessorTest, HotelsFiltering)
TestSearchRequest request(m_engine, params, m_viewport);
request.Run();
TRules rules = {ExactMatch(id, h2), ExactMatch(id, h4)};
TEST(MatchResults(rules, request.Results()), ());
TEST(ResultsMatch(request.Results(), rules), ());
}
}
@ -842,7 +823,7 @@ UNIT_CLASS_TEST(ProcessorTest, StopWords)
TRules rules = {ExactMatch(id, street)};
auto const & results = request->Results();
TEST(MatchResults(rules, results), ());
TEST(ResultsMatch(results, rules), ());
auto const & info = results[0].GetRankingInfo();
TEST_EQUAL(info.m_nameScore, NAME_SCORE_FULL_MATCH, ());

View file

@ -0,0 +1,52 @@
#include "testing/testing.hpp"
#include "search/search_integration_tests/helpers.hpp"
#include "search/search_tests_support/test_results_matching.hpp"
#include "generator/generator_tests_support/test_feature.hpp"
using namespace generator::tests_support;
using namespace search::tests_support;
using namespace search;
namespace
{
class RankerTest : public SearchTest
{
};
UNIT_CLASS_TEST(RankerTest, ErrorsInStreets)
{
TestStreet mazurova(
vector<m2::PointD>{m2::PointD(-0.001, -0.001), m2::PointD(0, 0), m2::PointD(0.001, 0.001)},
"Мазурова", "ru");
TestBuilding mazurova14(m2::PointD(-0.001, -0.001), "", "14", mazurova, "ru");
TestStreet masherova(
vector<m2::PointD>{m2::PointD(-0.001, 0.001), m2::PointD(0, 0), m2::PointD(0.001, -0.001)},
"Машерова", "ru");
TestBuilding masherova14(m2::PointD(0.001, 0.001), "", "14", masherova, "ru");
auto id = BuildCountry("Belarus", [&](TestMwmBuilder & builder) {
builder.Add(mazurova);
builder.Add(mazurova14);
builder.Add(masherova);
builder.Add(masherova14);
});
SetViewport(m2::RectD(m2::PointD(0, 0), m2::PointD(0.001, 0.001)));
{
auto request = MakeRequest("Мазурова 14");
auto const & results = request->Results();
TRules rules = {ExactMatch(id, mazurova14), ExactMatch(id, masherova14)};
TEST(ResultsMatch(results, rules), ());
TEST_EQUAL(results.size(), 2, ());
TEST(ResultsMatch({results[0]}, {rules[0]}), ());
TEST(ResultsMatch({results[1]}, {rules[1]}), ());
}
}
} // namespace