diff --git a/subway_structure.py b/subway_structure.py
index 823aea6..bb38f85 100644
--- a/subway_structure.py
+++ b/subway_structure.py
@@ -673,7 +673,7 @@ class Route:
self.stops = [] # List of RouteStop
# Would be a list of (lon, lat) for the longest stretch. Can be empty.
self.tracks = None
- # Index of the fist stop that is located on/near the self.tracks
+ # Index of the first stop that is located on/near the self.tracks
self.first_stop_on_rails_index = None
# Index of the last stop that is located on/near the self.tracks
self.last_stop_on_rails_index = None
diff --git a/tests/README.md b/tests/README.md
new file mode 100644
index 0000000..d6da466
--- /dev/null
+++ b/tests/README.md
@@ -0,0 +1,13 @@
+To perform tests manually, run this command from the top directory
+of the repository:
+
+```bash
+python -m unittest discover tests
+```
+
+or simply
+
+```bash
+python -m unittest
+```
+
diff --git a/tests/assets/kuntsevskaya_centers.json b/tests/assets/kuntsevskaya_centers.json
deleted file mode 100644
index 36317ec..0000000
--- a/tests/assets/kuntsevskaya_centers.json
+++ /dev/null
@@ -1,28 +0,0 @@
-{
- "w38836456": {
- "lat": 55.73064775,
- "lon": 37.446065950000005
- },
- "w489951237": {
- "lat": 55.730760724999996,
- "lon": 37.44602055
- },
- "r7588527": {
- "lat": 55.73066371666667,
- "lon": 37.44604881666667
- },
- "r7588528": {
- "lat": 55.73075192499999,
- "lon": 37.44609837
- },
- "r7588561": {
- "lat": 55.73070782083333,
- "lon": 37.44607359333334
- },
- "r13426423": {
- "lat": 55.730760724999996,
- "lon": 37.44602055
- },
- "r100": null,
- "r101": null
-}
diff --git a/tests/assets/tiny_world.osm b/tests/assets/tiny_world.osm
new file mode 100644
index 0000000..6ee2096
--- /dev/null
+++ b/tests/assets/tiny_world.osm
@@ -0,0 +1,217 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/tests/assets/tiny_world_gtfs.zip b/tests/assets/tiny_world_gtfs.zip
new file mode 100644
index 0000000..ef7a66a
Binary files /dev/null and b/tests/assets/tiny_world_gtfs.zip differ
diff --git a/tests/sample_data_for_build_tracks.py b/tests/sample_data_for_build_tracks.py
index ed1b589..4436db2 100644
--- a/tests/sample_data_for_build_tracks.py
+++ b/tests/sample_data_for_build_tracks.py
@@ -1,5 +1,6 @@
-sample_networks = {
- "Only 2 stations, no rails": {
+metro_samples = [
+ {
+ "name": "Only 2 stations, no rails",
"xml": """
@@ -37,7 +38,11 @@ sample_networks = {
""",
- "num_stations": 2,
+ "cities_info": [
+ {
+ "num_stations": 2,
+ },
+ ],
"tracks": [],
"extended_tracks": [
(0.0, 0.0),
@@ -55,7 +60,8 @@ sample_networks = {
"positions_on_rails": [],
},
},
- "Only 2 stations connected with rails": {
+ {
+ "name": "Only 2 stations connected with rails",
"xml": """
@@ -100,7 +106,11 @@ sample_networks = {
""",
- "num_stations": 2,
+ "cities_info": [
+ {
+ "num_stations": 2,
+ },
+ ],
"tracks": [
(0.0, 0.0),
(1.0, 0.0),
@@ -124,7 +134,8 @@ sample_networks = {
"positions_on_rails": [[0], [1]],
},
},
- "Only 6 stations, no rails": {
+ {
+ "name": "Only 6 stations, no rails",
"xml": """
@@ -190,7 +201,11 @@ sample_networks = {
""",
- "num_stations": 6,
+ "cities_info": [
+ {
+ "num_stations": 6,
+ },
+ ],
"tracks": [],
"extended_tracks": [
(0.0, 0.0),
@@ -212,7 +227,8 @@ sample_networks = {
"positions_on_rails": [],
},
},
- "One rail line connecting all stations": {
+ {
+ "name": "One rail line connecting all stations",
"xml": """
@@ -289,7 +305,11 @@ sample_networks = {
""",
- "num_stations": 6,
+ "cities_info": [
+ {
+ "num_stations": 6,
+ },
+ ],
"tracks": [
(0.0, 0.0),
(1.0, 0.0),
@@ -325,7 +345,8 @@ sample_networks = {
"positions_on_rails": [[0], [1], [2], [3], [4], [5]],
},
},
- "One rail line connecting all stations except the last": {
+ {
+ "name": "One rail line connecting all stations except the last",
"xml": """
@@ -401,7 +422,11 @@ sample_networks = {
""",
- "num_stations": 6,
+ "cities_info": [
+ {
+ "num_stations": 6,
+ },
+ ],
"tracks": [
(0.0, 0.0),
(1.0, 0.0),
@@ -435,7 +460,8 @@ sample_networks = {
"positions_on_rails": [[0], [1], [2], [3], [4]],
},
},
- "One rail line connecting all stations except the fist": {
+ {
+ "name": "One rail line connecting all stations except the first",
"xml": """
@@ -511,7 +537,11 @@ sample_networks = {
""",
- "num_stations": 6,
+ "cities_info": [
+ {
+ "num_stations": 6,
+ },
+ ],
"tracks": [
(1.0, 0.0),
(2.0, 0.0),
@@ -545,7 +575,11 @@ sample_networks = {
"positions_on_rails": [[0], [1], [2], [3], [4]],
},
},
- "One rail line connecting all stations except the fist and the last": {
+ {
+ "name": (
+ "One rail line connecting all stations "
+ "except the first and the last",
+ ),
"xml": """
@@ -620,7 +654,11 @@ sample_networks = {
""",
- "num_stations": 6,
+ "cities_info": [
+ {
+ "num_stations": 6,
+ },
+ ],
"tracks": [
(1.0, 0.0),
(2.0, 0.0),
@@ -652,7 +690,8 @@ sample_networks = {
"positions_on_rails": [[0], [1], [2], [3]],
},
},
- "One rail line connecting only 2 first stations": {
+ {
+ "name": "One rail line connecting only 2 first stations",
"xml": """
@@ -725,7 +764,11 @@ sample_networks = {
""",
- "num_stations": 6,
+ "cities_info": [
+ {
+ "num_stations": 6,
+ },
+ ],
"tracks": [
(0.0, 0.0),
(1.0, 0.0),
@@ -753,7 +796,8 @@ sample_networks = {
"positions_on_rails": [[0], [1]],
},
},
- "One rail line connecting only 2 last stations": {
+ {
+ "name": "One rail line connecting only 2 last stations",
"xml": """
@@ -826,7 +870,11 @@ sample_networks = {
""",
- "num_stations": 6,
+ "cities_info": [
+ {
+ "num_stations": 6,
+ },
+ ],
"tracks": [
(4.0, 0.0),
(5.0, 0.0),
@@ -854,7 +902,8 @@ sample_networks = {
"positions_on_rails": [[0], [1]],
},
},
- "One rail connecting all stations and protruding at both ends": {
+ {
+ "name": "One rail connecting all stations and protruding at both ends",
"xml": """
@@ -937,7 +986,11 @@ sample_networks = {
""",
- "num_stations": 6,
+ "cities_info": [
+ {
+ "num_stations": 6,
+ },
+ ],
"tracks": [
(-1.0, 0.0),
(0.0, 0.0),
@@ -977,10 +1030,11 @@ sample_networks = {
"positions_on_rails": [[1], [2], [3], [4], [5], [6]],
},
},
- (
- "Several rails with reversed order for backward route, "
- "connecting all stations and protruding at both ends"
- ): {
+ {
+ "name": (
+ "Several rails with reversed order for backward route, "
+ "connecting all stations and protruding at both ends"
+ ),
"xml": """
@@ -1069,7 +1123,11 @@ sample_networks = {
""",
- "num_stations": 6,
+ "cities_info": [
+ {
+ "num_stations": 6,
+ },
+ ],
"tracks": [
(-1.0, 0.0),
(0.0, 0.0),
@@ -1109,10 +1167,11 @@ sample_networks = {
"positions_on_rails": [[1], [2], [3], [4], [5], [6]],
},
},
- (
- "One rail laying near all stations requiring station projecting, "
- "protruding at both ends"
- ): {
+ {
+ "name": (
+ "One rail laying near all stations requiring station projecting, "
+ "protruding at both ends"
+ ),
"xml": """
@@ -1189,7 +1248,11 @@ sample_networks = {
""",
- "num_stations": 6,
+ "cities_info": [
+ {
+ "num_stations": 6,
+ },
+ ],
"tracks": [
(-1.0, 0.0),
(6.0, 0.0),
@@ -1227,7 +1290,8 @@ sample_networks = {
],
},
},
- "One rail laying near all stations except the first and last": {
+ {
+ "name": "One rail laying near all stations except the first and last",
"xml": """
@@ -1304,7 +1368,11 @@ sample_networks = {
""",
- "num_stations": 6,
+ "cities_info": [
+ {
+ "num_stations": 6,
+ },
+ ],
"tracks": [
(1.0, 0.0),
(4.0, 0.0),
@@ -1330,7 +1398,8 @@ sample_networks = {
"positions_on_rails": [[0], [1 / 3], [2 / 3], [1]],
},
},
- "Circle route without rails": {
+ {
+ "name": "Circle route without rails",
"xml": """
@@ -1377,7 +1446,11 @@ sample_networks = {
""",
- "num_stations": 4,
+ "cities_info": [
+ {
+ "num_stations": 4,
+ },
+ ],
"tracks": [],
"extended_tracks": [
(0.0, 0.0),
@@ -1398,7 +1471,8 @@ sample_networks = {
"positions_on_rails": [],
},
},
- "Circle route with closed rail line connecting all stations": {
+ {
+ "name": "Circle route with closed rail line connecting all stations",
"xml": """
@@ -1455,7 +1529,11 @@ sample_networks = {
""",
- "num_stations": 4,
+ "cities_info": [
+ {
+ "num_stations": 4,
+ },
+ ],
"tracks": [
(0.0, 0.0),
(0.0, 1.0),
@@ -1488,4 +1566,4 @@ sample_networks = {
"positions_on_rails": [[0, 4], [1], [2], [3], [0, 4]],
},
},
-}
+]
diff --git a/tests/assets/kuntsevskaya_transfer.osm b/tests/sample_data_for_center_calculation.py
similarity index 84%
rename from tests/assets/kuntsevskaya_transfer.osm
rename to tests/sample_data_for_center_calculation.py
index 48bf044..49ab368 100644
--- a/tests/assets/kuntsevskaya_transfer.osm
+++ b/tests/sample_data_for_center_calculation.py
@@ -1,4 +1,7 @@
-
+metro_samples = [
+ {
+ "name": "Transfer at Kuntsevskaya",
+ "xml": """
@@ -80,3 +83,16 @@
+""", # noqa: E501
+ "expected_centers": {
+ "w38836456": {"lat": 55.73064775, "lon": 37.446065950000005},
+ "w489951237": {"lat": 55.730760724999996, "lon": 37.44602055},
+ "r7588527": {"lat": 55.73066371666667, "lon": 37.44604881666667},
+ "r7588528": {"lat": 55.73075192499999, "lon": 37.44609837},
+ "r7588561": {"lat": 55.73070782083333, "lon": 37.44607359333334},
+ "r13426423": {"lat": 55.730760724999996, "lon": 37.44602055},
+ "r100": None,
+ "r101": None,
+ },
+ },
+]
diff --git a/tests/sample_data_for_error_messages.py b/tests/sample_data_for_error_messages.py
index 9d5c5fc..9bea1c7 100644
--- a/tests/sample_data_for_error_messages.py
+++ b/tests/sample_data_for_error_messages.py
@@ -1,5 +1,6 @@
-sample_networks = {
- "No errors": {
+metro_samples = [
+ {
+ "name": "No errors",
"xml": """
@@ -38,7 +39,11 @@ sample_networks = {
""",
- "num_stations": 2,
+ "cities_info": [
+ {
+ "num_stations": 2,
+ },
+ ],
"num_lines": 1,
"num_light_lines": 0,
"num_interchanges": 0,
@@ -46,7 +51,8 @@ sample_networks = {
"warnings": [],
"notices": [],
},
- "Bad station order": {
+ {
+ "name": "Bad station order",
"xml": """
@@ -99,7 +105,11 @@ sample_networks = {
""",
- "num_stations": 4,
+ "cities_info": [
+ {
+ "num_stations": 4,
+ },
+ ],
"num_lines": 1,
"num_light_lines": 0,
"num_interchanges": 0,
@@ -112,7 +122,8 @@ sample_networks = {
"warnings": [],
"notices": [],
},
- "Angle < 20 degrees": {
+ {
+ "name": "Angle < 20 degrees",
"xml": """
@@ -159,7 +170,11 @@ sample_networks = {
""",
- "num_stations": 3,
+ "cities_info": [
+ {
+ "num_stations": 3,
+ },
+ ],
"num_lines": 1,
"num_light_lines": 0,
"num_interchanges": 0,
@@ -172,7 +187,8 @@ sample_networks = {
"warnings": [],
"notices": [],
},
- "Angle between 20 and 45 degrees": {
+ {
+ "name": "Angle between 20 and 45 degrees",
"xml": """
@@ -219,7 +235,11 @@ sample_networks = {
""",
- "num_stations": 3,
+ "cities_info": [
+ {
+ "num_stations": 3,
+ },
+ ],
"num_lines": 1,
"num_light_lines": 0,
"num_interchanges": 0,
@@ -232,7 +252,8 @@ sample_networks = {
'is too narrow, 27 degrees (relation 2, "Backward")',
],
},
- "Stops unordered along tracks provided each angle > 45 degrees": {
+ {
+ "name": "Unordered stops provided each angle > 45 degrees",
"xml": """
@@ -300,7 +321,11 @@ sample_networks = {
""",
- "num_stations": 4,
+ "cities_info": [
+ {
+ "num_stations": 4,
+ },
+ ],
"num_lines": 1,
"num_light_lines": 0,
"num_interchanges": 0,
@@ -313,4 +338,4 @@ sample_networks = {
"warnings": [],
"notices": [],
},
-}
+]
diff --git a/tests/sample_data_for_outputs.py b/tests/sample_data_for_outputs.py
new file mode 100644
index 0000000..3c2a590
--- /dev/null
+++ b/tests/sample_data_for_outputs.py
@@ -0,0 +1,345 @@
+metro_samples = [
+ {
+ "name": "tiny_world",
+ "xml_file": """assets/tiny_world.osm""",
+ "cities_info": [
+ {
+ "id": 1,
+ "name": "Intersecting 2 metro lines",
+ "country": "World",
+ "continent": "Africa",
+ "num_stations": 6,
+ "num_lines": 2,
+ "num_light_lines": 0,
+ "num_interchanges": 1,
+ "bbox": "-179, -89, 179, 89",
+ "networks": "network-1",
+ },
+ {
+ "id": 2,
+ "name": "One light rail line",
+ "country": "World",
+ "continent": "Africa",
+ "num_stations": 2,
+ "num_lines": 0,
+ "num_light_lines": 1,
+ "num_interchanges": 0,
+ "bbox": "-179, -89, 179, 89",
+ "networks": "network-2",
+ },
+ ],
+ "gtfs_file": "assets/tiny_world_gtfs.zip",
+ "json_dump": """
+{
+ "stopareas": {
+ "n1": {
+ "id": "n1",
+ "center": [
+ 0,
+ 0
+ ],
+ "name": "Station 1",
+ "entrances": []
+ },
+ "r1": {
+ "id": "r1",
+ "center": [
+ 0.00470373068,
+ 0.0047037307
+ ],
+ "name": "Station 2",
+ "entrances": []
+ },
+ "r3": {
+ "id": "r3",
+ "center": [
+ 0.01012040581,
+ 0.0097589171
+ ],
+ "name": "Station 3",
+ "entrances": []
+ },
+ "n4": {
+ "id": "n4",
+ "center": [
+ 0,
+ 0.01
+ ],
+ "name": "Station 4",
+ "entrances": []
+ },
+ "r2": {
+ "id": "r2",
+ "center": [
+ 0.0047718624,
+ 0.00514739839
+ ],
+ "name": "Station 5",
+ "entrances": []
+ },
+ "n6": {
+ "id": "n6",
+ "center": [
+ 0.01,
+ 0
+ ],
+ "name": "Station 6",
+ "entrances": []
+ },
+ "r4": {
+ "id": "r4",
+ "center": [
+ 0.009716854315,
+ 0.010286367745
+ ],
+ "name": "Station 7",
+ "entrances": []
+ },
+ "r16": {
+ "id": "r16",
+ "center": [
+ 0.012405493905,
+ 0.014377764559999999
+ ],
+ "name": "Station 8",
+ "entrances": []
+ }
+ },
+ "networks": {
+ "Intersecting 2 metro lines": {
+ "id": 1,
+ "name": "Intersecting 2 metro lines",
+ "routes": [
+ {
+ "id": "r15",
+ "mode": "subway",
+ "ref": "1",
+ "name": "Blue Line",
+ "colour": "#0000ff",
+ "infill": null,
+ "itineraries": [
+ {
+ "id": "r7",
+ "tracks": [
+ [
+ 0,
+ 0
+ ],
+ [
+ 0.00470373068,
+ 0.0047037307
+ ],
+ [
+ 0.009939661455227341,
+ 0.009939661455455193
+ ]
+ ],
+ "start_time": null,
+ "end_time": null,
+ "interval": null,
+ "stops": [
+ {
+ "stoparea_id": "n1",
+ "distance": 0
+ },
+ {
+ "stoparea_id": "r1",
+ "distance": 741
+ },
+ {
+ "stoparea_id": "r3",
+ "distance": 1565
+ }
+ ]
+ },
+ {
+ "id": "r8",
+ "tracks": [
+ [
+ 0.009939661455227341,
+ 0.009939661455455193
+ ],
+ [
+ 0.00470373068,
+ 0.0047037307
+ ],
+ [
+ 0,
+ 0
+ ]
+ ],
+ "start_time": null,
+ "end_time": null,
+ "interval": null,
+ "stops": [
+ {
+ "stoparea_id": "r3",
+ "distance": 0
+ },
+ {
+ "stoparea_id": "r1",
+ "distance": 824
+ },
+ {
+ "stoparea_id": "n1",
+ "distance": 1565
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "id": "r14",
+ "mode": "subway",
+ "ref": "2",
+ "name": "Red Line",
+ "colour": "#ff0000",
+ "infill": null,
+ "itineraries": [
+ {
+ "id": "r12",
+ "tracks": [
+ [
+ 0,
+ 0.01
+ ],
+ [
+ 0.01,
+ 0
+ ]
+ ],
+ "start_time": null,
+ "end_time": null,
+ "interval": null,
+ "stops": [
+ {
+ "stoparea_id": "n4",
+ "distance": 0
+ },
+ {
+ "stoparea_id": "r2",
+ "distance": 758
+ },
+ {
+ "stoparea_id": "n6",
+ "distance": 1575
+ }
+ ]
+ },
+ {
+ "id": "r13",
+ "tracks": [
+ [
+ 0.01,
+ 0
+ ],
+ [
+ 0,
+ 0.01
+ ]
+ ],
+ "start_time": null,
+ "end_time": null,
+ "interval": null,
+ "stops": [
+ {
+ "stoparea_id": "n6",
+ "distance": 0
+ },
+ {
+ "stoparea_id": "r2",
+ "distance": 817
+ },
+ {
+ "stoparea_id": "n4",
+ "distance": 1575
+ }
+ ]
+ }
+ ]
+ }
+ ]
+ },
+ "One light rail line": {
+ "id": 2,
+ "name": "One light rail line",
+ "routes": [
+ {
+ "id": "r11",
+ "mode": "light_rail",
+ "ref": "LR",
+ "name": "LR Line",
+ "colour": "#a52a2a",
+ "infill": "#ffffff",
+ "itineraries": [
+ {
+ "id": "r9",
+ "tracks": [
+ [
+ 0.00976752835,
+ 0.01025306758
+ ],
+ [
+ 0.01245616794,
+ 0.01434446439
+ ]
+ ],
+ "start_time": null,
+ "end_time": null,
+ "interval": null,
+ "stops": [
+ {
+ "stoparea_id": "r4",
+ "distance": 0
+ },
+ {
+ "stoparea_id": "r16",
+ "distance": 545
+ }
+ ]
+ },
+ {
+ "id": "r10",
+ "tracks": [
+ [
+ 0.012321033122529725,
+ 0.014359650255679167
+ ],
+ [
+ 0.00966618028,
+ 0.01031966791
+ ]
+ ],
+ "start_time": null,
+ "end_time": null,
+ "interval": null,
+ "stops": [
+ {
+ "stoparea_id": "r16",
+ "distance": 0
+ },
+ {
+ "stoparea_id": "r4",
+ "distance": 538
+ }
+ ]
+ }
+ ]
+ }
+ ]
+ }
+ },
+ "transfers": [
+ [
+ "r1",
+ "r2"
+ ],
+ [
+ "r3",
+ "r4"
+ ]
+ ]
+}
+""",
+ },
+]
diff --git a/tests/test_build_tracks.py b/tests/test_build_tracks.py
index 14ea86b..a1b6a6c 100644
--- a/tests/test_build_tracks.py
+++ b/tests/test_build_tracks.py
@@ -1,24 +1,13 @@
-"""
-To perform tests manually, run this command from the top directory
-of the repository:
-
-> python -m unittest discover tests
-
-or simply
-
-> python -m unittest
-"""
-
-
-from tests.sample_data_for_build_tracks import sample_networks
+from tests.sample_data_for_build_tracks import metro_samples
from tests.util import TestCase
class TestOneRouteTracks(TestCase):
"""Test tracks extending and truncating on one-route networks"""
- def prepare_city_routes(self, network) -> tuple:
- city = self.validate_city(network)
+ def prepare_city_routes(self, metro_sample: dict) -> tuple:
+ cities, transfers = self.prepare_cities(metro_sample)
+ city = cities[0]
self.assertTrue(city.is_good)
@@ -30,56 +19,56 @@ class TestOneRouteTracks(TestCase):
return fwd_route, bwd_route
- def _test_tracks_extending_for_network(self, network_data):
- fwd_route, bwd_route = self.prepare_city_routes(network_data)
+ def _test_tracks_extending_for_network(self, metro_sample: dict) -> None:
+ fwd_route, bwd_route = self.prepare_city_routes(metro_sample)
self.assertEqual(
fwd_route.tracks,
- network_data["tracks"],
+ metro_sample["tracks"],
"Wrong tracks",
)
extended_tracks = fwd_route.get_extended_tracks()
self.assertEqual(
extended_tracks,
- network_data["extended_tracks"],
+ metro_sample["extended_tracks"],
"Wrong tracks after extending",
)
self.assertEqual(
bwd_route.tracks,
- network_data["tracks"][::-1],
+ metro_sample["tracks"][::-1],
"Wrong backward tracks",
)
extended_tracks = bwd_route.get_extended_tracks()
self.assertEqual(
extended_tracks,
- network_data["extended_tracks"][::-1],
+ metro_sample["extended_tracks"][::-1],
"Wrong backward tracks after extending",
)
- def _test_tracks_truncating_for_network(self, network_data):
- fwd_route, bwd_route = self.prepare_city_routes(network_data)
+ def _test_tracks_truncating_for_network(self, metro_sample: dict) -> None:
+ fwd_route, bwd_route = self.prepare_city_routes(metro_sample)
truncated_tracks = fwd_route.get_truncated_tracks(fwd_route.tracks)
self.assertEqual(
truncated_tracks,
- network_data["truncated_tracks"],
+ metro_sample["truncated_tracks"],
"Wrong tracks after truncating",
)
truncated_tracks = bwd_route.get_truncated_tracks(bwd_route.tracks)
self.assertEqual(
truncated_tracks,
- network_data["truncated_tracks"][::-1],
+ metro_sample["truncated_tracks"][::-1],
"Wrong backward tracks after truncating",
)
- def _test_stop_positions_on_rails_for_network(self, network_data):
- fwd_route, bwd_route = self.prepare_city_routes(network_data)
+ def _test_stop_positions_on_rails_for_network(self, sample: dict) -> None:
+ fwd_route, bwd_route = self.prepare_city_routes(sample)
for route, route_label in zip(
(fwd_route, bwd_route), ("forward", "backward")
):
- route_data = network_data[route_label]
+ route_data = sample[route_label]
for attr in (
"first_stop_on_rails_index",
@@ -97,21 +86,27 @@ class TestOneRouteTracks(TestCase):
rs.positions_on_rails
for rs in route.stops[first_ind : last_ind + 1] # noqa E203
]
- self.assertListAlmostEqual(
+ self.assertSequenceAlmostEqual(
positions_on_rails, route_data["positions_on_rails"]
)
def test_tracks_extending(self) -> None:
- for network_name, network_data in sample_networks.items():
- with self.subTest(msg=network_name):
- self._test_tracks_extending_for_network(network_data)
+ for sample in metro_samples:
+ sample_name = sample["name"]
+ sample["cities_info"][0]["name"] = sample_name
+ with self.subTest(msg=sample_name):
+ self._test_tracks_extending_for_network(sample)
def test_tracks_truncating(self) -> None:
- for network_name, network_data in sample_networks.items():
- with self.subTest(msg=network_name):
- self._test_tracks_truncating_for_network(network_data)
+ for sample in metro_samples:
+ sample_name = sample["name"]
+ sample["cities_info"][0]["name"] = sample_name
+ with self.subTest(msg=sample_name):
+ self._test_tracks_truncating_for_network(sample)
def test_stop_position_on_rails(self) -> None:
- for network_name, network_data in sample_networks.items():
- with self.subTest(msg=network_name):
- self._test_stop_positions_on_rails_for_network(network_data)
+ for sample in metro_samples:
+ sample_name = sample["name"]
+ sample["cities_info"][0]["name"] = sample_name
+ with self.subTest(msg=sample_name):
+ self._test_stop_positions_on_rails_for_network(sample)
diff --git a/tests/test_center_calculation.py b/tests/test_center_calculation.py
index 4f01a3c..0e42360 100644
--- a/tests/test_center_calculation.py
+++ b/tests/test_center_calculation.py
@@ -1,28 +1,28 @@
-import json
-from pathlib import Path
+import io
from unittest import TestCase
from process_subways import calculate_centers
from subway_io import load_xml
+from tests.sample_data_for_center_calculation import metro_samples
class TestCenterCalculation(TestCase):
"""Test center calculation. Test data [should] contain among others
the following edge cases:
- - an empty relation. It's element should not obtain "center" key.
- - relation as member of relation, the child relation following the parent
- in the OSM XML file.
+ - an empty relation. Its element should not obtain "center" key.
+ - relation as member of another relation, the child relation following
+ the parent in the OSM XML.
- relation with incomplete members (broken references).
- relations with cyclic references.
"""
- ASSETS_PATH = Path(__file__).resolve().parent / "assets"
- OSM_DATA = str(ASSETS_PATH / "kuntsevskaya_transfer.osm")
- CORRECT_CENTERS = str(ASSETS_PATH / "kuntsevskaya_centers.json")
-
- def test__calculate_centers(self) -> None:
- elements = load_xml(self.OSM_DATA)
+ def test_calculate_centers(self) -> None:
+ for sample in metro_samples:
+ with self.subTest(msg=sample["name"]):
+ self._test_calculate_centers_for_sample(sample)
+ def _test_calculate_centers_for_sample(self, metro_sample: dict) -> None:
+ elements = load_xml(io.BytesIO(metro_sample["xml"].encode()))
calculate_centers(elements)
elements_dict = {
@@ -36,12 +36,11 @@ class TestCenterCalculation(TestCase):
if "center" in el
}
- with open(self.CORRECT_CENTERS) as f:
- correct_centers = json.load(f)
+ expected_centers = metro_sample["expected_centers"]
- self.assertTrue(set(calculated_centers).issubset(correct_centers))
+ self.assertTrue(set(calculated_centers).issubset(expected_centers))
- for k, correct_center in correct_centers.items():
+ for k, correct_center in expected_centers.items():
if correct_center is None:
self.assertNotIn("center", elements_dict[k])
else:
diff --git a/tests/test_error_messages.py b/tests/test_error_messages.py
index 12a5583..aee6f48 100644
--- a/tests/test_error_messages.py
+++ b/tests/test_error_messages.py
@@ -1,4 +1,4 @@
-from tests.sample_data_for_error_messages import sample_networks
+from tests.sample_data_for_error_messages import metro_samples
from tests.util import TestCase
@@ -7,16 +7,19 @@ class TestValidationMessages(TestCase):
on different types of errors in input OSM data.
"""
- def _test_validation_messages_for_network(self, network_data):
- city = self.validate_city(network_data)
+ def _test_validation_messages_for_network(
+ self, metro_sample: dict
+ ) -> None:
+ cities, transfers = self.prepare_cities(metro_sample)
+ city = cities[0]
for err_level in ("errors", "warnings", "notices"):
self.assertListEqual(
sorted(getattr(city, err_level)),
- sorted(network_data[err_level]),
+ sorted(metro_sample[err_level]),
)
def test_validation_messages(self) -> None:
- for network_name, network_data in sample_networks.items():
- with self.subTest(msg=network_name):
- self._test_validation_messages_for_network(network_data)
+ for sample in metro_samples:
+ with self.subTest(msg=sample["name"]):
+ self._test_validation_messages_for_network(sample)
diff --git a/tests/test_gtfs_processor.py b/tests/test_gtfs_processor.py
index 5a234e8..86d1cac 100644
--- a/tests/test_gtfs_processor.py
+++ b/tests/test_gtfs_processor.py
@@ -1,9 +1,13 @@
-from unittest import TestCase
+import codecs
+import csv
+from functools import partial
+from pathlib import Path
+from zipfile import ZipFile
-from processors.gtfs import (
- dict_to_row,
- GTFS_COLUMNS,
-)
+from processors._common import transit_to_dict
+from processors.gtfs import dict_to_row, GTFS_COLUMNS, transit_data_to_gtfs
+from tests.util import TestCase
+from tests.sample_data_for_outputs import metro_samples
class TestGTFS(TestCase):
@@ -94,3 +98,62 @@ class TestGTFS(TestCase):
self.assertListEqual(
dict_to_row(shape["shape_data"], "shapes"), shape["answer"]
)
+
+ def test__transit_data_to_gtfs(self) -> None:
+ for metro_sample in metro_samples:
+ cities, transfers = self.prepare_cities(metro_sample)
+ calculated_transit_data = transit_to_dict(cities, transfers)
+ calculated_gtfs_data = transit_data_to_gtfs(
+ calculated_transit_data
+ )
+
+ control_gtfs_data = self._readGtfs(
+ Path(__file__).resolve().parent / metro_sample["gtfs_file"]
+ )
+ self._compareGtfs(calculated_gtfs_data, control_gtfs_data)
+
+ @staticmethod
+ def _readGtfs(filepath: str) -> dict:
+ gtfs_data = dict()
+ with ZipFile(filepath) as zf:
+ for gtfs_feature in GTFS_COLUMNS:
+ with zf.open(f"{gtfs_feature}.txt") as f:
+ reader = csv.reader(codecs.iterdecode(f, "utf-8"))
+ next(reader) # read header
+ rows = list(reader)
+ gtfs_data[gtfs_feature] = rows
+ return gtfs_data
+
+ def _compareGtfs(
+ self, calculated_gtfs_data: dict, control_gtfs_data: dict
+ ) -> None:
+ for gtfs_feature in GTFS_COLUMNS:
+ calculated_rows = sorted(
+ map(
+ partial(dict_to_row, record_type=gtfs_feature),
+ calculated_gtfs_data[gtfs_feature],
+ )
+ )
+ control_rows = sorted(control_gtfs_data[gtfs_feature])
+
+ self.assertEqual(len(calculated_rows), len(control_rows))
+
+ for i, (calculated_row, control_row) in enumerate(
+ zip(calculated_rows, control_rows)
+ ):
+ self.assertEqual(
+ len(calculated_row),
+ len(control_row),
+ f"Different length of {i}-th row of {gtfs_feature}",
+ )
+ for calculated_value, control_value in zip(
+ calculated_row, control_row
+ ):
+ if calculated_value is None:
+ self.assertEqual(control_value, "", f"in {i}-th row")
+ else: # convert str to float/int/str
+ self.assertAlmostEqual(
+ calculated_value,
+ type(calculated_value)(control_value),
+ places=10,
+ )
diff --git a/tests/test_storage.py b/tests/test_storage.py
new file mode 100644
index 0000000..978529f
--- /dev/null
+++ b/tests/test_storage.py
@@ -0,0 +1,26 @@
+import json
+
+from processors._common import transit_to_dict
+from tests.sample_data_for_outputs import metro_samples
+from tests.util import TestCase, TestTransitDataMixin
+
+
+class TestStorage(TestCase, TestTransitDataMixin):
+ def test_storage(self) -> None:
+ for sample in metro_samples:
+ with self.subTest(msg=sample["name"]):
+ self._test_storage_for_sample(sample)
+
+ def _test_storage_for_sample(self, metro_sample: dict) -> None:
+ cities, transfers = self.prepare_cities(metro_sample)
+
+ calculated_transit_data = transit_to_dict(cities, transfers)
+
+ control_transit_data = json.loads(metro_sample["json_dump"])
+ control_transit_data["transfers"] = set(
+ map(tuple, control_transit_data["transfers"])
+ )
+
+ self.compare_transit_data(
+ calculated_transit_data, control_transit_data
+ )
diff --git a/tests/util.py b/tests/util.py
index efab8c2..56b1962 100644
--- a/tests/util.py
+++ b/tests/util.py
@@ -1,15 +1,23 @@
import io
+from collections.abc import Sequence, Mapping
+from operator import itemgetter
+from pathlib import Path
+from typing import Any
from unittest import TestCase as unittestTestCase
+from process_subways import (
+ add_osm_elements_to_cities,
+ validate_cities,
+ calculate_centers,
+)
from subway_io import load_xml
-from subway_structure import City
+from subway_structure import City, find_transfers
class TestCase(unittestTestCase):
"""TestCase class for testing the Subway Validator"""
CITY_TEMPLATE = {
- "id": 1,
"name": "Null Island",
"country": "World",
"continent": "Africa",
@@ -21,29 +29,184 @@ class TestCase(unittestTestCase):
"num_interchanges": 0,
}
- def validate_city(self, network) -> City:
- city_data = self.CITY_TEMPLATE.copy()
- for attr in self.CITY_TEMPLATE.keys():
- if attr in network:
- city_data[attr] = network[attr]
+ @classmethod
+ def setUpClass(cls) -> None:
+ cls.city_class = City
- city = City(city_data)
- elements = load_xml(io.BytesIO(network["xml"].encode("utf-8")))
- for el in elements:
- city.add(el)
- city.extract_routes()
- city.validate()
- return city
+ def prepare_cities(self, metro_sample: dict) -> tuple:
+ """Load cities from file/string, validate them and return cities
+ and transfers.
+ """
- def assertListAlmostEqual(self, list1, list2, places=10) -> None:
- if not (isinstance(list1, list) and isinstance(list2, list)):
- raise RuntimeError(
- f"Not lists passed to the '{self.__class__.__name__}."
- "assertListAlmostEqual' method"
- )
- self.assertEqual(len(list1), len(list2))
- for a, b in zip(list1, list2):
- if isinstance(a, list) and isinstance(b, list):
- self.assertListAlmostEqual(a, b, places)
+ def assign_unique_id(city_info: dict, cities_info: list[dict]) -> None:
+ """city_info - newly added city, cities_info - already added
+ cities. Check city id uniqueness / assign unique id to the city.
+ """
+ occupied_ids = set(c["id"] for c in cities_info)
+ if "id" in city_info:
+ if city_info["id"] in occupied_ids:
+ raise RuntimeError("Not unique city ids in test data")
else:
- self.assertAlmostEqual(a, b, places)
+ city_info["id"] = max(occupied_ids, default=1) + 1
+
+ cities_given_info = metro_sample["cities_info"]
+ cities_info = list()
+ for city_given_info in cities_given_info:
+ city_info = self.CITY_TEMPLATE.copy()
+ for attr in city_given_info.keys():
+ city_info[attr] = city_given_info[attr]
+ assign_unique_id(city_info, cities_info)
+ cities_info.append(city_info)
+
+ if len(set(ci["name"] for ci in cities_info)) < len(cities_info):
+ raise RuntimeError("Not unique city names in test data")
+
+ cities = list(map(self.city_class, cities_info))
+ if "xml" in metro_sample:
+ xml_file = io.BytesIO(metro_sample["xml"].encode())
+ else:
+ xml_file = (
+ Path(__file__).resolve().parent / metro_sample["xml_file"]
+ )
+ elements = load_xml(xml_file)
+ calculate_centers(elements)
+ add_osm_elements_to_cities(elements, cities)
+ validate_cities(cities)
+ transfers = find_transfers(elements, cities)
+ return cities, transfers
+
+ def _assertAnyAlmostEqual(
+ self,
+ first: Any,
+ second: Any,
+ places: int = 10,
+ ignore_keys: set = None,
+ ) -> None:
+ """Dispatcher method to other "...AlmostEqual" methods
+ depending on argument types.
+ """
+ if isinstance(first, Mapping):
+ self.assertMappingAlmostEqual(first, second, places, ignore_keys)
+ elif isinstance(first, Sequence) and not isinstance(
+ first, (str, bytes)
+ ):
+ self.assertSequenceAlmostEqual(first, second, places, ignore_keys)
+ else:
+ self.assertAlmostEqual(first, second, places)
+
+ def assertSequenceAlmostEqual(
+ self,
+ seq1: Sequence,
+ seq2: Sequence,
+ places: int = 10,
+ ignore_keys: set = None,
+ ) -> None:
+ """Compare two sequences, items of numeric types being compared
+ approximately, containers being approx-compared recursively.
+
+ :param: seq1 a sequence of values of any types, including collections
+ :param: seq2 a sequence of values of any types, including collections
+ :param: places number of fractional digits (passed to
+ assertAlmostEqual() method of parent class)
+ :param: ignore_keys a set of strs with keys in dictionaries
+ that should be ignored during recursive comparison
+ :return: None
+ """
+ if not (isinstance(seq1, Sequence) and isinstance(seq2, Sequence)):
+ raise RuntimeError(
+ f"Not a sequence passed to the '{self.__class__.__name__}."
+ "assertSequenceAlmostEqual' method"
+ )
+ self.assertEqual(len(seq1), len(seq2))
+ for a, b in zip(seq1, seq2):
+ self._assertAnyAlmostEqual(a, b, places, ignore_keys)
+
+ def assertMappingAlmostEqual(
+ self,
+ d1: Mapping,
+ d2: Mapping,
+ places: int = 10,
+ ignore_keys: set = None,
+ ) -> None:
+ """Compare dictionaries recursively, numeric values being compared
+ approximately.
+
+ :param: d1 a mapping of arbitrary key/value types,
+ including collections
+ :param: d1 a mapping of arbitrary key/value types,
+ including collections
+ :param: places number of fractional digits (passed to
+ assertAlmostEqual() method of parent class)
+ :param: ignore_keys a set of strs with keys in dictionaries
+ that should be ignored during recursive comparison
+ :return: None
+ """
+ if not (isinstance(d1, Mapping) and isinstance(d2, Mapping)):
+ raise RuntimeError(
+ f"Not a dictionary passed to the '{self.__class__.__name__}."
+ "assertMappingAlmostEqual' method"
+ )
+
+ d1_keys = set(d1.keys())
+ d2_keys = set(d2.keys())
+ if ignore_keys:
+ d1_keys -= ignore_keys
+ d2_keys -= ignore_keys
+ self.assertSetEqual(d1_keys, d2_keys)
+ for k in d1_keys:
+ v1 = d1[k]
+ v2 = d2[k]
+ self._assertAnyAlmostEqual(v1, v2, places, ignore_keys)
+
+
+class TestTransitDataMixin:
+ def compare_transit_data(self, td1: dict, td2: dict) -> None:
+ """Compare transit data td1 and td2 remembering that:
+ - arrays that represent sets ("routes", "itineraries", "entrances")
+ should be compared without order;
+ - all floating-point values (coordinates) should be compared
+ approximately.
+ """
+ self.assertMappingAlmostEqual(
+ td1,
+ td2,
+ ignore_keys={"stopareas", "routes", "itineraries"},
+ )
+
+ networks1 = td1["networks"]
+ networks2 = td2["networks"]
+
+ id_cmp = itemgetter("id")
+
+ for network_name, network_data1 in networks1.items():
+ network_data2 = networks2[network_name]
+ routes1 = sorted(network_data1["routes"], key=id_cmp)
+ routes2 = sorted(network_data2["routes"], key=id_cmp)
+ self.assertEqual(len(routes1), len(routes2))
+ for r1, r2 in zip(routes1, routes2):
+ self.assertMappingAlmostEqual(
+ r1, r2, ignore_keys={"itineraries"}
+ )
+ its1 = sorted(r1["itineraries"], key=id_cmp)
+ its2 = sorted(r2["itineraries"], key=id_cmp)
+ self.assertEqual(len(its1), len(its2))
+ for it1, it2 in zip(its1, its2):
+ self.assertMappingAlmostEqual(it1, it2)
+
+ transfers1 = td1["transfers"]
+ transfers2 = td2["transfers"]
+ self.assertSetEqual(transfers1, transfers2)
+
+ stopareas1 = td1["stopareas"]
+ stopareas2 = td2["stopareas"]
+ self.assertMappingAlmostEqual(
+ stopareas1, stopareas2, ignore_keys={"entrances"}
+ )
+
+ for sa_id, sa1_data in stopareas1.items():
+ sa2_data = stopareas2[sa_id]
+ entrances1 = sorted(sa1_data["entrances"], key=id_cmp)
+ entrances2 = sorted(sa2_data["entrances"], key=id_cmp)
+ self.assertEqual(len(entrances1), len(entrances2))
+ for e1, e2 in zip(entrances1, entrances2):
+ self.assertMappingAlmostEqual(e1, e2)