forked from organicmaps/organicmaps
[search] Introduce ResultType feature for ranker.
This commit is contained in:
parent
dfbbfee9e3
commit
8ef7c2f767
5 changed files with 143 additions and 2 deletions
|
@ -82,6 +82,8 @@ namespace feature
|
|||
|
||||
size_t Size() const { return m_size; }
|
||||
bool Empty() const { return (m_size == 0); }
|
||||
Types::const_iterator cbegin() const { return m_types.cbegin(); }
|
||||
Types::const_iterator cend() const { return m_types.cbegin() + m_size; }
|
||||
Types::const_iterator begin() const { return m_types.cbegin(); }
|
||||
Types::const_iterator end() const { return m_types.cbegin() + m_size; }
|
||||
Types::iterator begin() { return m_types.begin(); }
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include "indexer/brands_holder.hpp"
|
||||
#include "indexer/data_source.hpp"
|
||||
#include "indexer/feature_algo.hpp"
|
||||
#include "indexer/feature_data.hpp"
|
||||
#include "indexer/feature_utils.hpp"
|
||||
#include "indexer/ftypes_matcher.hpp"
|
||||
#include "indexer/search_string_utils.hpp"
|
||||
|
@ -353,6 +354,8 @@ class RankerResultMaker
|
|||
info.m_popularity = preInfo.m_popularity;
|
||||
info.m_rating = preInfo.m_rating;
|
||||
info.m_type = preInfo.m_type;
|
||||
if (info.m_type == Model::TYPE_POI)
|
||||
info.m_resultType = GetResultType(feature::TypesHolder(ft));
|
||||
info.m_allTokensUsed = preInfo.m_allTokensUsed;
|
||||
info.m_numTokens = m_params.GetNumTokens();
|
||||
info.m_exactMatch = preInfo.m_exactMatch;
|
||||
|
|
|
@ -1,10 +1,14 @@
|
|||
#include "search/ranking_info.hpp"
|
||||
|
||||
#include "search/utils.hpp"
|
||||
|
||||
#include "ugc/types.hpp"
|
||||
|
||||
#include "indexer/classificator.hpp"
|
||||
#include "indexer/search_string_utils.hpp"
|
||||
|
||||
#include "base/assert.hpp"
|
||||
#include "base/stl_helpers.hpp"
|
||||
|
||||
#include <iomanip>
|
||||
#include <limits>
|
||||
|
@ -50,6 +54,15 @@ double constexpr kType[Model::TYPE_COUNT] = {
|
|||
0.0233254 /* State */,
|
||||
0.1679389 /* Country */
|
||||
};
|
||||
double constexpr kResultType[base::Underlying(ResultType::Count)] = {
|
||||
0.0338794 /* TransportMajor */,
|
||||
0.0216298 /* TransportLocal */,
|
||||
0.0064977 /* Eat */,
|
||||
-0.0275763 /* Hotel */,
|
||||
0.0358858 /* Attraction */,
|
||||
-0.0195234 /* Service */,
|
||||
-0.0128952 /* General */
|
||||
};
|
||||
|
||||
// Coeffs sanity checks.
|
||||
static_assert(kHasName >= 0, "");
|
||||
|
@ -102,6 +115,42 @@ void PrintParse(ostringstream & oss, array<TokenRange, Model::TYPE_COUNT> const
|
|||
}
|
||||
oss << "]";
|
||||
}
|
||||
|
||||
class IsServiceTypeChecker
|
||||
{
|
||||
public:
|
||||
IsServiceTypeChecker()
|
||||
{
|
||||
vector<string> const oneLevelTypes = {
|
||||
"barrier",
|
||||
"power",
|
||||
"traffic_calming"
|
||||
};
|
||||
|
||||
vector<vector<string>> const twoLevelTypes = {};
|
||||
|
||||
for (auto const t : oneLevelTypes)
|
||||
m_oneLevelTypes.push_back(classif().GetTypeByPath({t}));
|
||||
for (auto const t : twoLevelTypes)
|
||||
m_twoLevelTypes.push_back(classif().GetTypeByPath(t));
|
||||
}
|
||||
|
||||
bool operator()(feature::TypesHolder const & th) const
|
||||
{
|
||||
auto findType = [](vector<uint32_t> const & v, uint32_t t, uint8_t level) {
|
||||
ftype::TruncValue(t, level);
|
||||
return find(v.begin(), v.end(), t) != v.end();
|
||||
};
|
||||
|
||||
return base::AnyOf(th, [&](auto t) {
|
||||
return findType(m_oneLevelTypes, t, 1) || findType(m_twoLevelTypes, t, 2);
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
vector<uint32_t> m_oneLevelTypes;
|
||||
vector<uint32_t> m_twoLevelTypes;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// static
|
||||
|
@ -118,6 +167,7 @@ void RankingInfo::PrintCSVHeader(ostream & os)
|
|||
<< ",ErrorsMade"
|
||||
<< ",MatchedFraction"
|
||||
<< ",SearchType"
|
||||
<< ",ResultType"
|
||||
<< ",PureCats"
|
||||
<< ",FalseCats"
|
||||
<< ",AllTokensUsed"
|
||||
|
@ -142,6 +192,7 @@ string DebugPrint(RankingInfo const & info)
|
|||
os << ", m_numTokens:" << info.m_numTokens;
|
||||
os << ", m_matchedFraction:" << info.m_matchedFraction;
|
||||
os << ", m_type:" << DebugPrint(info.m_type);
|
||||
os << ", m_resultType:" << DebugPrint(info.m_resultType);
|
||||
os << ", m_pureCats:" << info.m_pureCats;
|
||||
os << ", m_falseCats:" << info.m_falseCats;
|
||||
os << ", m_allTokensUsed:" << info.m_allTokensUsed;
|
||||
|
@ -163,6 +214,7 @@ void RankingInfo::ToCSV(ostream & os) const
|
|||
os << GetErrorsMadePerToken() << ",";
|
||||
os << m_matchedFraction << ",";
|
||||
os << DebugPrint(m_type) << ",";
|
||||
os << DebugPrint(m_resultType) << ",";
|
||||
os << m_pureCats << ",";
|
||||
os << m_falseCats << ",";
|
||||
os << (m_allTokensUsed ? 1 : 0) << ",";
|
||||
|
@ -203,6 +255,8 @@ double RankingInfo::GetLinearModelRank() const
|
|||
result += kRating * rating;
|
||||
result += m_falseCats * kFalseCats;
|
||||
result += kType[m_type];
|
||||
if (m_type == Model::TYPE_POI)
|
||||
result += kResultType[base::Underlying(m_resultType)];
|
||||
result += kNameScore[nameScore];
|
||||
result += kErrorsMade * GetErrorsMadePerToken();
|
||||
result += kMatchedFraction * m_matchedFraction;
|
||||
|
@ -236,4 +290,51 @@ double RankingInfo::GetErrorsMadePerToken() const
|
|||
CHECK_GREATER(m_numTokens, 0, ());
|
||||
return static_cast<double>(m_errorsMade.m_errorsMade) / static_cast<double>(m_numTokens);
|
||||
}
|
||||
|
||||
ResultType GetResultType(feature::TypesHolder const & th)
|
||||
{
|
||||
if (ftypes::IsEatChecker::Instance()(th))
|
||||
return ResultType::Eat;
|
||||
if (ftypes::IsHotelChecker::Instance()(th))
|
||||
return ResultType::Hotel;
|
||||
if (ftypes::IsRailwayStationChecker::Instance()(th) ||
|
||||
ftypes::IsSubwayStationChecker::Instance()(th) || ftypes::IsAirportChecker::Instance()(th))
|
||||
{
|
||||
return ResultType::TransportMajor;
|
||||
}
|
||||
if (ftypes::IsPublicTransportStopChecker::Instance()(th))
|
||||
return ResultType::TransportLocal;
|
||||
|
||||
// We have several lists for attractions: short list in search categories for @tourism and long
|
||||
// list in ftypes::AttractionsChecker. We have highway-pedestrian, place-square, historic-tomb,
|
||||
// landuse-cemetery, amenity-townhall etc in long list and logic of long list is "if this object
|
||||
// has high popularity and/or wiki description probably it is attraction". It's better to use
|
||||
// short list here.
|
||||
auto static const attractionTypes =
|
||||
search::GetCategoryTypes("sights", "en", GetDefaultCategories());
|
||||
if (base::AnyOf(attractionTypes, [&th](auto t) { return th.Has(t); }))
|
||||
return ResultType::Attraction;
|
||||
|
||||
static const IsServiceTypeChecker isServiceTypeChecker;
|
||||
if (isServiceTypeChecker(th))
|
||||
return ResultType::Service;
|
||||
|
||||
return ResultType::General;
|
||||
}
|
||||
|
||||
string DebugPrint(ResultType type)
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
case ResultType::TransportMajor: return "TransportMajor";
|
||||
case ResultType::TransportLocal: return "TransportLocal";
|
||||
case ResultType::Eat: return "Eat";
|
||||
case ResultType::Hotel: return "Hotel";
|
||||
case ResultType::Attraction: return "Attraction";
|
||||
case ResultType::Service: return "Service";
|
||||
case ResultType::General: return "General";
|
||||
case ResultType::Count: return "Count";
|
||||
}
|
||||
UNREACHABLE();
|
||||
}
|
||||
} // namespace search
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
#include "search/pre_ranking_info.hpp"
|
||||
#include "search/ranking_utils.hpp"
|
||||
|
||||
#include "indexer/feature_data.hpp"
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
@ -15,6 +17,25 @@ class FeatureType;
|
|||
|
||||
namespace search
|
||||
{
|
||||
enum class ResultType : uint8_t
|
||||
{
|
||||
// Railway/subway stations, airports
|
||||
TransportMajor,
|
||||
// Bus/tram stops
|
||||
TransportLocal,
|
||||
// Cafes, restaurants, bars
|
||||
Eat,
|
||||
// Hotels
|
||||
Hotel,
|
||||
// Attractions
|
||||
Attraction,
|
||||
// Service types: power lines and substations, barrier-fence, etc.
|
||||
Service,
|
||||
// All other POIs
|
||||
General,
|
||||
Count
|
||||
};
|
||||
|
||||
struct RankingInfo
|
||||
{
|
||||
static double const kMaxDistMeters;
|
||||
|
@ -72,6 +93,9 @@ struct RankingInfo
|
|||
// Search type for the feature.
|
||||
Model::Type m_type = Model::TYPE_COUNT;
|
||||
|
||||
// Type (food/transport/attraction/etc) for POI results for non-categorial requests.
|
||||
ResultType m_resultType = ResultType::Count;
|
||||
|
||||
// True if all of the tokens that the feature was matched by
|
||||
// correspond to this feature's categories.
|
||||
bool m_pureCats = false;
|
||||
|
@ -88,5 +112,8 @@ struct RankingInfo
|
|||
bool m_hasName = false;
|
||||
};
|
||||
|
||||
ResultType GetResultType(feature::TypesHolder const & th);
|
||||
|
||||
std::string DebugPrint(RankingInfo const & info);
|
||||
std::string DebugPrint(ResultType type);
|
||||
} // namespace search
|
||||
|
|
|
@ -20,8 +20,9 @@ MAX_POPULARITY = 255.0
|
|||
RELEVANCES = {'Harmful': -3, 'Irrelevant': 0, 'Relevant': 1, 'Vital': 3}
|
||||
NAME_SCORES = ['Zero', 'Substring', 'Prefix', 'Full Match']
|
||||
SEARCH_TYPES = ['POI', 'Building', 'Street', 'Unclassified', 'Village', 'City', 'State', 'Country']
|
||||
RESULT_TYPES = ['TransportMajor', 'TransportLocal', 'Eat', 'Hotel', 'Attraction', 'Service', 'General']
|
||||
FEATURES = ['DistanceToPivot', 'Rank', 'Popularity', 'Rating', 'FalseCats', 'ErrorsMade', 'MatchedFraction',
|
||||
'AllTokensUsed', 'ExactCountryOrCapital'] + NAME_SCORES + SEARCH_TYPES
|
||||
'AllTokensUsed', 'ExactCountryOrCapital'] + NAME_SCORES + SEARCH_TYPES + RESULT_TYPES
|
||||
|
||||
BOOTSTRAP_ITERATIONS = 10000
|
||||
|
||||
|
@ -62,6 +63,10 @@ def normalize_data(data):
|
|||
for st in SEARCH_TYPES:
|
||||
data[st] = data['SearchType'].apply(lambda v: int(st == v))
|
||||
|
||||
# Adds dummy variables to data for RESULT_TYPES.
|
||||
for rt in RESULT_TYPES:
|
||||
data[rt] = data['ResultType'].apply(lambda v: int(rt == v))
|
||||
|
||||
|
||||
def compute_ndcg(relevances):
|
||||
"""
|
||||
|
@ -215,17 +220,20 @@ def cpp_output(features, ws):
|
|||
Prints feature-coeff pairs in the C++-compatible format.
|
||||
"""
|
||||
|
||||
ns, st = [], []
|
||||
ns, st, rt = [], [], []
|
||||
|
||||
for f, w in zip(features, ws):
|
||||
if f in NAME_SCORES:
|
||||
ns.append((f, w))
|
||||
elif f in SEARCH_TYPES:
|
||||
st.append((f, w))
|
||||
elif f in RESULT_TYPES:
|
||||
rt.append((f, w))
|
||||
else:
|
||||
print_const(f, w)
|
||||
print_array('kNameScore', 'NameScore::NAME_SCORE_COUNT', ns)
|
||||
print_array('kType', 'Model::TYPE_COUNT', st)
|
||||
print_array('kResultType', 'base::Underlying(ResultType::Count)', rt)
|
||||
|
||||
|
||||
def show_bootstrap_statistics(clf, X, y, features):
|
||||
|
|
Loading…
Add table
Reference in a new issue