[search][tols] Review fixes for booking test dataset download script.

This commit is contained in:
tatiana-yan 2019-08-28 11:48:09 +03:00 committed by mpimenov
parent afce4d7ab0
commit 0eba1a010e
2 changed files with 27 additions and 12 deletions

View file

@ -110,7 +110,10 @@ map<string, string> ParseAddressDataset(string const & filename)
map<string, string> result;
ifstream data(filename);
for (string line; getline(data, line);)
string line;
// Skip header.
getline(data, line);
while (getline(data, line);)
{
vector<string> fields;
strings::ParseCSVRow(line, '\t', fields);

View file

@ -12,11 +12,11 @@ SUPPORTED_LANGUAGES = ("en", "ru")
class BookingGen:
def __init__(self, api, country, districtNames):
def __init__(self, api, country, district_names):
self.api = api
self.country_code = country["country"]
self.country_name = country["name"]
self.districtNames = districtNames
self.district_names = district_names
logging.info(f"Download[{self.country_code}]: {self.country_name}")
extras = ["hotel_info"]
@ -39,8 +39,8 @@ class BookingGen:
hotel_data = hotel["hotel_data"]
location = hotel_data["location"]
district = "None"
if hotel_data["district_id"] in self.districtNames:
district = self.districtNames[hotel_data["district_id"]]
if hotel_data["district_id"] in self.district_names:
district = self.district_names[hotel_data["district_id"]]
row = (
hotel["hotel_id"],
hotel_data["address"],
@ -51,9 +51,20 @@ class BookingGen:
)
return sep.join(BookingGen._format_string(str(x)) for x in row)
def create_tsv_header(sep="\t"):
row = (
"Hotel ID",
"Address",
"ZIP",
"City",
"District",
"Country",
)
return sep.join(x for x in row)
def download_hotels_by_country(api, districtNames, country):
generator = BookingGen(api, country, districtNames)
def download_hotels_by_country(api, district_names, country):
generator = BookingGen(api, country, district_names)
rows = list(generator.generate_tsv_rows())
logging.info(f"For {country['name']} {len(rows)} lines were generated.")
return rows
@ -65,11 +76,11 @@ def download_test_data(country_code, user, password, path, threads_count,
api = BookingApi(user, password, "2.4")
list_api = BookingListApi(api)
districts = list_api.districts(languages="en")
districtNames = {}
district_names = {}
for district in districts:
for translation in district['translations']:
if translation['language'] == 'en':
districtNames[district['district_id']] = translation['name']
for translation in district['translations']:
if translation['language'] == 'en':
district_names[district['district_id']] = translation['name']
countries = list_api.countries(languages="en")
if country_code is not None:
countries = list(filter(lambda x: x["country"] in country_code, countries))
@ -77,8 +88,9 @@ def download_test_data(country_code, user, password, path, threads_count,
progress_bar.desc = "Countries"
progress_bar.total = len(countries)
with open(path, "w") as f:
f.write(create_tsv_header() + "\n")
with ThreadPool(threads_count) as pool:
for lines in pool.imap_unordered(partial(download_hotels_by_country, list_api, districtNames),
for lines in pool.imap_unordered(partial(download_hotels_by_country, list_api, district_names),
countries):
f.writelines([f"{x}\n" for x in lines])
progress_bar.update()