diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 2c5434e..b735261 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -26,7 +26,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install flake8==6.0.0 black==23.1.0 + pip install flake8==6.0.0 black==23.1.0 shapely==2.0.1 if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Lint with flake8 run: | diff --git a/make_all_metro_poly.py b/make_all_metro_poly.py index 00281a7..e8450a2 100644 --- a/make_all_metro_poly.py +++ b/make_all_metro_poly.py @@ -1,18 +1,21 @@ import argparse -import shapely.geometry -import shapely.ops +from shapely import unary_union +from shapely.geometry import MultiPolygon, Polygon from process_subways import DEFAULT_CITIES_INFO_URL, get_cities_info def make_disjoint_metro_polygons(cities_info_url: str) -> None: + """Make disjoint polygon from cities bboxes and write them + in *.poly format to stdout. + """ cities_info = get_cities_info(cities_info_url) polygons = [] for ci in cities_info: bbox = tuple(map(float, ci["bbox"].split(","))) - polygon = shapely.geometry.Polygon( + polygon = Polygon( [ (bbox[0], bbox[1]), (bbox[0], bbox[3]), @@ -22,14 +25,17 @@ def make_disjoint_metro_polygons(cities_info_url: str) -> None: ) polygons.append(polygon) - union = shapely.ops.unary_union(polygons) + union = unary_union(polygons) + + if union.geom_type == "Polygon": + union = MultiPolygon([union]) print("all metro") - for i, polygon in enumerate(union, start=1): + for i, polygon in enumerate(union.geoms, start=1): assert len(polygon.interiors) == 0 print(i) - for point in polygon.exterior.coords: - print(" {lon} {lat}".format(lon=point[0], lat=point[1])) + for lon, lat in polygon.exterior.coords: + print(f" {lon} {lat}") print("END") print("END") diff --git a/scripts/process_subways.sh b/scripts/process_subways.sh index 42c4af6..a27f283 100755 --- a/scripts/process_subways.sh +++ b/scripts/process_subways.sh @@ -91,7 +91,7 @@ function check_poly() { if [ -z "${POLY-}" -o ! -f "${POLY-}" ]; then POLY=${POLY:-$(mktemp "$TMPDIR/all-metro.XXXXXXXX.poly")} if [ -n "$("$PYTHON" -c "import shapely" 2>&1)" ]; then - "$PYTHON" -m pip install shapely==1.7.1 + "$PYTHON" -m pip install shapely==2.0.1 fi "$PYTHON" "$SUBWAYS_PATH"/make_all_metro_poly.py \ ${CITIES_INFO_URL:+--cities-info-url "$CITIES_INFO_URL"} > "$POLY" diff --git a/tests/assets/cities_info_1city.csv b/tests/assets/cities_info_1city.csv new file mode 100644 index 0000000..c2b5b95 --- /dev/null +++ b/tests/assets/cities_info_1city.csv @@ -0,0 +1,2 @@ +#,City,Country,Region,Stations,Subway Lines,Light Rail +Monorail,Interchanges,"BBox (lon, lat)",Networks (opt.),Approved,Comment,Source +291,Moscow,Russia,Europe,351,14,3,68,"37.1667,55.3869,38.2626,56.0136","subway,train:Московский метрополитен;МЦК;МЦД" diff --git a/tests/assets/cities_info_2cities.csv b/tests/assets/cities_info_2cities.csv new file mode 100644 index 0000000..efd2c7d --- /dev/null +++ b/tests/assets/cities_info_2cities.csv @@ -0,0 +1,3 @@ +#,City,Country,Region,Stations,Subway Lines,Light Rail +Monorail,Interchanges,"BBox (lon, lat)",Networks (opt.),Approved,Comment,Source +313,London,UK,Europe,750,11,23,54,"-0.9747,51.1186,0.3315,51.8459","subway,train,light_rail:London Underground;London Overground;Docklands Light Railway;London Trams;Crossrail",,,https://tfl.gov.uk/maps/track/tube +291,Moscow,Russia,Europe,351,14,3,68,"37.1667,55.3869,38.2626,56.0136","subway,train:Московский метрополитен;МЦК;МЦД" diff --git a/tests/assets/networks_with_bad_values.csv b/tests/assets/cities_info_with_bad_values.csv similarity index 100% rename from tests/assets/networks_with_bad_values.csv rename to tests/assets/cities_info_with_bad_values.csv diff --git a/tests/test_make_all_metro_poly.py b/tests/test_make_all_metro_poly.py new file mode 100644 index 0000000..dac8dae --- /dev/null +++ b/tests/test_make_all_metro_poly.py @@ -0,0 +1,108 @@ +import contextlib +import io +import os +from unittest import TestCase + +from make_all_metro_poly import make_disjoint_metro_polygons + + +cases = [ + { + "csv_file": "cities_info_1city.csv", + "expected_stdout": """all metro +1 + 37.1667 55.3869 + 37.1667 56.0136 + 38.2626 56.0136 + 38.2626 55.3869 + 37.1667 55.3869 +END +END +""", + "shape_line_ranges": [ + { + "start": 2, + "end": 6, + }, + ], + }, + { + "csv_file": "cities_info_2cities.csv", + "expected_stdout": """all metro +1 + -0.9747 51.8459 + 0.3315 51.8459 + 0.3315 51.1186 + -0.9747 51.1186 + -0.9747 51.8459 +END +2 + 37.1667 56.0136 + 38.2626 56.0136 + 38.2626 55.3869 + 37.1667 55.3869 + 37.1667 56.0136 +END +END +""", + "shape_line_ranges": [ + { + "start": 2, + "end": 6, + }, + { + "start": 9, + "end": 13, + }, + ], + }, +] + + +class TestMakeAllMetroPoly(TestCase): + def test_make_disjoint_metro_polygons(self) -> None: + for case in cases: + with self.subTest(msg=case["csv_file"]): + file_url = ( + f"file://{os.getcwd()}/tests/assets/{case['csv_file']}" + ) + stream = io.StringIO() + with contextlib.redirect_stdout(stream): + make_disjoint_metro_polygons(file_url) + generated_poly = stream.getvalue() + expected_poly = case["expected_stdout"] + + # Since shapely may produce multipolygon with different order + # of polygons in it and different vertex order in a polygon, + # we should compare polygons/vertexes as sets. + + generated_poly_lines = generated_poly.split("\n") + expected_poly_lines = expected_poly.split("\n") + self.assertSetEqual( + set(expected_poly_lines), set(generated_poly_lines) + ) + + line_ranges = case["shape_line_ranges"] + + # Check that polygons are closed + for line_range in line_ranges: + self.assertEqual( + generated_poly_lines[line_range["start"]], + generated_poly_lines[line_range["end"]], + ) + + generated_points = [ + sorted( + generated_poly_lines[r["start"] : r["end"]] # noqa 203 + ) + for r in line_ranges + ] + expected_points = [ + sorted( + expected_poly_lines[r["start"] : r["end"]] # noqa 203 + ) + for r in line_ranges + ] + expected_points.sort() + generated_points.sort() + self.assertListEqual(expected_points, generated_points) diff --git a/tests/test_prepare_cities.py b/tests/test_prepare_cities.py index e74505f..63ddce6 100644 --- a/tests/test_prepare_cities.py +++ b/tests/test_prepare_cities.py @@ -10,7 +10,7 @@ class TestPrepareCities(TestCase): csv_path = ( Path(inspect.getfile(self.__class__)).parent / "assets" - / "networks_with_bad_values.csv" + / "cities_info_with_bad_values.csv" ) cities = prepare_cities(cities_info_url=f"file://{csv_path}")