Use csv.DictReader instead of csv.reader to load city data

This commit is contained in:
Alexey Zakharenkov 2022-07-07 16:50:00 +03:00 committed by Alexey Zakharenkov
parent a684370eb6
commit 0c8821b850
2 changed files with 58 additions and 38 deletions

View file

@ -1476,30 +1476,32 @@ class RouteMaster:
class City:
def __init__(self, row, overground=False):
def __init__(self, city_data, overground=False):
self.errors = []
self.warnings = []
self.notices = []
self.name = row[1]
self.country = row[2]
self.continent = row[3]
if not row[0]:
self.error('City {} does not have an id'.format(self.name))
self.id = int(row[0] or '0')
self.id = int(city_data["id"])
self.name = city_data["name"]
self.country = city_data["country"]
self.continent = city_data["continent"]
self.overground = overground
if not overground:
self.num_stations = int(row[4])
self.num_lines = int(row[5] or '0')
self.num_light_lines = int(row[6] or '0')
self.num_interchanges = int(row[7] or '0')
self.num_stations = int(city_data["num_stations"])
self.num_lines = int(city_data["num_lines"] or '0')
self.num_light_lines = int(city_data["num_light_lines"] or '0')
self.num_interchanges = int(city_data["num_interchanges"] or '0')
else:
self.num_tram_lines = int(row[4] or '0')
self.num_trolleybus_lines = int(row[5] or '0')
self.num_bus_lines = int(row[6] or '0')
self.num_other_lines = int(row[7] or '0')
self.num_tram_lines = int(city_data["num_tram_lines"] or '0')
self.num_trolleybus_lines = int(city_data["num_trolleybus_lines"] or '0')
self.num_bus_lines = int(city_data["num_bus_lines"] or '0')
self.num_other_lines = int(city_data["num_other_lines"] or '0')
# Aquiring list of networks and modes
networks = None if len(row) <= 9 else row[9].split(':')
networks = (
None
if not city_data["networks"]
else city_data["networks"].split(':')
)
if not networks or len(networks[-1]) == 0:
self.networks = []
else:
@ -1515,7 +1517,7 @@ class City:
self.modes = set([x.strip() for x in networks[0].split(',')])
# Reversing bbox so it is (xmin, ymin, xmax, ymax)
bbox = row[8].split(',')
bbox = city_data["bbox"].split(',')
if len(bbox) == 4:
self.bbox = [float(bbox[i]) for i in (1, 0, 3, 2)]
else:
@ -2033,6 +2035,7 @@ def get_unused_entrances_geojson(elements):
def download_cities(overground=False):
assert not overground, "Overground transit not implemented yet"
url = (
'https://docs.google.com/spreadsheets/d/{}/export?format=csv{}'.format(
SPREADSHEET_ID, '&gid=1881416409' if overground else ''
@ -2046,16 +2049,33 @@ def download_cities(overground=False):
)
)
data = response.read().decode('utf-8')
r = csv.reader(data.splitlines())
next(r) # skipping the header
reader = csv.DictReader(
data.splitlines(),
fieldnames=(
"id",
"name",
"country",
"continent",
"num_stations",
"num_lines",
"num_light_lines",
"num_interchanges",
"bbox",
"networks",
),
)
next(reader) # skipping the header
names = set()
cities = []
for row in r:
if len(row) > 8 and row[8]:
cities.append(City(row, overground))
if row[0].strip() in names:
for city_data in reader:
if city_data["id"] and city_data["bbox"]:
cities.append(City(city_data, overground))
name = city_data["name"].strip()
if name in names:
logging.warning(
'Duplicate city name in the google spreadsheet: %s', row[0]
'Duplicate city name in the google spreadsheet: %s',
city_data
)
names.add(row[0].strip())
names.add(name)
return cities

View file

@ -20,22 +20,22 @@ from tests.sample_data import sample_networks
class TestOneRouteTracks(unittest.TestCase):
"""Test tracks extending and truncating on one-route networks"""
STATION_COUNT_INDEX = 4
CITY_TEMPLATE = [
1, # city id
"Null Island", # name
"World", # Country
"Africa", # continent
None, # station count. Would be taken from the sample network data under testing
1, # subway line count
0, # light rail line count
0, # interchanges
"-179, -89, 179, 89", # bbox
]
CITY_TEMPLATE = {
"id": 1,
"name": "Null Island",
"country": "World",
"continent": "Africa",
"num_stations": None, # Would be taken from the sample network data under testing
"num_lines": 1,
"num_light_lines": 0,
"num_interchanges": 0,
"bbox": "-179, -89, 179, 89",
"networks": "",
}
def prepare_city_routes(self, network):
city_data = self.CITY_TEMPLATE.copy()
city_data[self.STATION_COUNT_INDEX] = network["station_count"]
city_data["num_stations"] = network["station_count"]
city = City(city_data)
elements = load_xml(io.BytesIO(network["xml"].encode("utf-8")))
for el in elements: