diff --git a/subway_structure.py b/subway_structure.py index ece0b5c..6edd230 100644 --- a/subway_structure.py +++ b/subway_structure.py @@ -1476,30 +1476,32 @@ class RouteMaster: class City: - def __init__(self, row, overground=False): + def __init__(self, city_data, overground=False): self.errors = [] self.warnings = [] self.notices = [] - self.name = row[1] - self.country = row[2] - self.continent = row[3] - if not row[0]: - self.error('City {} does not have an id'.format(self.name)) - self.id = int(row[0] or '0') + self.id = int(city_data["id"]) + self.name = city_data["name"] + self.country = city_data["country"] + self.continent = city_data["continent"] self.overground = overground if not overground: - self.num_stations = int(row[4]) - self.num_lines = int(row[5] or '0') - self.num_light_lines = int(row[6] or '0') - self.num_interchanges = int(row[7] or '0') + self.num_stations = int(city_data["num_stations"]) + self.num_lines = int(city_data["num_lines"] or '0') + self.num_light_lines = int(city_data["num_light_lines"] or '0') + self.num_interchanges = int(city_data["num_interchanges"] or '0') else: - self.num_tram_lines = int(row[4] or '0') - self.num_trolleybus_lines = int(row[5] or '0') - self.num_bus_lines = int(row[6] or '0') - self.num_other_lines = int(row[7] or '0') + self.num_tram_lines = int(city_data["num_tram_lines"] or '0') + self.num_trolleybus_lines = int(city_data["num_trolleybus_lines"] or '0') + self.num_bus_lines = int(city_data["num_bus_lines"] or '0') + self.num_other_lines = int(city_data["num_other_lines"] or '0') # Aquiring list of networks and modes - networks = None if len(row) <= 9 else row[9].split(':') + networks = ( + None + if not city_data["networks"] + else city_data["networks"].split(':') + ) if not networks or len(networks[-1]) == 0: self.networks = [] else: @@ -1515,7 +1517,7 @@ class City: self.modes = set([x.strip() for x in networks[0].split(',')]) # Reversing bbox so it is (xmin, ymin, xmax, ymax) - bbox = row[8].split(',') + bbox = city_data["bbox"].split(',') if len(bbox) == 4: self.bbox = [float(bbox[i]) for i in (1, 0, 3, 2)] else: @@ -2033,6 +2035,7 @@ def get_unused_entrances_geojson(elements): def download_cities(overground=False): + assert not overground, "Overground transit not implemented yet" url = ( 'https://docs.google.com/spreadsheets/d/{}/export?format=csv{}'.format( SPREADSHEET_ID, '&gid=1881416409' if overground else '' @@ -2046,16 +2049,33 @@ def download_cities(overground=False): ) ) data = response.read().decode('utf-8') - r = csv.reader(data.splitlines()) - next(r) # skipping the header + reader = csv.DictReader( + data.splitlines(), + fieldnames=( + "id", + "name", + "country", + "continent", + "num_stations", + "num_lines", + "num_light_lines", + "num_interchanges", + "bbox", + "networks", + ), + ) + + next(reader) # skipping the header names = set() cities = [] - for row in r: - if len(row) > 8 and row[8]: - cities.append(City(row, overground)) - if row[0].strip() in names: + for city_data in reader: + if city_data["id"] and city_data["bbox"]: + cities.append(City(city_data, overground)) + name = city_data["name"].strip() + if name in names: logging.warning( - 'Duplicate city name in the google spreadsheet: %s', row[0] + 'Duplicate city name in the google spreadsheet: %s', + city_data ) - names.add(row[0].strip()) + names.add(name) return cities diff --git a/tests/test_build_tracks.py b/tests/test_build_tracks.py index 997a404..782555f 100644 --- a/tests/test_build_tracks.py +++ b/tests/test_build_tracks.py @@ -20,22 +20,22 @@ from tests.sample_data import sample_networks class TestOneRouteTracks(unittest.TestCase): """Test tracks extending and truncating on one-route networks""" - STATION_COUNT_INDEX = 4 - CITY_TEMPLATE = [ - 1, # city id - "Null Island", # name - "World", # Country - "Africa", # continent - None, # station count. Would be taken from the sample network data under testing - 1, # subway line count - 0, # light rail line count - 0, # interchanges - "-179, -89, 179, 89", # bbox - ] + CITY_TEMPLATE = { + "id": 1, + "name": "Null Island", + "country": "World", + "continent": "Africa", + "num_stations": None, # Would be taken from the sample network data under testing + "num_lines": 1, + "num_light_lines": 0, + "num_interchanges": 0, + "bbox": "-179, -89, 179, 89", + "networks": "", + } def prepare_city_routes(self, network): city_data = self.CITY_TEMPLATE.copy() - city_data[self.STATION_COUNT_INDEX] = network["station_count"] + city_data["num_stations"] = network["station_count"] city = City(city_data) elements = load_xml(io.BytesIO(network["xml"].encode("utf-8"))) for el in elements: