diff --git a/tools/python/booking/__main__.py b/tools/python/booking/__main__.py index eb62827f14..8537fb90dc 100644 --- a/tools/python/booking/__main__.py +++ b/tools/python/booking/__main__.py @@ -8,6 +8,7 @@ from tqdm import tqdm from .api.booking_api import LIMIT_REQUESTS_PER_MINUTE from .download_hotels import download +from .download_test_data import download_test_data def process_options(): @@ -25,6 +26,8 @@ def process_options(): help="Name and destination for output file") parser.add_argument("--country_code", default=None, action="append", help="Download hotels of this country.") + parser.add_argument("--download_test_dataset", default=False, + help="Download dataset for tests.") options = parser.parse_args() return options @@ -44,8 +47,12 @@ def main(): logging.basicConfig(level=logging.DEBUG, filename=logfile, format="%(thread)d [%(asctime)s] %(levelname)s: %(message)s") with tqdm(disable=not options.verbose) as progress_bar: - download(options.country_code, options.user, options.password, - options.output, options.threads_count, progress_bar) + if options.download_test_dataset: + download_test_data(options.country_code, options.user, options.password, + options.output, options.threads_count, progress_bar) + else: + download(options.country_code, options.user, options.password, + options.output, options.threads_count, progress_bar) main() diff --git a/tools/python/booking/api/booking_api.py b/tools/python/booking/api/booking_api.py index 4a854ac657..fa990e7aa6 100644 --- a/tools/python/booking/api/booking_api.py +++ b/tools/python/booking/api/booking_api.py @@ -17,7 +17,8 @@ MINMAX_LIMIT_WAIT_AFTER_429_ERROR_SECONDS = (30, 120) class BookingApi: ENDPOINTS = { "countries": "list", - "hotels": "list" + "hotels": "list", + "districts": "list" } def __init__(self, login, password, version): @@ -116,4 +117,4 @@ class BookingListApi: def _set_endpoints(self): for endpoint in BookingApi.ENDPOINTS: if BookingApi.ENDPOINTS[endpoint] == "list": - setattr(self, endpoint, partial(self.call_endpoint, endpoint)) \ No newline at end of file + setattr(self, endpoint, partial(self.call_endpoint, endpoint)) diff --git a/tools/python/booking/download_test_data.py b/tools/python/booking/download_test_data.py new file mode 100755 index 0000000000..393a7305ac --- /dev/null +++ b/tools/python/booking/download_test_data.py @@ -0,0 +1,85 @@ +import logging +import statistics +from functools import partial +from multiprocessing.pool import ThreadPool + +import math +from tqdm import tqdm + +from .api.booking_api import BookingApi, BookingListApi + +SUPPORTED_LANGUAGES = ("en", "ru") + + +class BookingGen: + def __init__(self, api, country, districtNames): + self.api = api + self.country_code = country["country"] + self.country_name = country["name"] + self.districtNames = districtNames + logging.info(f"Download[{self.country_code}]: {self.country_name}") + + extras = ["hotel_info"] + self.hotels = self._download_hotels(extras=extras) + + def generate_tsv_rows(self, sep="\t"): + return (self._create_tsv_hotel_line(hotel, sep) for hotel in self.hotels) + + @staticmethod + def _format_string(s): + s = s.strip() + for x in (("\t", " "), ("\n", " "), ("\r", "")): + s = s.replace(*x) + return s + + def _download_hotels(self, **params): + return self.api.hotels(country_ids=self.country_code, **params) + + def _create_tsv_hotel_line(self, hotel, sep="\t"): + hotel_data = hotel["hotel_data"] + location = hotel_data["location"] + district = "None" + if hotel_data["district_id"] in self.districtNames: + district = self.districtNames[hotel_data["district_id"]] + row = ( + hotel["hotel_id"], + hotel_data["address"], + hotel_data["zip"], + hotel_data["city"], + district, + self.country_name, + ) + return sep.join(BookingGen._format_string(str(x)) for x in row) + + +def download_hotels_by_country(api, districtNames, country): + generator = BookingGen(api, country, districtNames) + rows = list(generator.generate_tsv_rows()) + logging.info(f"For {country['name']} {len(rows)} lines were generated.") + return rows + + +def download_test_data(country_code, user, password, path, threads_count, + progress_bar=tqdm(disable=True)): + logging.info(f"Starting test dataset download.") + api = BookingApi(user, password, "2.4") + list_api = BookingListApi(api) + districts = list_api.districts(languages="en") + districtNames = {} + for district in districts: + for translation in district['translations']: + if translation['language'] == 'en': + districtNames[district['district_id']] = translation['name'] + countries = list_api.countries(languages="en") + if country_code is not None: + countries = list(filter(lambda x: x["country"] in country_code, countries)) + logging.info(f"There are {len(countries)} countries.") + progress_bar.desc = "Countries" + progress_bar.total = len(countries) + with open(path, "w") as f: + with ThreadPool(threads_count) as pool: + for lines in pool.imap_unordered(partial(download_hotels_by_country, list_api, districtNames), + countries): + f.writelines([f"{x}\n" for x in lines]) + progress_bar.update() + logging.info(f"Hotels test dataset saved to {path}.")