Add type aliases, declarations and annotations

This commit is contained in:
Alexey Zakharenkov 2024-02-27 14:59:51 +03:00
parent 28f4c0d139
commit c2f2956da1
8 changed files with 249 additions and 215 deletions

View file

@ -152,7 +152,7 @@ CSS_COLOURS = {
}
def normalize_colour(c):
def normalize_colour(c: str | None) -> str | None:
if not c:
return None
c = c.strip().lower()

View file

@ -25,8 +25,10 @@ from subway_structure import (
CriticalValidationError,
find_transfers,
get_unused_subway_entrances_geojson,
LonLat,
MODES_OVERGROUND,
MODES_RAPID,
OsmElementT,
)
DEFAULT_SPREADSHEET_ID = "1SEW1-NiNOnA2qDwievcxYV1FOaQl1mb1fdeyqAxHu3k"
@ -36,8 +38,6 @@ DEFAULT_CITIES_INFO_URL = (
)
BAD_MARK = "[bad]"
Point = tuple[float, float]
def compose_overpass_request(
overground: bool, bboxes: list[list[float]]
@ -68,7 +68,7 @@ def compose_overpass_request(
def overpass_request(
overground: bool, overpass_api: str, bboxes: list[list[float]]
) -> list[dict]:
) -> list[OsmElementT]:
query = compose_overpass_request(overground, bboxes)
url = f"{overpass_api}?data={urllib.parse.quote(query)}"
response = urllib.request.urlopen(url, timeout=1000)
@ -79,7 +79,7 @@ def overpass_request(
def multi_overpass(
overground: bool, overpass_api: str, bboxes: list[list[float]]
) -> list[dict]:
) -> list[OsmElementT]:
SLICE_SIZE = 10
INTERREQUEST_WAIT = 5 # in seconds
result = []
@ -96,8 +96,8 @@ def slugify(name: str) -> str:
def get_way_center(
element: dict, node_centers: dict[int, Point]
) -> Point | None:
element: OsmElementT, node_centers: dict[int, LonLat]
) -> LonLat | None:
"""
:param element: dict describing OSM element
:param node_centers: osm_id => (lat, lon)
@ -107,7 +107,7 @@ def get_way_center(
# If elements have been queried via overpass-api with
# 'out center;' clause then ways already have 'center' attribute
if "center" in element:
return element["center"]["lat"], element["center"]["lon"]
return element["center"]["lon"], element["center"]["lat"]
if "nodes" not in element:
return None
@ -131,22 +131,22 @@ def get_way_center(
count += 1
if count == 0:
return None
element["center"] = {"lat": center[0] / count, "lon": center[1] / count}
return element["center"]["lat"], element["center"]["lon"]
element["center"] = {"lat": center[1] / count, "lon": center[0] / count}
return element["center"]["lon"], element["center"]["lat"]
def get_relation_center(
element: dict,
node_centers: dict[int, Point],
way_centers: dict[int, Point],
relation_centers: dict[int, Point],
element: OsmElementT,
node_centers: dict[int, LonLat],
way_centers: dict[int, LonLat],
relation_centers: dict[int, LonLat],
ignore_unlocalized_child_relations: bool = False,
) -> Point | None:
) -> LonLat | None:
"""
:param element: dict describing OSM element
:param node_centers: osm_id => (lat, lon)
:param way_centers: osm_id => (lat, lon)
:param relation_centers: osm_id => (lat, lon)
:param node_centers: osm_id => LonLat
:param way_centers: osm_id => LonLat
:param relation_centers: osm_id => LonLat
:param ignore_unlocalized_child_relations: if a member that is a relation
has no center, skip it and calculate center based on member nodes,
ways and other, "localized" (with known centers), relations
@ -159,7 +159,7 @@ def get_relation_center(
# of other relations (e.g., route_master, stop_area_group or
# stop_area with only members that are multipolygons)
if "center" in element:
return element["center"]["lat"], element["center"]["lon"]
return element["center"]["lon"], element["center"]["lat"]
center = [0, 0]
count = 0
@ -186,25 +186,25 @@ def get_relation_center(
count += 1
if count == 0:
return None
element["center"] = {"lat": center[0] / count, "lon": center[1] / count}
return element["center"]["lat"], element["center"]["lon"]
element["center"] = {"lat": center[1] / count, "lon": center[0] / count}
return element["center"]["lon"], element["center"]["lat"]
def calculate_centers(elements: list[dict]) -> None:
def calculate_centers(elements: list[OsmElementT]) -> None:
"""Adds 'center' key to each way/relation in elements,
except for empty ways or relations.
Relies on nodes-ways-relations order in the elements list.
"""
nodes: dict[int, Point] = {} # id => (lat, lon)
ways: dict[int, Point] = {} # id => (lat, lon)
relations: dict[int, Point] = {} # id => (lat, lon)
nodes: dict[int, LonLat] = {} # id => LonLat
ways: dict[int, LonLat] = {} # id => approx center LonLat
relations: dict[int, LonLat] = {} # id => approx center LonLat
unlocalized_relations = [] # 'unlocalized' means the center of the
# relation has not been calculated yet
unlocalized_relations: list[OsmElementT] = [] # 'unlocalized' means
# the center of the relation has not been calculated yet
for el in elements:
if el["type"] == "node":
nodes[el["id"]] = (el["lat"], el["lon"])
nodes[el["id"]] = (el["lon"], el["lat"])
elif el["type"] == "way":
if center := get_way_center(el, nodes):
ways[el["id"]] = center
@ -216,7 +216,7 @@ def calculate_centers(elements: list[dict]) -> None:
def iterate_relation_centers_calculation(
ignore_unlocalized_child_relations: bool,
) -> list[dict]:
) -> list[OsmElementT]:
unlocalized_relations_upd = []
for rel in unlocalized_relations:
if center := get_relation_center(
@ -244,7 +244,7 @@ def calculate_centers(elements: list[dict]) -> None:
def add_osm_elements_to_cities(
osm_elements: list[dict], cities: list[City]
osm_elements: list[OsmElementT], cities: list[City]
) -> None:
for el in osm_elements:
for c in cities:

View file

@ -1,6 +1,4 @@
from typing import List, Set
from subway_structure import City, el_center, StopArea
from subway_structure import City, el_center, TransfersT
DEFAULT_INTERVAL = 2.5 * 60 # seconds
KMPH_TO_MPS = 1 / 3.6 # km/h to m/s conversion multiplier
@ -8,14 +6,12 @@ SPEED_ON_TRANSFER = 3.5 * KMPH_TO_MPS # m/s
TRANSFER_PENALTY = 30 # seconds
def format_colour(colour):
def format_colour(colour: str | None) -> str | None:
"""Truncate leading # sign."""
return colour[1:] if colour else None
def transit_to_dict(
cities: List[City], transfers: List[Set[StopArea]]
) -> dict:
def transit_to_dict(cities: list[City], transfers: TransfersT) -> dict:
"""Get data for good cities as a dictionary."""
data = {
"stopareas": {}, # stoparea id => stoparea data

View file

@ -3,7 +3,6 @@ from functools import partial
from io import BytesIO, StringIO
from itertools import permutations
from tarfile import TarFile, TarInfo
from typing import List, Optional, Set
from zipfile import ZipFile
from ._common import (
@ -16,7 +15,7 @@ from ._common import (
from subway_structure import (
City,
distance,
StopArea,
TransfersT,
)
@ -133,13 +132,13 @@ GTFS_COLUMNS = {
}
def round_coords(coords_tuple):
def round_coords(coords_tuple: tuple) -> tuple:
return tuple(
map(lambda coord: round(coord, COORDINATE_PRECISION), coords_tuple)
)
def transit_data_to_gtfs(data):
def transit_data_to_gtfs(data: dict) -> dict:
# Keys correspond GTFS file names
gtfs_data = {key: [] for key in GTFS_COLUMNS.keys()}
@ -313,14 +312,14 @@ def transit_data_to_gtfs(data):
def process(
cities: List[City],
transfers: List[Set[StopArea]],
cities: list[City],
transfers: TransfersT,
filename: str,
cache_path: str,
):
cache_path: str | None,
) -> None:
"""Generate all output and save to file.
:param cities: List of City instances
:param transfers: List of sets of StopArea.id
:param cities: list of City instances
:param transfers: all collected transfers in the world
:param filename: Path to file to save the result
:param cache_path: Path to json-file with good cities cache or None.
"""
@ -344,9 +343,7 @@ def dict_to_row(dict_data: dict, record_type: str) -> list:
]
def make_gtfs(
filename: str, gtfs_data: dict, fmt: Optional[str] = None
) -> None:
def make_gtfs(filename: str, gtfs_data: dict, fmt: str | None = None) -> None:
if not fmt:
fmt = "tar" if filename.endswith(".tar") else "zip"

View file

@ -2,13 +2,19 @@ import json
import logging
import os
from collections import defaultdict
from collections.abc import Callable
from typing import Any, TypeAlias
from subway_structure import (
City,
DISPLACEMENT_TOLERANCE,
distance,
el_center,
IdT,
LonLat,
OsmElementT,
Station,
StopArea,
TransfersT,
)
from ._common import (
@ -19,14 +25,16 @@ from ._common import (
TRANSFER_PENALTY,
)
OSM_TYPES = {"n": (0, "node"), "w": (2, "way"), "r": (3, "relation")}
ENTRANCE_PENALTY = 60 # seconds
SPEED_TO_ENTRANCE = 5 * KMPH_TO_MPS # m/s
SPEED_ON_LINE = 40 * KMPH_TO_MPS # m/s
# (stoparea1_uid, stoparea2_uid) -> seconds; stoparea1_uid < stoparea2_uid
TransferTimesT: TypeAlias = dict[tuple[int, int], int]
def uid(elid, typ=None):
def uid(elid: IdT, typ: str | None = None) -> int:
t = elid[0]
osm_id = int(elid[1:])
if not typ:
@ -39,24 +47,24 @@ def uid(elid, typ=None):
class DummyCache:
"""This class may be used when you need to omit all cache processing"""
def __init__(self, cache_path, cities):
def __init__(self, cache_path: str, cities: list[City]) -> None:
pass
def __getattr__(self, name):
def __getattr__(self, name: str) -> Callable[..., None]:
"""This results in that a call to any method effectively does nothing
and does not generate exceptions."""
def method(*args, **kwargs):
def method(*args, **kwargs) -> None:
return None
return method
def if_object_is_used(method):
def if_object_is_used(method: Callable) -> Callable:
"""Decorator to skip method execution under certain condition.
Relies on "is_used" object property."""
def inner(self, *args, **kwargs):
def inner(self, *args, **kwargs) -> Any:
if not self.is_used:
return
return method(self, *args, **kwargs)
@ -65,7 +73,7 @@ def if_object_is_used(method):
class MapsmeCache:
def __init__(self, cache_path, cities):
def __init__(self, cache_path: str, cities: list[City]) -> None:
if not cache_path:
# Cache is not used,
# all actions with cache must be silently skipped
@ -90,7 +98,7 @@ class MapsmeCache:
self.city_dict = {c.name: c for c in cities}
self.good_city_names = {c.name for c in cities if c.is_good}
def _is_cached_city_usable(self, city):
def _is_cached_city_usable(self, city: City) -> bool:
"""Check if cached stations still exist in osm data and
not moved far away.
"""
@ -105,8 +113,9 @@ class MapsmeCache:
):
return False
station_coords = el_center(city_station)
cached_station_coords = tuple(
cached_stoparea[coord] for coord in ("lon", "lat")
cached_station_coords = (
cached_stoparea["lon"],
cached_stoparea["lat"],
)
displacement = distance(station_coords, cached_station_coords)
if displacement > DISPLACEMENT_TOLERANCE:
@ -115,7 +124,9 @@ class MapsmeCache:
return True
@if_object_is_used
def provide_stops_and_networks(self, stops, networks):
def provide_stops_and_networks(
self, stops: dict, networks: list[dict]
) -> None:
"""Put stops and networks for bad cities into containers
passed as arguments."""
for city in self.city_dict.values():
@ -128,7 +139,7 @@ class MapsmeCache:
self.recovered_city_names.add(city.name)
@if_object_is_used
def provide_transfers(self, transfers):
def provide_transfers(self, transfers: TransferTimesT) -> None:
"""Add transfers from usable cached cities to 'transfers' dict
passed as argument."""
for city_name in self.recovered_city_names:
@ -138,7 +149,7 @@ class MapsmeCache:
transfers[(stop1_uid, stop2_uid)] = transfer_time
@if_object_is_used
def initialize_good_city(self, city_name, network):
def initialize_good_city(self, city_name: str, network: dict) -> None:
"""Create/replace one cache element with new data container.
This should be done for each good city."""
self.cache[city_name] = {
@ -149,20 +160,22 @@ class MapsmeCache:
}
@if_object_is_used
def link_stop_with_city(self, stoparea_id, city_name):
def link_stop_with_city(self, stoparea_id: IdT, city_name: str) -> None:
"""Remember that some stop_area is used in a city."""
stoparea_uid = uid(stoparea_id)
self.stop_cities[stoparea_uid].add(city_name)
@if_object_is_used
def add_stop(self, stoparea_id, st):
def add_stop(self, stoparea_id: IdT, st: dict) -> None:
"""Add stoparea to the cache of each city the stoparea is in."""
stoparea_uid = uid(stoparea_id)
for city_name in self.stop_cities[stoparea_uid]:
self.cache[city_name]["stops"][stoparea_id] = st
@if_object_is_used
def add_transfer(self, stoparea1_uid, stoparea2_uid, transfer_time):
def add_transfer(
self, stoparea1_uid: int, stoparea2_uid: int, transfer_time: int
) -> None:
"""If a transfer is inside a good city, add it to the city's cache."""
for city_name in (
self.good_city_names
@ -174,7 +187,7 @@ class MapsmeCache:
)
@if_object_is_used
def save(self):
def save(self) -> None:
try:
with open(self.cache_path, "w", encoding="utf-8") as f:
json.dump(self.cache, f, ensure_ascii=False)
@ -191,7 +204,9 @@ def transit_data_to_mapsme(
:param cache_path: Path to json-file with good cities cache or None.
"""
def find_exits_for_platform(center, nodes):
def find_exits_for_platform(
center: LonLat, nodes: list[OsmElementT]
) -> list[OsmElementT]:
exits = []
min_distance = None
for n in nodes:
@ -212,8 +227,8 @@ def transit_data_to_mapsme(
cache = MapsmeCache(cache_path, cities)
stop_areas = {} # stoparea el_id -> StopArea instance
stops = {} # stoparea el_id -> stop jsonified data
stop_areas: dict[IdT, StopArea] = {}
stops: dict[IdT, dict] = {} # stoparea el_id -> stop jsonified data
networks = []
good_cities = [c for c in cities if c.is_good]
platform_nodes = {}
@ -362,9 +377,7 @@ def transit_data_to_mapsme(
stops[stop_id] = st
cache.add_stop(stop_id, st)
pairwise_transfers = (
{}
) # (stoparea1_uid, stoparea2_uid) -> time; uid1 < uid2
pairwise_transfers: TransferTimesT = {}
for stoparea_id_set in transfers:
stoparea_ids = list(stoparea_id_set)
for i_first in range(len(stoparea_ids) - 1):
@ -388,14 +401,14 @@ def transit_data_to_mapsme(
cache.provide_transfers(pairwise_transfers)
cache.save()
pairwise_transfers = [
pairwise_transfers_list = [
(stop1_uid, stop2_uid, transfer_time)
for (stop1_uid, stop2_uid), transfer_time in pairwise_transfers.items()
]
result = {
"stops": list(stops.values()),
"transfers": pairwise_transfers,
"transfers": pairwise_transfers_list,
"networks": networks,
}
return result
@ -406,10 +419,10 @@ def process(
transfers: TransfersT,
filename: str,
cache_path: str | None,
):
) -> None:
"""Generate all output and save to file.
:param cities: List of City instances
:param transfers: List of sets of StopArea.id
:param cities: list of City instances
:param transfers: all collected transfers in the world
:param filename: Path to file to save the result
:param cache_path: Path to json-file with good cities cache or None.
"""

View file

@ -1,15 +1,18 @@
import json
import logging
from collections import OrderedDict
from typing import Any, TextIO
from subway_structure import City, OsmElementT, StopArea
def load_xml(f):
def load_xml(f: TextIO | str) -> list[OsmElementT]:
try:
from lxml import etree
except ImportError:
import xml.etree.ElementTree as etree
elements = []
elements: list[OsmElementT] = []
for event, element in etree.iterparse(f):
if element.tag in ("node", "way", "relation"):
@ -49,7 +52,7 @@ _YAML_SPECIAL_CHARACTERS = "!&*{}[],#|>@`'\""
_YAML_SPECIAL_SEQUENCES = ("- ", ": ", "? ")
def _get_yaml_compatible_string(scalar):
def _get_yaml_compatible_string(scalar: Any) -> str:
"""Enclose string in single quotes in some cases"""
string = str(scalar)
if string and (
@ -62,8 +65,8 @@ def _get_yaml_compatible_string(scalar):
return string
def dump_yaml(city, f):
def write_yaml(data, f, indent=""):
def dump_yaml(city: City, f: TextIO) -> None:
def write_yaml(data: dict, f: TextIO, indent: str = "") -> None:
if isinstance(data, (set, list)):
f.write("\n")
for i in data:
@ -138,10 +141,10 @@ def dump_yaml(city, f):
write_yaml(result, f)
def make_geojson(city, include_tracks_geometry=True):
transfers = set()
def make_geojson(city: City, include_tracks_geometry: bool = True) -> dict:
stopareas_in_transfers: set[StopArea] = set()
for t in city.transfers:
transfers.update(t)
stopareas_in_transfers.update(t)
features = []
stopareas = set()
stops = set()
@ -196,7 +199,7 @@ def make_geojson(city, include_tracks_geometry=True):
"name": stoparea.name,
"marker-size": "small",
"marker-color": "#ff2600"
if stoparea in transfers
if stoparea in stopareas_in_transfers
else "#797979",
},
}
@ -204,7 +207,7 @@ def make_geojson(city, include_tracks_geometry=True):
return {"type": "FeatureCollection", "features": features}
def _dumps_route_id(route_id):
def _dumps_route_id(route_id: tuple[str | None, str | None]) -> str:
"""Argument is a route_id that depends on route colour and ref. Name can
be taken from route_master or can be route's own, we don't take it into
consideration. Some of route attributes can be None. The function makes
@ -212,13 +215,13 @@ def _dumps_route_id(route_id):
return json.dumps(route_id, ensure_ascii=False)
def _loads_route_id(route_id_dump):
def _loads_route_id(route_id_dump: str) -> tuple[str | None, str | None]:
"""Argument is a json-encoded identifier of a route.
Return a tuple (colour, ref)."""
return tuple(json.loads(route_id_dump))
def read_recovery_data(path):
def read_recovery_data(path: str) -> dict:
"""Recovery data is a json with data from previous transport builds.
It helps to recover cities from some errors, e.g. by resorting
shuffled stations in routes."""
@ -246,11 +249,15 @@ def read_recovery_data(path):
return data
def write_recovery_data(path, current_data, cities):
def write_recovery_data(
path: str, current_data: dict, cities: list[City]
) -> None:
"""Updates recovery data with good cities data and writes to file."""
def make_city_recovery_data(city):
routes = {}
def make_city_recovery_data(
city: City,
) -> dict[tuple[str | None, str | None], list[dict]]:
routes: dict[tuple(str | None, str | None), list[dict]] = {}
for route in city:
# Recovery is based primarily on route/station names/refs.
# If route's ref/colour changes, the route won't be used.

View file

@ -3,9 +3,9 @@ from __future__ import annotations
import math
import re
from collections import Counter, defaultdict
from collections.abc import Collection, Iterator
from collections.abc import Callable, Collection, Iterator
from itertools import chain, islice
from typing import TypeVar
from typing import TypeAlias, TypeVar
from css_colours import normalize_colour
@ -47,13 +47,18 @@ used_entrances = set()
START_END_TIMES_RE = re.compile(r".*?(\d{2}):(\d{2})-(\d{2}):(\d{2}).*")
IdT = str # Type of feature ids
TransferT = set[IdT] # A transfer is a set of StopArea IDs
TransfersT = Collection[TransferT]
OsmElementT: TypeAlias = dict
IdT: TypeAlias = str # Type of feature ids
TransferT: TypeAlias = set[IdT] # A transfer is a set of StopArea IDs
TransfersT: TypeAlias = list[TransferT]
LonLat: TypeAlias = tuple[float, float]
RailT: TypeAlias = list[LonLat]
T = TypeVar("T")
def get_start_end_times(opening_hours):
def get_start_end_times(
opening_hours: str,
) -> tuple[tuple[int, int], tuple[int, int]] | tuple[None, None]:
"""Very simplified method to parse OSM opening_hours tag.
We simply take the first HH:MM-HH:MM substring which is the most probable
opening hours interval for the most of the weekdays.
@ -67,7 +72,7 @@ def get_start_end_times(opening_hours):
return start_time, end_time
def osm_interval_to_seconds(interval_str):
def osm_interval_to_seconds(interval_str: str) -> int | None:
"""Convert to int an OSM value for 'interval'/'headway' tag
which may be in these formats:
HH:MM:SS,
@ -97,7 +102,7 @@ class CriticalValidationError(Exception):
that prevents further validation of a city."""
def el_id(el):
def el_id(el: OsmElementT) -> IdT | None:
if not el:
return None
if "type" not in el:
@ -105,7 +110,7 @@ def el_id(el):
return el["type"][0] + str(el.get("id", el.get("ref", "")))
def el_center(el):
def el_center(el: OsmElementT) -> LonLat | None:
if not el:
return None
if "lat" in el:
@ -115,7 +120,7 @@ def el_center(el):
return None
def distance(p1, p2):
def distance(p1: LonLat, p2: LonLat) -> float:
if p1 is None or p2 is None:
raise Exception(
"One of arguments to distance({}, {}) is None".format(p1, p2)
@ -127,14 +132,14 @@ def distance(p1, p2):
return 6378137 * math.sqrt(dx * dx + dy * dy)
def is_near(p1, p2):
def is_near(p1: LonLat, p2: LonLat) -> bool:
return (
p1[0] - 1e-8 <= p2[0] <= p1[0] + 1e-8
and p1[1] - 1e-8 <= p2[1] <= p1[1] + 1e-8
)
def project_on_segment(p, p1, p2):
def project_on_segment(p: LonLat, p1: LonLat, p2: LonLat) -> float | None:
"""Given three points, return u - the position of projection of
point p onto segment p1p2 regarding point p1 and (p2-p1) direction vector
"""
@ -148,7 +153,7 @@ def project_on_segment(p, p1, p2):
return u
def project_on_line(p, line):
def project_on_line(p: LonLat, line: RailT) -> dict:
result = {
# In the first approximation, position on rails is the index of the
# closest vertex of line to the point p. Fractional value means that
@ -212,7 +217,9 @@ def project_on_line(p, line):
return result
def find_segment(p, line, start_vertex=0):
def find_segment(
p: LonLat, line: RailT, start_vertex: int = 0
) -> tuple[int, float] | tuple[None, None]:
"""Returns index of a segment and a position inside it."""
EPS = 1e-9
for seg in range(start_vertex, len(line) - 1):
@ -237,7 +244,9 @@ def find_segment(p, line, start_vertex=0):
return None, None
def distance_on_line(p1, p2, line, start_vertex=0):
def distance_on_line(
p1: LonLat, p2: LonLat, line: RailT, start_vertex: int = 0
) -> tuple[float, int] | None:
"""Calculates distance via line between projections
of points p1 and p2. Returns a TUPLE of (d, vertex):
d is the distance and vertex is the number of the second
@ -270,7 +279,7 @@ def distance_on_line(p1, p2, line, start_vertex=0):
return d, seg2 % line_len
def angle_between(p1, c, p2):
def angle_between(p1: LonLat, c: LonLat, p2: LonLat) -> float:
a = round(
abs(
math.degrees(
@ -282,7 +291,7 @@ def angle_between(p1, c, p2):
return a if a <= 180 else 360 - a
def format_elid_list(ids):
def format_elid_list(ids: Collection[IdT]) -> str:
msg = ", ".join(sorted(ids)[:20])
if len(ids) > 20:
msg += ", ..."
@ -291,14 +300,14 @@ def format_elid_list(ids):
class Station:
@staticmethod
def get_modes(el: dict) -> set[str]:
def get_modes(el: OsmElementT) -> set[str]:
modes = {m for m in ALL_MODES if el["tags"].get(m) == "yes"}
if mode := el["tags"].get("station"):
modes.add(mode)
return modes
@staticmethod
def is_station(el, modes):
def is_station(el: OsmElementT, modes: set[str]) -> bool:
# public_transport=station is too ambiguous and unspecific to use,
# so we expect for it to be backed by railway=station.
if (
@ -316,7 +325,7 @@ class Station:
return False
return True
def __init__(self, el, city):
def __init__(self, el: OsmElementT, city: City) -> None:
"""Call this with a railway=station node."""
if not Station.is_station(el, city.modes):
raise Exception(
@ -324,8 +333,8 @@ class Station:
"Got: {}".format(el)
)
self.id = el_id(el)
self.element = el
self.id: IdT = el_id(el)
self.element: OsmElementT = el
self.modes = Station.get_modes(el)
self.name = el["tags"].get("name", "?")
self.int_name = el["tags"].get(
@ -340,7 +349,7 @@ class Station:
if self.center is None:
raise Exception("Could not find center of {}".format(el))
def __repr__(self):
def __repr__(self) -> str:
return "Station(id={}, modes={}, name={}, center={})".format(
self.id, ",".join(self.modes), self.name, self.center
)
@ -348,7 +357,7 @@ class Station:
class StopArea:
@staticmethod
def is_stop(el):
def is_stop(el: OsmElementT) -> bool:
if "tags" not in el:
return False
if el["tags"].get("railway") == "stop":
@ -358,7 +367,7 @@ class StopArea:
return False
@staticmethod
def is_platform(el):
def is_platform(el: OsmElementT) -> bool:
if "tags" not in el:
return False
if el["tags"].get("railway") in ("platform", "platform_edge"):
@ -368,19 +377,22 @@ class StopArea:
return False
@staticmethod
def is_track(el):
def is_track(el: OsmElementT) -> bool:
if el["type"] != "way" or "tags" not in el:
return False
return el["tags"].get("railway") in RAILWAY_TYPES
def __init__(
self, station: Station, city: City, stop_area: StopArea | None = None
self,
station: Station,
city: City,
stop_area: OsmElementT | None = None,
) -> None:
"""Call this with a Station object."""
self.element = stop_area or station.element
self.id = el_id(self.element)
self.station = station
self.element: OsmElementT = stop_area or station.element
self.id: IdT = el_id(self.element)
self.station: Station = station
self.stops = set() # set of el_ids of stop_positions
self.platforms = set() # set of el_ids of platforms
self.exits = set() # el_id of subway_entrance/train_station_entrance
@ -440,7 +452,7 @@ class StopArea:
self.center[i] /= len(self.stops) + len(self.platforms)
def _process_members(
self, station: Station, city: City, stop_area: dict
self, station: Station, city: City, stop_area: OsmElementT
) -> None:
# If we have a stop area, add all elements from it
tracks_detected = False
@ -503,7 +515,7 @@ class StopArea:
if etag != "entrance":
self.exits.add(entrance_id)
def get_elements(self):
def get_elements(self) -> set[IdT]:
result = {self.id, self.station.id}
result.update(self.entrances)
result.update(self.exits)
@ -511,7 +523,7 @@ class StopArea:
result.update(self.platforms)
return result
def __repr__(self):
def __repr__(self) -> str:
return (
f"StopArea(id={self.id}, name={self.name}, station={self.station},"
f" transfer={self.transfer}, center={self.center})"
@ -519,9 +531,9 @@ class StopArea:
class RouteStop:
def __init__(self, stoparea):
self.stoparea = stoparea
self.stop = None # Stop position (lon, lat), possibly projected
def __init__(self, stoparea: StopArea) -> None:
self.stoparea: StopArea = stoparea
self.stop: LonLat = None # Stop position, possibly projected
self.distance = 0 # In meters from the start of the route
self.platform_entry = None # Platform el_id
self.platform_exit = None # Platform el_id
@ -533,11 +545,13 @@ class RouteStop:
self.seen_station = False
@property
def seen_platform(self):
def seen_platform(self) -> bool:
return self.seen_platform_entry or self.seen_platform_exit
@staticmethod
def get_actual_role(el, role, modes):
def get_actual_role(
el: OsmElementT, role: str, modes: set[str]
) -> str | None:
if StopArea.is_stop(el):
return "stop"
elif StopArea.is_platform(el):
@ -549,7 +563,7 @@ class RouteStop:
return "stop"
return None
def add(self, member, relation, city):
def add(self, member: dict, relation: OsmElementT, city: City) -> None:
el = city.elements[el_id(member)]
role = member["role"]
@ -616,7 +630,7 @@ class RouteStop:
relation,
)
def __repr__(self):
def __repr__(self) -> str:
return (
"RouteStop(stop={}, pl_entry={}, pl_exit={}, stoparea={})".format(
self.stop,
@ -628,10 +642,10 @@ class RouteStop:
class Route:
"""Corresponds to OSM "type=route" relation"""
"""The longest route for a city with a unique ref."""
@staticmethod
def is_route(el, modes):
def is_route(el: OsmElementT, modes: set[str]) -> bool:
if (
el["type"] != "relation"
or el.get("tags", {}).get("type") != "route"
@ -649,14 +663,14 @@ class Route:
return True
@staticmethod
def get_network(relation):
def get_network(relation: OsmElementT) -> str | None:
for k in ("network:metro", "network", "operator"):
if k in relation["tags"]:
return relation["tags"][k]
return None
@staticmethod
def get_interval(tags):
def get_interval(tags: dict) -> int | None:
v = None
for k in ("interval", "headway"):
if k in tags:
@ -681,16 +695,16 @@ class Route:
def __init__(
self,
relation: dict,
relation: OsmElementT,
city: City,
master: dict | None = None,
master: OsmElementT | None = None,
) -> None:
assert Route.is_route(
relation, city.modes
), f"The relation does not seem to be a route: {relation}"
self.city = city
self.element = relation
self.id = el_id(relation)
self.element: OsmElementT = relation
self.id: IdT = el_id(relation)
self.ref = None
self.name = None
@ -702,7 +716,7 @@ class Route:
self.start_time = None
self.end_time = None
self.is_circular = False
self.stops = [] # List of RouteStop
self.stops: list[RouteStop] = []
# Would be a list of (lon, lat) for the longest stretch. Can be empty.
self.tracks = None
# Index of the first stop that is located on/near the self.tracks
@ -714,10 +728,10 @@ class Route:
stop_position_elements = self.process_stop_members()
self.process_tracks(stop_position_elements)
def build_longest_line(self):
line_nodes = set()
last_track = []
track = []
def build_longest_line(self) -> tuple[list[IdT], set[IdT]]:
line_nodes: set[IdT] = set()
last_track: list[IdT] = []
track: list[IdT] = []
warned_about_holes = False
for m in self.element["members"]:
el = self.city.elements.get(el_id(m), None)
@ -726,7 +740,7 @@ class Route:
if "nodes" not in el or len(el["nodes"]) < 2:
self.city.error("Cannot find nodes in a railway", el)
continue
nodes = ["n{}".format(n) for n in el["nodes"]]
nodes: list[IdT] = ["n{}".format(n) for n in el["nodes"]]
if m["role"] == "backward":
nodes.reverse()
line_nodes.update(nodes)
@ -773,10 +787,10 @@ class Route:
]
return last_track, line_nodes
def get_stop_projections(self):
def get_stop_projections(self) -> tuple[list[dict], Callable[[int], bool]]:
projected = [project_on_line(x.stop, self.tracks) for x in self.stops]
def stop_near_tracks_criterion(stop_index: int):
def stop_near_tracks_criterion(stop_index: int) -> bool:
return (
projected[stop_index]["projected_point"] is not None
and distance(
@ -788,14 +802,14 @@ class Route:
return projected, stop_near_tracks_criterion
def project_stops_on_line(self):
def project_stops_on_line(self) -> dict:
projected, stop_near_tracks_criterion = self.get_stop_projections()
projected_stops_data = {
"first_stop_on_rails_index": None,
"last_stop_on_rails_index": None,
"stops_on_longest_line": [], # list [{'route_stop': RouteStop,
# 'coords': (lon, lat),
# 'coords': LonLat,
# 'positions_on_rails': [] }
}
first_index = 0
@ -848,7 +862,7 @@ class Route:
projected_stops_data["stops_on_longest_line"].append(stop_data)
return projected_stops_data
def calculate_distances(self):
def calculate_distances(self) -> None:
dist = 0
vertex = 0
for i, stop in enumerate(self.stops):
@ -870,7 +884,7 @@ class Route:
dist += round(direct)
stop.distance = dist
def process_tags(self, master):
def process_tags(self, master: OsmElementT) -> None:
relation = self.element
master_tags = {} if not master else master["tags"]
if "ref" not in relation["tags"] and "ref" not in master_tags:
@ -918,12 +932,12 @@ class Route:
relation,
)
def process_stop_members(self):
stations = set() # temporary for recording stations
def process_stop_members(self) -> list[OsmElementT]:
stations: set[StopArea] = set() # temporary for recording stations
seen_stops = False
seen_platforms = False
repeat_pos = None
stop_position_elements = []
stop_position_elements: list[OsmElementT] = []
for m in self.element["members"]:
if "inactive" in m["role"]:
continue
@ -1072,7 +1086,9 @@ class Route:
)
return stop_position_elements
def process_tracks(self, stop_position_elements: list[dict]) -> None:
def process_tracks(
self, stop_position_elements: list[OsmElementT]
) -> None:
tracks, line_nodes = self.build_longest_line()
for stop_el in stop_position_elements:
@ -1130,7 +1146,7 @@ class Route:
if stop_coords := stop_data["coords"]:
route_stop.stop = stop_coords
def get_extended_tracks(self):
def get_extended_tracks(self) -> RailT:
"""Amend tracks with points of leading/trailing self.stops
that were not projected onto the longest tracks line.
Return a new array.
@ -1153,7 +1169,7 @@ class Route:
)
return tracks
def get_truncated_tracks(self, tracks):
def get_truncated_tracks(self, tracks: RailT) -> RailT:
"""Truncate leading/trailing segments of `tracks` param
that are beyond the first and last stop locations.
Return a new array.
@ -1194,12 +1210,12 @@ class Route:
and self.last_stop_on_rails_index == len(self) - 1
)
def get_tracks_geometry(self):
def get_tracks_geometry(self) -> RailT:
tracks = self.get_extended_tracks()
tracks = self.get_truncated_tracks(tracks)
return tracks
def check_stops_order_by_angle(self) -> tuple[list, list]:
def check_stops_order_by_angle(self) -> tuple[list[str], list[str]]:
disorder_warnings = []
disorder_errors = []
for i, route_stop in enumerate(
@ -1222,7 +1238,9 @@ class Route:
disorder_warnings.append(msg)
return disorder_warnings, disorder_errors
def check_stops_order_on_tracks_direct(self, stop_sequence) -> str | None:
def check_stops_order_on_tracks_direct(
self, stop_sequence: Iterator[dict]
) -> str | None:
"""Checks stops order on tracks, following stop_sequence
in direct order only.
:param stop_sequence: list of dict{'route_stop', 'positions_on_rails',
@ -1253,7 +1271,9 @@ class Route:
)
max_position_on_rails = positions_on_rails[suitable_occurrence]
def check_stops_order_on_tracks(self, projected_stops_data) -> str | None:
def check_stops_order_on_tracks(
self, projected_stops_data: dict
) -> str | None:
"""Checks stops order on tracks, trying direct and reversed
order of stops in the stop_sequence.
:param projected_stops_data: info about RouteStops that belong to the
@ -1280,7 +1300,9 @@ class Route:
return error_message
def check_stops_order(self, projected_stops_data):
def check_stops_order(
self, projected_stops_data: dict
) -> tuple[list[str], list[str]]:
(
angle_disorder_warnings,
angle_disorder_errors,
@ -1294,7 +1316,9 @@ class Route:
disorder_errors.append(disorder_on_tracks_error)
return disorder_warnings, disorder_errors
def check_and_recover_stops_order(self, projected_stops_data: dict):
def check_and_recover_stops_order(
self, projected_stops_data: dict
) -> None:
"""
:param projected_stops_data: may change if we need to reverse tracks
"""
@ -1319,7 +1343,7 @@ class Route:
for msg in disorder_errors:
self.city.error(msg, self.element)
def try_resort_stops(self):
def try_resort_stops(self) -> bool:
"""Precondition: self.city.recovery_data is not None.
Return success of station order recovering."""
self_stops = {} # station name => RouteStop
@ -1388,7 +1412,7 @@ class Route:
]
return True
def get_end_transfers(self) -> tuple[str, str]:
def get_end_transfers(self) -> tuple[IdT, IdT]:
"""Using transfer ids because a train can arrive at different
stations within a transfer. But disregard transfer that may give
an impression of a circular route (for example,
@ -1406,7 +1430,7 @@ class Route:
)
)
def get_transfers_sequence(self) -> list[str]:
def get_transfers_sequence(self) -> list[IdT]:
"""Return a list of stoparea or transfer (if not None) ids."""
transfer_seq = [
stop.stoparea.transfer or stop.stoparea.id for stop in self
@ -1418,16 +1442,16 @@ class Route:
transfer_seq[0], transfer_seq[-1] = self.get_end_transfers()
return transfer_seq
def __len__(self):
def __len__(self) -> int:
return len(self.stops)
def __getitem__(self, i):
def __getitem__(self, i) -> RouteStop:
return self.stops[i]
def __iter__(self):
def __iter__(self) -> Iterator[RouteStop]:
return iter(self.stops)
def __repr__(self):
def __repr__(self) -> str:
return (
"Route(id={}, mode={}, ref={}, name={}, network={}, interval={}, "
"circular={}, num_stops={}, line_length={} m, from={}, to={}"
@ -1447,11 +1471,11 @@ class Route:
class RouteMaster:
def __init__(self, city: City, master: dict = None) -> None:
def __init__(self, city: City, master: OsmElementT = None) -> None:
self.city = city
self.routes = []
self.best = None
self.id = el_id(master)
self.best: Route = None
self.id: IdT = el_id(master)
self.has_master = master is not None
self.interval_from_master = False
if master:
@ -1871,16 +1895,16 @@ class RouteMaster:
stops_that_dont_match,
)
def __len__(self):
def __len__(self) -> int:
return len(self.routes)
def __getitem__(self, i):
def __getitem__(self, i) -> Route:
return self.routes[i]
def __iter__(self):
def __iter__(self) -> Iterator[Route]:
return iter(self.routes)
def __repr__(self):
def __repr__(self) -> str:
return (
f"RouteMaster(id={self.id}, mode={self.mode}, ref={self.ref}, "
f"name={self.name}, network={self.network}, "
@ -1891,11 +1915,11 @@ class RouteMaster:
class City:
route_class = Route
def __init__(self, city_data, overground=False):
def __init__(self, city_data: dict, overground: bool = False) -> None:
self.validate_called = False
self.errors = []
self.warnings = []
self.notices = []
self.errors: list[str] = []
self.warnings: list[str] = []
self.notices: list[str] = []
self.id = None
self.try_fill_int_attribute(city_data, "id")
self.name = city_data["name"]
@ -1940,16 +1964,14 @@ class City:
else:
self.bbox = None
self.elements = {} # Dict el_id → el
self.stations = defaultdict(list) # Dict el_id → list of StopAreas
self.routes = {} # Dict route_master_ref → RouteMaster
self.masters = {} # Dict el_id of route → route_master
self.stop_areas = defaultdict(
list
) # El_id → list of stop_area elements it belongs to
self.transfers: TransfersT = [] # List of sets of stop areas
self.station_ids = set() # Set of stations' uid
self.stops_and_platforms = set() # Set of stops and platforms el_id
self.elements: dict[IdT, OsmElementT] = {}
self.stations: dict[IdT, list[StopArea]] = defaultdict(list)
self.routes: dict[str, RouteMaster] = {} # keys are route_master refs
self.masters: dict[IdT, OsmElementT] = {} # Route id → master element
self.stop_areas: [IdT, list[OsmElementT]] = defaultdict(list)
self.transfers: list[set[StopArea]] = []
self.station_ids: set[IdT] = set()
self.stops_and_platforms: set[IdT] = set()
self.recovery_data = None
def try_fill_int_attribute(
@ -1980,7 +2002,7 @@ class City:
setattr(self, attr, attr_int)
@staticmethod
def log_message(message, el):
def log_message(message: str, el: OsmElementT) -> str:
if el:
tags = el.get("tags", {})
message += ' ({} {}, "{}")'.format(
@ -1990,24 +2012,24 @@ class City:
)
return message
def notice(self, message, el=None):
def notice(self, message: str, el: OsmElementT | None = None) -> None:
"""This type of message may point to a potential problem."""
msg = City.log_message(message, el)
self.notices.append(msg)
def warn(self, message, el=None):
def warn(self, message: str, el: OsmElementT | None = None) -> None:
"""A warning is definitely a problem but is doesn't prevent
from building a routing file and doesn't invalidate the city.
"""
msg = City.log_message(message, el)
self.warnings.append(msg)
def error(self, message, el=None):
def error(self, message: str, el: OsmElementT | None = None) -> None:
"""Error is a critical problem that invalidates the city."""
msg = City.log_message(message, el)
self.errors.append(msg)
def contains(self, el):
def contains(self, el: OsmElementT) -> bool:
center = el_center(el)
if center:
return (
@ -2016,7 +2038,7 @@ class City:
)
return False
def add(self, el):
def add(self, el: OsmElementT) -> None:
if el["type"] == "relation" and "members" not in el:
return
@ -2052,8 +2074,8 @@ class City:
else:
stop_areas.append(el)
def make_transfer(self, stoparea_group: dict) -> None:
transfer = set()
def make_transfer(self, stoparea_group: OsmElementT) -> None:
transfer: set[StopArea] = set()
for m in stoparea_group["members"]:
k = el_id(m)
el = self.elements.get(k)
@ -2195,7 +2217,7 @@ class City:
if len(inner_transfer) > 1
]
def __iter__(self):
def __iter__(self) -> Iterator[RouteMaster]:
return iter(self.routes.values())
def stopareas(self) -> Iterator[StopArea]:
@ -2207,7 +2229,7 @@ class City:
yielded_stopareas.add(stoparea)
@property
def is_good(self):
def is_good(self) -> bool:
if not (self.errors or self.validate_called):
raise RuntimeError(
"You mustn't refer to City.is_good property before calling "
@ -2215,7 +2237,7 @@ class City:
)
return len(self.errors) == 0
def get_validation_result(self):
def get_validation_result(self) -> dict:
result = {
"name": self.name,
"country": self.country,
@ -2260,7 +2282,7 @@ class City:
result["notices"] = self.notices
return result
def count_unused_entrances(self):
def count_unused_entrances(self) -> None:
global used_entrances
stop_areas = set()
for el in self.elements.values():
@ -2299,7 +2321,7 @@ class City:
f"relations: {format_elid_list(not_in_sa)}"
)
def validate_lines(self):
def validate_lines(self) -> None:
self.found_light_lines = len(
[x for x in self.routes.values() if x.mode != "subway"]
)
@ -2317,7 +2339,7 @@ class City:
)
)
def validate_overground_lines(self):
def validate_overground_lines(self) -> None:
self.found_tram_lines = len(
[x for x in self.routes.values() if x.mode == "tram"]
)
@ -2344,7 +2366,7 @@ class City:
),
)
def validate(self):
def validate(self) -> None:
networks = Counter()
self.found_stations = 0
unused_stations = set(self.station_ids)
@ -2421,7 +2443,7 @@ class City:
def find_transfers(
elements: list[dict], cities: Collection[City]
elements: list[OsmElementT], cities: Collection[City]
) -> TransfersT:
"""As for now, two Cities may contain the same stoparea, but those
StopArea instances would have different python id. So we don't store
@ -2457,7 +2479,7 @@ def find_transfers(
return transfers
def get_unused_subway_entrances_geojson(elements: list[dict]) -> dict:
def get_unused_subway_entrances_geojson(elements: list[OsmElementT]) -> dict:
global used_entrances
features = []
for el in elements:

View file

@ -7,7 +7,7 @@ import json
import os
import re
from collections import defaultdict
from typing import Any, Optional
from typing import Any
from process_subways import DEFAULT_SPREADSHEET_ID
from v2h_templates import (
@ -22,8 +22,7 @@ from v2h_templates import (
class CityData:
def __init__(self, city: Optional[str] = None) -> None:
self.city = city is not None
def __init__(self, city: dict | None = None) -> None:
self.data = {
"good_cities": 0,
"total_cities": 1 if city else 0,
@ -93,7 +92,7 @@ class CityData:
return s
def tmpl(s: str, data: Optional[CityData] = None, **kwargs) -> str:
def tmpl(s: str, data: CityData | None = None, **kwargs) -> str:
if data:
s = data.format(s)
if kwargs: