From 36d947047e226c19c17343630721a8f95f495707 Mon Sep 17 00:00:00 2001 From: Alexey Zakharenkov Date: Tue, 18 Oct 2022 17:42:12 +0300 Subject: [PATCH] Defer CSV-complying data transformation to write-to-CSV phase --- .github/workflows/python-app.yml | 4 +- README.md | 2 +- processors/gtfs.py | 132 ++++++++++++++----------------- tests/test_gtfs_processor.py | 44 ++++++++++- 4 files changed, 102 insertions(+), 80 deletions(-) diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 343fd4a..fa8b992 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -19,10 +19,10 @@ jobs: steps: - uses: actions/checkout@v3 - - name: Set up Python 3.6 + - name: Set up Python 3.8 uses: actions/setup-python@v3 with: - python-version: "3.6" + python-version: "3.8" - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/README.md b/README.md index 9149825..ae5f2d8 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ a city's bbox has been extended. A single city or a country with few metro networks can be validated much faster if you allow the `process_subway.py` to fetch data from Overpass API. Here are the steps: -1. Python3 interpreter required (3.6+) +1. Python3 interpreter required (3.8+) 2. Clone the repo ``` git clone https://github.com/alexey-zakharenkov/subways.git subways_validator diff --git a/processors/gtfs.py b/processors/gtfs.py index 1ae5ee7..9094651 100644 --- a/processors/gtfs.py +++ b/processors/gtfs.py @@ -2,6 +2,7 @@ import csv import io import zipfile +from functools import partial from itertools import permutations from ._common import ( @@ -129,19 +130,6 @@ GTFS_COLUMNS = { } -def dict_to_row(dict_data, record_type): - """Given object stored in a dict and an array of columns, - returns a row to use in CSV. - """ - row = [] - for column in GTFS_COLUMNS[record_type]: - value = dict_data.get(column) - if value is None: - value = "" - row.append(value) - return row - - def round_coords(coords_tuple): return tuple( map(lambda coord: round(coord, COORDINATE_PRECISION), coords_tuple) @@ -162,21 +150,18 @@ def process(cities, transfers, filename, cache_path): gtfs_data = {key: [] for key in GTFS_COLUMNS.keys()} gtfs_data["calendar"].append( - dict_to_row( - { - "service_id": "always", - "monday": 1, - "tuesday": 1, - "wednesday": 1, - "thursday": 1, - "friday": 1, - "saturday": 1, - "sunday": 1, - "start_date": "19700101", - "end_date": "30000101", - }, - "calendar", - ) + { + "service_id": "always", + "monday": 1, + "tuesday": 1, + "wednesday": 1, + "thursday": 1, + "friday": 1, + "saturday": 1, + "sunday": 1, + "start_date": "19700101", + "end_date": "30000101", + } ) all_stops = {} # stop (stop area center or station) el_id -> stop data @@ -262,7 +247,7 @@ def process(cities, transfers, filename, cache_path): # agency, routes, trips, stop_times, frequencies, shapes for city in good_cities: agency = {"agency_id": city.id, "agency_name": city.name} - gtfs_data["agency"].append(dict_to_row(agency, "agency")) + gtfs_data["agency"].append(agency) for city_route in city: route = { @@ -273,7 +258,7 @@ def process(cities, transfers, filename, cache_path): "route_long_name": city_route.name, "route_color": format_colour(city_route.colour), } - gtfs_data["routes"].append(dict_to_row(route, "routes")) + gtfs_data["routes"].append(route) for variant in city_route: shape_id = variant.id[1:] # truncate leading 'r' @@ -283,7 +268,7 @@ def process(cities, transfers, filename, cache_path): "service_id": "always", "shape_id": shape_id, } - gtfs_data["trips"].append(dict_to_row(trip, "trips")) + gtfs_data["trips"].append(trip) tracks = variant.get_extended_tracks() tracks = variant.get_truncated_tracks(tracks) @@ -291,16 +276,13 @@ def process(cities, transfers, filename, cache_path): for i, (lon, lat) in enumerate(tracks): lon, lat = round_coords((lon, lat)) gtfs_data["shapes"].append( - dict_to_row( - { - "shape_id": shape_id, - "trip_id": variant.id, - "shape_pt_lat": lat, - "shape_pt_lon": lon, - "shape_pt_sequence": i, - }, - "shapes", - ) + { + "shape_id": shape_id, + "trip_id": variant.id, + "shape_pt_lat": lat, + "shape_pt_lon": lon, + "shape_pt_sequence": i, + } ) start_time = variant.start_time or DEFAULT_TRIP_START_TIME @@ -311,37 +293,29 @@ def process(cities, transfers, filename, cache_path): end_time = f"{end_time[0]:02d}:{end_time[1]:02d}:00" gtfs_data["frequencies"].append( - dict_to_row( - { - "trip_id": variant.id, - "start_time": start_time, - "end_time": end_time, - "headway_secs": variant.interval - or DEFAULT_INTERVAL, - }, - "frequencies", - ) + { + "trip_id": variant.id, + "start_time": start_time, + "end_time": end_time, + "headway_secs": variant.interval + or DEFAULT_INTERVAL, + } ) for stop_sequence, route_stop in enumerate(variant): gtfs_platform_id = add_stop_gtfs(route_stop, city) gtfs_data["stop_times"].append( - dict_to_row( - { - "trip_id": variant.id, - "stop_sequence": stop_sequence, - "shape_dist_traveled": route_stop.distance, - "stop_id": gtfs_platform_id, - }, - "stop_times", - ) + { + "trip_id": variant.id, + "stop_sequence": stop_sequence, + "shape_dist_traveled": route_stop.distance, + "stop_id": gtfs_platform_id, + } ) # stops - gtfs_data["stops"].extend( - map(lambda row: dict_to_row(row, "stops"), all_stops.values()) - ) + gtfs_data["stops"].extend(all_stops.values()) # transfers for stoparea_set in transfers: @@ -358,20 +332,27 @@ def process(cities, transfers, filename, cache_path): ) for id1, id2 in permutations((stop1_id, stop2_id)): gtfs_data["transfers"].append( - dict_to_row( - { - "from_stop_id": id1, - "to_stop_id": id2, - "transfer_type": 0, - "min_transfer_time": transfer_time, - }, - "transfers", - ) + { + "from_stop_id": id1, + "to_stop_id": id2, + "transfer_type": 0, + "min_transfer_time": transfer_time, + } ) make_gtfs(filename, gtfs_data) +def dict_to_row(dict_data: dict, record_type: str) -> list: + """Given object stored in a dict and an array of columns, + return a row to use in CSV. + """ + return [ + "" if (v := dict_data.get(column)) is None else v + for column in GTFS_COLUMNS[record_type] + ] + + def make_gtfs(filename, gtfs_data): if not filename.lower().endswith("zip"): filename = f"{filename}.zip" @@ -381,5 +362,10 @@ def make_gtfs(filename, gtfs_data): with io.StringIO(newline="") as string_io: writer = csv.writer(string_io, delimiter=",") writer.writerow(columns) - writer.writerows(gtfs_data[gtfs_feature]) + writer.writerows( + map( + partial(dict_to_row, record_type=gtfs_feature), + gtfs_data[gtfs_feature] + ) + ) zf.writestr(f"{gtfs_feature}.txt", string_io.getvalue()) diff --git a/tests/test_gtfs_processor.py b/tests/test_gtfs_processor.py index df0e2b1..ca206f9 100644 --- a/tests/test_gtfs_processor.py +++ b/tests/test_gtfs_processor.py @@ -1,4 +1,4 @@ -import unittest +from unittest import TestCase from processors.gtfs import ( dict_to_row, @@ -6,10 +6,10 @@ from processors.gtfs import ( ) -class TestGTFS(unittest.TestCase): +class TestGTFS(TestCase): """Test processors/gtfs.py""" - def test_dict_to_row(self): + def test__dict_to_row__Nones_and_absent_keys(self) -> None: """Test that absent or None values in a GTFS feature item are converted by dict_to_row() function to empty strings in right amount. @@ -55,6 +55,42 @@ class TestGTFS(unittest.TestCase): for test_trip in test_trips: with self.subTest(msg=test_trip["description"]): - self.assertEqual( + self.assertListEqual( dict_to_row(test_trip["trip_data"], "trips"), answer ) + + def test__dict_to_row__numeric_values(self) -> None: + """Test that zero numeric values remain zeros in dict_to_row() function, + and not empty strings or None. + """ + + shapes = [ + { + "description": "Numeric non-zeroes", + "shape_data": { + "shape_id": 1, + "shape_pt_lat": 55.3242425, + "shape_pt_lon": -179.23242, + "shape_pt_sequence": 133, + "shape_dist_traveled": 1.2345, + }, + "answer": [1, 55.3242425, -179.23242, 133, 1.2345], + }, + { + "description": "Numeric zeroes and None keys", + "shape_data": { + "shape_id": 0, + "shape_pt_lat": 0.0, + "shape_pt_lon": 0, + "shape_pt_sequence": 0, + "shape_dist_traveled": None, + }, + "answer": [0, 0.0, 0, 0, ""], + }, + ] + + for shape in shapes: + with self.subTest(shape["description"]): + self.assertListEqual( + dict_to_row(shape["shape_data"], "shapes"), shape["answer"] + )