Code review.

This commit is contained in:
Sergey Magidovich 2016-07-12 16:21:46 +03:00
parent 88d5775c23
commit 61babee342
7 changed files with 96 additions and 46 deletions

View file

@ -209,7 +209,8 @@ void BookingDataset::BuildFeatures(function<void(OsmElement *)> const & fn) cons
if (!hotel.houseNumber.empty())
e.AddTag("addr:housenumber", hotel.houseNumber);
// TODO(mgsergio): Add a comment or use enum.
// Matching booking.com hotel types to OpenStreetMap values.
// Booking types are listed in the closed API docs.
switch (hotel.type)
{
case 19:

View file

@ -17,7 +17,7 @@ DEFINE_string(osm_file_name, "", "Input .o5m file");
DEFINE_string(booking_data, "", "Path to booking data in .tsv format");
DEFINE_string(sample_data, "", "Sample output path");
DEFINE_uint64(selection_size, 1000, "Selection size");
DEFINE_uint64(random_seed, minstd_rand::default_seed, "Seed for random shuffle");
DEFINE_uint64(seed, minstd_rand::default_seed, "Seed for random shuffle");
using namespace generator;
@ -60,7 +60,7 @@ int main(int argc, char * argv[])
vector<size_t> elementIndexes(elements.size());
iota(elementIndexes.begin(), elementIndexes.end(), 0);
shuffle(elementIndexes.begin(), elementIndexes.end(), minstd_rand(FLAGS_random_seed));
shuffle(elementIndexes.begin(), elementIndexes.end(), minstd_rand(FLAGS_seed));
if (FLAGS_selection_size < elementIndexes.size())
elementIndexes.resize(FLAGS_selection_size);
@ -82,7 +82,11 @@ int main(int argc, char * argv[])
outStream << "# ------------------------------------------" << fixed << setprecision(6)
<< endl;
outStream << (matched ? 'y' : 'n') << " \t" << i << "\t " << j
<< " distance: " << distanceMeters << " score: " << score.GetMatchingScore() << endl;
<< "\tdistance: " << distanceMeters
<< "\tdistance score: " << score.m_linearNormDistanceScore
<< "\tname score: " << score.m_nameSimilarityScore
<< "\tresult score: " << score.GetMatchingScore()
<< endl;
outStream << "# " << e << endl;
outStream << "# " << hotel << endl;
outStream << "# URL: https://www.openstreetmap.org/?mlat=" << hotel.lat

View file

@ -2,12 +2,16 @@
#include "generator/booking_dataset.hpp"
#include "indexer/search_string_utils.hpp"
#include "indexer/search_delimiters.hpp"
#include "indexer/search_string_utils.hpp"
#include "geometry/distance_on_sphere.hpp"
#include "base/collection_cast.hpp"
#include "base/stl_iterator.hpp"
#include "std/algorithm.hpp"
#include "std/vector.hpp"
namespace generator
{
@ -16,49 +20,82 @@ namespace booking_scoring
namespace
{
// Calculated with tools/python/booking_hotels_quality.py.
double constexpr kOptimalThreshold = 0.151001;
double constexpr kOptimalThreshold = 0.317324;
template <typename T, typename U>
struct decay_equiv :
std::is_same<typename std::decay<T>::type, U>::type
{};
set<strings::UniString> StringToSetOfWords(string const & str)
using WeightedBagOfWords = vector<pair<strings::UniString, double>>;
vector<strings::UniString> StringToSetOfWords(string const & str)
{
vector<strings::UniString> result;
search::NormalizeAndTokenizeString(str, result, search::Delimiters{});
return my::collection_cast<set>(result);
sort(begin(result), end(result));
return result;
}
// TODO(mgsergio): Update existing one in base or wherever...
// Or just use one from boost.
struct CounterIterator
WeightedBagOfWords MakeWeightedBagOfWords(vector<strings::UniString> const & words)
{
template<typename T, typename = typename enable_if<!decay_equiv<T, CounterIterator>::value>::type>
CounterIterator & operator=(T const &) { ++m_count; return *this; }
CounterIterator & operator++() { return *this; }
CounterIterator & operator++(int) { return *this; }
CounterIterator & operator*() { return *this; }
uint32_t Count() const { return m_count; }
// TODO(mgsergio): Calculate tf-idsf score for every word.
auto constexpr kTfIdfScorePlaceholder = 1;
uint32_t m_count = 0;
};
double StringSimilarityScore(string const & a, string const & b)
{
auto const aWords = StringToSetOfWords(a);
auto const bWords = StringToSetOfWords(b);
auto const intersectionCard = set_intersection(begin(aWords), end(aWords),
begin(bWords), end(bWords),
CounterIterator()).Count();
auto const aLikeBScore = static_cast<double>(intersectionCard) / aWords.size();
auto const bLikeAScore = static_cast<double>(intersectionCard) / bWords.size();
return aLikeBScore * bLikeAScore;
WeightedBagOfWords result;
for (auto i = 0; i < words.size(); ++i)
{
result.emplace_back(words[i], kTfIdfScorePlaceholder);
while (i + 1 < words.size() && words[i] == words[i + 1])
{
result.back().second += kTfIdfScorePlaceholder; // TODO(mgsergio): tf-idf score for result[i].frist;
++i;
}
}
return result;
}
double GetLinearNormDistanceScrore(double distance)
double WeightedBagsDotProduct(WeightedBagOfWords const & lhs, WeightedBagOfWords const & rhs)
{
double result{};
auto lhsIt = begin(lhs);
auto rhsIt = begin(rhs);
while (lhsIt != end(lhs) && rhsIt != end(rhs))
{
if (lhsIt->first == rhsIt->first)
{
result += lhsIt->second * rhsIt->second;
++lhsIt;
++rhsIt;
}
else if (lhsIt->first < rhsIt->first)
{
++lhsIt;
}
else
{
++rhsIt;
}
}
return result;
}
double WeightedBagOfWordsCos(WeightedBagOfWords const & lhs, WeightedBagOfWords const & rhs)
{
auto const product = WeightedBagsDotProduct(lhs, rhs);
auto const lhsLength = sqrt(WeightedBagsDotProduct(lhs, lhs));
auto const rhsLength = sqrt(WeightedBagsDotProduct(rhs, rhs));
if (product == 0.0)
return 0.0;
return product / (lhsLength * rhsLength);
}
double GetLinearNormDistanceScore(double distance)
{
distance = my::clamp(distance, 0, BookingDataset::kDistanceLimitInMeters);
return 1.0 - distance / BookingDataset::kDistanceLimitInMeters;
@ -66,7 +103,15 @@ double GetLinearNormDistanceScrore(double distance)
double GetNameSimilarityScore(string const & booking_name, string const & osm_name)
{
return StringSimilarityScore(booking_name, osm_name);
auto const aws = MakeWeightedBagOfWords(StringToSetOfWords(booking_name));
auto const bws = MakeWeightedBagOfWords(StringToSetOfWords(osm_name));
if (aws.empty() && bws.empty())
return 1.0;
if (aws.empty() || bws.empty())
return 0.0;
return WeightedBagOfWordsCos(aws, bws);
}
} // namespace
@ -85,11 +130,10 @@ BookingMatchScore Match(BookingDataset::Hotel const & h, OsmElement const & e)
BookingMatchScore score;
auto const distance = ms::DistanceOnEarth(e.lat, e.lon, h.lat, h.lon);
score.m_linearNormDistanceScore = GetLinearNormDistanceScrore(distance);
score.m_linearNormDistanceScore = GetLinearNormDistanceScore(distance);
string osmHotelName;
score.m_nameSimilarityScore = e.GetTag("name", osmHotelName)
? GetNameSimilarityScore(h.name, osmHotelName) : 0;
// TODO(mgsergio): Check all translations and use the best one.
score.m_nameSimilarityScore = GetNameSimilarityScore(h.name, e.GetTag("name"));
return score;
}

View file

@ -1,7 +1,8 @@
#pragma once
#include "generator/booking_dataset.hpp"
#include "generator/osm_element.hpp"
struct OsmElement;
namespace generator
{

View file

@ -121,7 +121,7 @@ string OsmElement::ToString(string const & shift) const
return ss.str();
}
bool OsmElement::GetTag(string const & key, string & value) const
string OsmElement::GetTag(string const & key) const
{
auto const it = find_if(begin(m_tags), end(m_tags), [&key](Tag const & tag)
{
@ -129,10 +129,9 @@ bool OsmElement::GetTag(string const & key, string & value) const
});
if (it == end(m_tags))
return false;
return {};
value = it->value;
return true;
return it->value;
}
string DebugPrint(OsmElement const & e)

View file

@ -153,7 +153,7 @@ struct OsmElement
AddTag(k, v);
}
bool GetTag(string const & key, string & value) const;
string GetTag(string const & key) const;
};
string DebugPrint(OsmElement const & e);

View file

@ -14,6 +14,7 @@ import os
import pickle
import time
import urllib2
import re
# init logging
logging.basicConfig(level=logging.DEBUG, format='[%(asctime)s] %(levelname)s: %(message)s')
@ -21,7 +22,7 @@ logging.basicConfig(level=logging.DEBUG, format='[%(asctime)s] %(levelname)s: %(
def load_binary_list(path):
"""
Loads referance binary classifier output.
Loads reference binary classifier output.
"""
bits = []
with open(path, 'r') as fd:
@ -41,7 +42,7 @@ def load_score_list(path):
for line in fd:
if (not line.strip()) or line[0] == '#':
continue
scores.append(float(line[line.rfind(':')+2:]))
scores.append(float(re.search(r'result score: (\d*\.\d+)', line).group(1)))
return scores