osm_conflate/conflate/osm.py
2019-06-04 11:15:21 +03:00

391 lines
16 KiB
Python

import logging
import requests
import re
from .data import OSMPoint
from . import etree
OVERPASS_SERVER = 'https://overpass-api.de/api/'
ALT_OVERPASS_SERVER = 'https://overpass.kumi.systems/api/'
OSM_API_SERVER = 'https://api.openstreetmap.org/api/0.6/'
BBOX_PADDING = 0.003 # in degrees, ~330 m default
class OsmDownloader:
def __init__(self, profile):
self.profile = profile
def set_overpass(self, server='alt'):
global OVERPASS_SERVER
if server == 'alt':
OVERPASS_SERVER = ALT_OVERPASS_SERVER
else:
OVERPASS_SERVER = server
def construct_overpass_query(self, bboxes):
"""Constructs an Overpass API query from the "query" list in the profile.
(k, v) turns into [k=v], (k,) into [k], (k, None) into [!k], (k, "~v") into [k~v]."""
tags = self.profile.get(
'query', required="a list of tuples. E.g. [('amenity', 'cafe'), ('name', '~Mc.*lds')]")
tag_strs = []
if isinstance(tags, str):
tag_strs = [tags]
else:
if not isinstance(tags[0], str) and isinstance(tags[0][0], str):
tags = [tags]
for tags_q in tags:
if isinstance(tags_q, str):
tag_strs.append(tags_q)
continue
tag_str = ''
for t in tags_q:
if len(t) == 1:
q = '"{}"'.format(t[0])
elif t[1] is None or len(t[1]) == 0:
q = '"!{}"'.format(t[0])
elif t[1][0] == '~':
q = '"{}"~"{}",i'.format(t[0], t[1][1:])
elif len(t) > 2:
q = '"{}"~"^({})$"'.format(t[0], '|'.join(t[1:]))
else:
q = '"{}"="{}"'.format(t[0], t[1])
tag_str += '[' + q + ']'
tag_strs.append(tag_str)
if self.profile.get('no_dataset_id', False):
ref = None
else:
ref = 'nwr["ref:' + self.profile.get(
'dataset_id', required='A fairly unique id of the dataset to query OSM') + '"]'
timeout = self.profile.get('overpass_timeout', 120)
query = '[out:xml]{};('.format('' if timeout is None else '[timeout:{}]'.format(timeout))
for bbox in bboxes:
bbox_str = '' if bbox is None else '(' + ','.join([str(x) for x in bbox]) + ')'
for tag_str in tag_strs:
query += 'nwr' + tag_str + bbox_str + ';'
if ref is not None:
if not self.profile.get('bounded_update', False):
query += ref + ';'
else:
for bbox in bboxes:
bbox_str = '' if bbox is None else '(' + ','.join(
[str(x) for x in bbox]) + ')'
query += ref + bbox_str + ';'
query += '); out meta qt center;'
return query
def get_bbox(self, points):
"""Plain iterates over the dataset and returns the bounding box
that encloses it."""
padding = self.profile.get('bbox_padding', BBOX_PADDING)
bbox = [90.0, 180.0, -90.0, -180.0]
for p in points:
bbox[0] = min(bbox[0], p.lat - padding)
bbox[1] = min(bbox[1], p.lon - padding)
bbox[2] = max(bbox[2], p.lat + padding)
bbox[3] = max(bbox[3], p.lon + padding)
return bbox
def split_into_bboxes(self, points):
"""
Splits the dataset into multiple bboxes to lower load on the overpass api.
Returns a list of tuples (minlat, minlon, maxlat, maxlon).
"""
max_bboxes = self.profile.get('max_request_boxes', 4)
if max_bboxes <= 1 or len(points) <= 1:
return [self.get_bbox(points)]
# coord, alt coord, total w/h to the left/bottom, total w/h to the right/top
lons = sorted([[d.lon, d.lat, 0, 0] for d in points])
lats = sorted([[d.lat, d.lon, 0, 0] for d in points])
def update_side_dimensions(ar):
"""For each point, calculates the maximum and
minimum bound for all points left and right."""
fwd_top = fwd_bottom = ar[0][1]
back_top = back_bottom = ar[-1][1]
for i in range(len(ar)):
fwd_top = max(fwd_top, ar[i][1])
fwd_bottom = min(fwd_bottom, ar[i][1])
ar[i][2] = fwd_top - fwd_bottom
back_top = max(back_top, ar[-i-1][1])
back_bottom = min(back_bottom, ar[-i-1][1])
ar[-i-1][3] = back_top - back_bottom
def find_max_gap(ar, h):
"""Select an interval between points, which would give
the maximum area if split there."""
max_id = None
max_gap = 0
for i in range(len(ar) - 1):
# "Extra" variables are for area to the left and right
# that would be freed after splitting.
extra_left = (ar[i][0]-ar[0][0]) * (h-ar[i][2])
extra_right = (ar[-1][0]-ar[i+1][0]) * (h-ar[i+1][3])
# Gap is the area of the column between points i and i+1
# plus extra areas to the left and right.
gap = (ar[i+1][0] - ar[i][0]) * h + extra_left + extra_right
if gap > max_gap:
max_id = i
max_gap = gap
return max_id, max_gap
def get_bbox(b, pad=0):
"""Returns a list of [min_lat, min_lon, max_lat, max_lon] for a box."""
return [b[2][0][0]-pad, b[3][0][0]-pad, b[2][-1][0]+pad, b[3][-1][0]+pad]
def split(box, point_array, point_id):
"""Split the box over axis point_array at point point_id...point_id+1.
Modifies the box in-place and returns a new box."""
alt_array = 5 - point_array # 3->2, 2->3
points = box[point_array][point_id+1:]
del box[point_array][point_id+1:]
alt = {True: [], False: []} # True means point is in new box
for p in box[alt_array]:
alt[(p[1], p[0]) >= (points[0][0], points[0][1])].append(p)
new_box = [None] * 4
new_box[point_array] = points
new_box[alt_array] = alt[True]
box[alt_array] = alt[False]
for i in range(2):
box[i] = box[i+2][-1][0] - box[i+2][0][0]
new_box[i] = new_box[i+2][-1][0] - new_box[i+2][0][0]
return new_box
# height, width, lats, lons
boxes = [[lats[-1][0]-lats[0][0], lons[-1][0]-lons[0][0], lats, lons]]
initial_area = boxes[0][0] * boxes[0][1]
while len(boxes) < max_bboxes and len(boxes) <= len(points):
candidate_box = None
area = 0
point_id = None
point_array = None
for box in boxes:
for ar in (2, 3):
# Find a box and an axis for splitting that would decrease the area the most
update_side_dimensions(box[ar])
max_id, max_area = find_max_gap(box[ar], box[3-ar])
if max_area > area:
area = max_area
candidate_box = box
point_id = max_id
point_array = ar
if area * 100 < initial_area:
# Stop splitting when the area decrease is less than 1%
break
logging.debug('Splitting bbox %s at %s %s..%s; area decrease %s%%',
get_bbox(candidate_box),
'longs' if point_array == 3 else 'lats',
candidate_box[point_array][point_id][0],
candidate_box[point_array][point_id+1][0],
round(100*area/initial_area))
boxes.append(split(candidate_box, point_array, point_id))
padding = self.profile.get('bbox_padding', BBOX_PADDING)
return [get_bbox(b, padding) for b in boxes]
def get_categories(self, tags):
def match_query(tags, query):
for tag in query:
if len(tag) == 1:
return tag[0] in tags
else:
value = tags.get(tag[0], None)
if tag[1] is None or tag[1] == '':
return value is None
if value is None:
return False
found = False
for t2 in tag[1:]:
if t2[0] == '~':
if re.search(t2[1:], value):
found = True
elif t2[0] == '!':
if t2[1:].lower() in value.lower():
found = True
elif t2 == value:
found = True
if found:
break
if not found:
return False
return True
def tags_to_query(tags):
return [(k, v) for k, v in tags.items()]
result = set()
qualifies = self.profile.get('qualifies', args=tags)
if qualifies is not None:
if qualifies:
result.add(None)
return result
# First check default query
query = self.profile.get('query', None)
if query is not None:
if isinstance(query, str):
result.add(None)
else:
if isinstance(query[0][0], str):
query = [query]
for q in query:
if match_query(tags, q):
result.add(None)
break
# Then check each category if we got these
categories = self.profile.get('categories', {})
for name, params in categories.items():
if 'tags' not in params and 'query' not in params:
raise ValueError('No tags and query attributes for category "{}"'.format(name))
if match_query(tags, params.get('query', tags_to_query(params.get('tags')))):
result.add(name)
return result
def calc_boxes(self, dataset_points):
profile_bbox = self.profile.get('bbox', True)
if not profile_bbox:
bboxes = [None]
elif hasattr(profile_bbox, '__len__') and len(profile_bbox) == 4:
bboxes = [profile_bbox]
else:
bboxes = self.split_into_bboxes(dataset_points)
return bboxes
def download(self, bboxes=None):
"""Constructs an Overpass API query and requests objects
to match from a server."""
if not bboxes:
pbbox = self.profile.get('bbox', True)
if pbbox and hasattr(pbbox, '__len__') and len(pbbox) == 4:
bboxes = [pbbox]
else:
bboxes = [None]
query = self.construct_overpass_query(bboxes)
logging.debug('Overpass query: %s', query)
r = requests.get(OVERPASS_SERVER + 'interpreter', {'data': query})
if r.encoding is None:
r.encoding = 'utf-8'
if r.status_code != 200:
logging.error('Failed to download data from Overpass API: %s', r.status_code)
if 'rate_limited' in r.text:
r = requests.get(OVERPASS_SERVER + 'status')
logging.warning('Seems like you are rate limited. API status:\n%s', r.text)
else:
logging.error('Error message: %s', r.text)
raise IOError()
if 'runtime error: ' in r.text:
m = re.search(r'runtime error: ([^<]+)', r.text)
error = 'unknown' if not m else m.group(1)
if 'Query timed out' in error:
logging.error(
'Query timed out, try increasing the "overpass_timeout" profile variable')
else:
logging.error('Runtime error: %s', error)
raise IOError()
return self.parse_xml(r.content)
def parse_xml(self, fileobj):
"""Parses an OSM XML file into the "osmdata" field. For ways and relations,
finds the center. Drops objects that do not match the overpass query tags
(see "check_against_profile_tags" method)."""
if isinstance(fileobj, bytes):
xml = etree.fromstring(fileobj)
else:
xml = etree.parse(fileobj).getroot()
nodes = {}
for nd in xml.findall('node'):
nodes[nd.get('id')] = (float(nd.get('lat')), float(nd.get('lon')))
ways = {}
for way in xml.findall('way'):
center = way.find('center')
if center is not None:
ways[way.get('id')] = [float(center.get('lat')), float(center.get('lon'))]
else:
logging.debug('Way %s does not have a center', way.get('id'))
coord = [0, 0]
count = 0
for nd in way.findall('nd'):
if nd.get('ref') in nodes:
count += 1
for i in range(len(coord)):
coord[i] += nodes[nd.get('ref')][i]
ways[way.get('id')] = [coord[0] / count, coord[1] / count]
# For calculating weight of OSM objects
weight_fn = self.profile.get_raw('weight')
osmdata = {}
for el in xml:
tags = {}
for tag in el.findall('tag'):
tags[tag.get('k')] = tag.get('v')
categories = self.get_categories(tags)
if categories is False or categories is None or len(categories) == 0:
continue
if el.tag == 'node':
coord = nodes[el.get('id')]
members = None
elif el.tag == 'way':
coord = ways[el.get('id')]
members = [nd.get('ref') for nd in el.findall('nd')]
elif el.tag == 'relation':
center = el.find('center')
if center is not None:
coord = [float(center.get('lat')), float(center.get('lon'))]
else:
logging.debug('Relation %s does not have a center', el.get('id'))
coord = [0, 0]
count = 0
for m in el.findall('member'):
if m.get('type') == 'node' and m.get('ref') in nodes:
count += 1
for i in range(len(coord)):
coord[i] += nodes[m.get('ref')][i]
elif m.get('type') == 'way' and m.get('ref') in ways:
count += 1
for i in range(len(coord)):
coord[i] += ways[m.get('ref')][i]
if count > 0:
coord = [coord[0] / count, coord[1] / count]
members = [
(m.get('type'), m.get('ref'), m.get('role'))
for m in el.findall('member')
]
else:
continue
if not coord or coord == [0, 0]:
continue
pt = OSMPoint(
el.tag, int(el.get('id')), int(el.get('version')),
coord[0], coord[1], tags, categories)
pt.members = members
if pt.is_poi():
if callable(weight_fn):
weight = weight_fn(pt)
if weight:
if abs(weight) > 3:
pt.dist_offset = weight
else:
pt.dist_offset = weight * self.profile.max_distance
osmdata[pt.id] = pt
return osmdata
def check_moveability(changes):
to_check = [x for x in changes if x['properties']['osm_type'] == 'node' and
x['properties']['action'] == 'modify']
logging.info('Checking moveability of %s modified nodes', len(to_check))
for c in to_check:
p = c['properties']
p['can_move'] = False
r = requests.get('{}node/{}/ways'.format(OSM_API_SERVER, p['osm_id']))
if r.status_code == 200:
xml = etree.fromstring(r.content)
p['can_move'] = xml.find('way') is None