osm_conflate/conflate.py
2017-02-15 20:25:12 +03:00

480 lines
19 KiB
Python
Executable file

#!/usr/bin/env python3
import argparse
import logging
import requests
import sys
import kdtree
from io import BytesIO
import json # for profiles
import re # for profiles
try:
from lxml import etree
except ImportError:
import xml.etree.ElementTree as etree
OVERPASS_SERVER = 'http://overpass-api.de/api/'
BBOX_PADDING = 0.1 # in degrees
MAX_DISTANCE = 0.001 # how far can object be to be considered a match. 0.001 dg is ~110 m
class SourcePoint:
"""A common class for points. Has an id, latitude and longitude,
and a dict of tags."""
def __init__(self, pid, lat, lon, tags=None):
self.id = str(pid)
self.lat = lat
self.lon = lon
self.tags = {} if tags is None else tags
def __len__(self):
return 2
def __getitem__(self, i):
if i == 0:
return self.lat
elif i == 1:
return self.lon
else:
raise ValueError('A SourcePoint has only lat and lon in a list')
def __eq__(self, other):
return self.id == other.id
def __hash__(self):
return hash(self.id)
class OSMPoint(SourcePoint):
"""An OSM points is a SourcePoint with a few extra fields.
Namely, version, members (for ways and relations), and an action.
The id is compound and created from object type and object id."""
def __init__(self, ptype, pid, version, lat, lon, tags=None):
super().__init__('{}{}'.format(ptype[0], pid), lat, lon, tags)
self.osm_type = ptype
self.osm_id = pid
self.version = version
self.members = None
self.action = None
def to_xml(self):
"""Produces an XML out of the point data. Disregards the "action" field."""
el = etree.Element(self.osm_type, id=str(self.osm_id), version=str(self.version))
for tag, value in self.tags.items():
etree.SubElement(el, 'tag', k=tag, v=value)
if self.osm_type == 'node':
el.set('lat', str(self.lat))
el.set('lon', str(self.lon))
elif self.osm_type == 'way':
for node_id in self.members:
etree.SubElement(el, 'nd', ref=str(node_id))
elif self.osm_type == 'relation':
for member in self.members:
m = etree.SubElement(el, 'member')
for i, n in enumerate(('type', 'ref', 'role')):
m.set(n, str(member[i]))
return el
class ProfileException(Exception):
"""An exception class for the Profile instance."""
def __init__(self, attr, desc):
super().__init__('Field missing in profile: {} ({})'.format(attr, desc))
class Profile:
"""A wrapper for a profile.
A profile is a python script that sets a few local variables.
These variables become properties of the profile, accessible with
a "get" method. If something is a function, it will be called,
optional parameters might be passed to it.
You can compile a list of all supported variables by grepping through
this code, or by looking at a few example profiles. If something
is required, you will be notified of that.
"""
def __init__(self, fileobj):
s = fileobj.read().replace('\r', '')
self.profile = {}
exec(s, globals(), self.profile)
def has(self, attr):
return attr in self.profile
def get(self, attr, default=None, required=None, args=None):
if attr in self.profile:
value = self.profile[attr]
if callable(value):
if args is None:
return value()
else:
return value(*args)
else:
return value
if required is not None:
raise ProfileException(attr, required)
return default
class OsmConflator:
"""The main class for the conflator.
It receives a dataset, after which one must call either
"download_osm" or "parse_osm" methods. Then it is ready to match:
call the "match" method and get results with "to_osc".
"""
def __init__(self, profile, dataset):
self.dataset = {p.id: p for p in dataset}
self.osmdata = {}
self.matched = []
self.profile = profile
if self.profile.get('no_dataset_id', False):
self.ref = None
else:
self.ref = 'ref:' + self.profile.get('dataset_id', required='A fairly unique id of the dataset to query OSM')
def construct_overpass_query(self, bbox=None):
"""Constructs an Overpass API query from the "query" list in the profile.
(k, v) turns into [k=v], (k,) into [k], (k, None) into [!k], (k, "~v") into [k~v]."""
tags = self.profile.get('query', required="a list of tuples. E.g. [('amenity', 'cafe'), ('name', '~Mc.*lds')]")
tag_str = ''
for t in tags:
if len(t) == 1:
q = '"{}"'.format(t[0])
elif t[1] is None or len(t[1]) == 0:
q = '"!{}"'.format(t[0])
elif t[1][0] == '~':
q = '"{}"~"{}"'.format(t[0], t[1][1:])
else:
q = '"{}"="{}"'.format(t[0], t[1])
tag_str += '[' + q + ']'
query = '[out:json][timeout:300];('
bbox_str = '' if bbox is None else '(' + ','.join([str(x) for x in bbox]) + ')'
for t in ('node', 'way', 'relation'):
query += t + tag_str + bbox_str + ';'
if self.ref is not None:
query += t + '["' + self.ref + '"];'
query += '); out meta center;'
return query
def get_dataset_bbox(self):
"""Plain iterates over the dataset and returns the bounding box
that encloses it."""
bbox = [90.0, 180.0, -90.0, -180.0]
for p in self.dataset.values():
bbox[0] = min(bbox[0], p.lat - BBOX_PADDING)
bbox[1] = min(bbox[1], p.lon - BBOX_PADDING)
bbox[2] = max(bbox[2], p.lat + BBOX_PADDING)
bbox[3] = max(bbox[3], p.lon + BBOX_PADDING)
return bbox
def split_into_bboxes(self):
"""
Splits the dataset into multiple bboxes to lower load on the overpass api.
Returns a list of tuples (minlat, minlon, maxlat, maxlon).
Not implemented for now, returns the single big bbox. Not sure if needed.
"""
# TODO
return [self.get_dataset_bbox()]
def check_against_profile_tags(self, tags):
qualifies = self.profile.get('qualifies', args=tags)
if qualifies is not None:
return qualifies
query = self.profile.get('query', None)
if query is not None:
for tag in query:
if len(tag) >= 1:
if tag[0] not in tags:
return False
if len(tag) >= 2 and tag[1][0] != '~':
if tag[1] != tags[tag[0]]:
return False
return True
def download_osm(self):
"""Constructs an Overpass API query and requests objects
to match from a server."""
profile_bbox = self.profile.get('bbox', True)
if not profile_bbox:
bboxes = [None]
elif hasattr(profile_bbox, '__len__') and len(profile_bbox) == 4:
bboxes = [profile_bbox]
else:
bboxes = self.split_into_bboxes()
for b in bboxes:
query = self.construct_overpass_query(b)
logging.debug('Overpass query: %s', query)
r = requests.get(OVERPASS_SERVER + 'interpreter', {'data': query})
if r.status_code != 200:
raise IOError('Failed to download data from Overpass API: {} {}\nQuery: {}'.format(r.status_code, r.text, query))
for el in r.json()['elements']:
if 'tags' not in el:
continue
if 'center' in el:
for ll in ('lat', 'lon'):
el[ll] = el['center'][ll]
if self.check_against_profile_tags(el['tags']):
pt = OSMPoint(el['type'], el['id'], el['version'], el['lat'], el['lon'], el['tags'])
if 'nodes' in el:
pt.members = el['nodes']
elif 'members' in el:
pt.members = [(x['type'], x['ref'], x['role']) for x in el['members']]
self.osmdata[pt.id] = pt
def parse_osm(self, fileobj):
"""Parses an OSM XML file into the "osmdata" field. For ways and relations,
finds the center. Drops objects that do not match the overpass query tags
(see "check_against_profile_tags" method)."""
xml = etree.parse(fileobj).getroot()
nodes = {}
for nd in xml.findall('node'):
nodes[nd.get('id')] = (float(nd.get('lat')), float(nd.get('lon')))
ways = {}
for way in xml.findall('way'):
coord = [0, 0]
count = 0
for nd in way.findall('nd'):
if nd.get('id') in nodes:
count += 1
for i in range(len(coord)):
coord[i] += nodes[nd.get('ref')][i]
ways[way.get('id')] = [coord[0] / count, coord[1] / count]
for el in xml:
tags = {}
for tag in el.findall('tag'):
tags[tag.get('k')] = tag.get('v')
if not self.check_against_profile_tags(tags):
continue
if el.tag == 'node':
coord = nodes[el.get('id')]
members = None
elif el.tag == 'way':
coord = ways[el.get('id')]
members = [nd.get('ref') for nd in el.findall('nd')]
elif el.tag == 'relation':
coord = [0, 0]
count = 0
for m in el.findall('member'):
if m.get('type') == 'node' and m.get('ref') in nodes:
count += 1
for i in range(len(coord)):
coord[i] += nodes[m.get('ref')][i]
elif m.get('type') == 'way' and m.get('ref') in ways:
count += 1
for i in range(len(coord)):
coord[i] += ways[m.get('ref')][i]
coord = [coord[0] / count, coord[1] / count]
members = [(m.get('type'), m.get('ref'), m.get('role')) for m in el.findall('member')]
pt = OSMPoint(el.tag, el.get('id'), el.get('version'), coord[0], coord[1], tags)
pt.members = members
self.osmdata[pt.id] = pt
def register_match(self, dataset_key, osmdata_key, retag=None):
if osmdata_key is not None:
p = self.osmdata[osmdata_key]
del self.osmdata[osmdata_key]
else:
p = None
if dataset_key is not None:
sp = self.dataset[dataset_key]
del self.dataset[dataset_key]
if p is None:
p = OSMPoint('node', -1-len(self.matched), 1, sp.lat, sp.lon, sp.tags)
p.action = 'create'
else:
master_tags = self.profile.get('master_tags', required='a set of authoritative tags that replace OSM values')
changed = False
for k, v in sp.tags.items():
if k not in p.tags or (k in master_tags and p.tags[k] != v):
p.tags[k] = v
changed = True
if changed:
p.action = 'modify'
# If not, action is None and we're not including this object into the osmChange
source = self.profile.get('source', required='value of "source" tag for uploaded OSM objects')
p.tags['source'] = source
if self.ref is not None:
p.tags[self.ref] = sp.id
elif retag:
for k, v in retag.items():
if v is not None:
p.tags[k] = v
elif k in p.tags:
del p.tags[k]
p.action = 'modify'
else:
p.action = 'delete'
if p is not None and p.action is not None:
self.matched.append(p)
def match_dataset_points_smart(self):
"""Smart matching for dataset <-> OSM points.
We find a shortest link between a dataset and an OSM point.
Then we match these and remove both from dicts.
Then find another link and so on, until the length of a link
becomes larger than "max_distance".
Currently the worst case complexity is around O(n^2*log^2 n).
But given the small number of objects to match, and that
the average case complexity is ~O(n*log^2 n), this is fine.
"""
if not self.osmdata:
return
# KDTree distance is squared, so we square the max_distance
max_distance = pow(self.profile.get('max_distance', MAX_DISTANCE), 2)
osm_kd = kdtree.create(list(self.osmdata.values()))
count_matched = 0
dist = []
for sp, v in self.dataset.items():
osm_point, distance = osm_kd.search_nn(v)
if osm_point is not None and distance <= max_distance:
dist.append((distance, sp, osm_point.data))
needs_sorting = True
while dist:
if needs_sorting:
dist.sort(key=lambda x: x[0])
needs_sorting = False
count_matched += 1
osm_point = dist[0][2]
self.register_match(dist[0][1], osm_point.id)
osm_kd = osm_kd.remove(osm_point)
del dist[0]
for i in range(len(dist)-1, -1, -1):
if dist[i][2] == osm_point:
nearest = osm_kd.search_nn(self.dataset[dist[i][1]])
if nearest and nearest[1] <= max_distance:
new_point, distance = nearest
dist[i] = (distance, dist[i][1], new_point.data)
needs_sorting = i == 0 or distance < dist[0][0]
else:
del dist[i]
needs_sorting = i == 0
logging.info('Matched %s points', count_matched)
def match(self):
"""Matches each osm object with a SourcePoint, or marks it as obsolete.
The resulting list of OSM Points are written to the "matched" field."""
if self.ref is not None:
# First match all objects with ref:whatever tag set
count_ref = 0
for k, p in list(self.osmdata.items()):
if self.ref in p.tags:
if p.tags[self.ref] in self.dataset:
count_ref += 1
self.register_match(p.tags[self.ref], k)
logging.info('Updated %s OSM objects with %s tag', count_ref, self.ref)
# Then find matches for unmatched dataset points
self.match_dataset_points_smart()
# Add unmatched dataset points
logging.info('Adding %s unmatched dataset points', len(self.dataset))
for k in list(self.dataset.keys()):
self.register_match(k, None)
# And finally delete some or all of the remaining osm objects
if len(self.osmdata) > 0:
count_deleted = 0
count_retagged = 0
delete_unmatched = self.profile.get('delete_unmatched', False)
retag = self.profile.get('tag_unmatched')
for k, p in list(self.osmdata.items()):
if self.ref is not None and self.ref in p.tags:
# When ref:whatever is present, we can delete that object safely
count_deleted += 1
self.register_match(None, k)
elif delete_unmatched or retag:
if retag:
count_retagged += 1
else:
count_deleted += 1
self.register_match(None, k, retag=retag)
logging.info('Deleted %s and retagged %s unmatched objects from OSM', count_deleted, count_retagged)
def to_osc(self):
"""Returns a string with osmChange."""
osc = etree.Element('osmChange', version='0.6', generator='OSM Conflator')
for osmel in self.matched:
if osmel.action is not None:
el = osmel.to_xml()
etree.SubElement(osc, osmel.action).append(el)
return "<?xml version='1.0' encoding='utf-8'?>\n" + etree.tostring(osc, encoding='utf-8').decode('utf-8')
def read_dataset(profile, fileobj):
"""A helper function to call a "dataset" function in the profile.
If the fileobj is not specified, tries to download a dataset from
an URL specified in "download_url" profile variable."""
if not fileobj:
url = profile.get('download_url')
if url is None:
logging.error('No download_url specified in the profile, please provide a dataset file with --source')
return None
r = requests.get(url)
if r.status_code != 200:
logging.error('Could not download source data: %s %s', r.status_code, r.text)
return None
if len(r.content) == 0:
logging.error('Empty response from %s', url)
return None
fileobj = BytesIO(r.content)
if not profile.has('dataset'):
# The default option is to parse the source as a JSON
try:
data = []
for item in json.load(fileobj):
data.append(SourcePoint(item['id'], item['lat'], item['lon'], item['tags']))
return data
except Exception:
logging.error('Failed to parse the source as a JSON')
return profile.get('dataset', args=(fileobj,), required='returns a list of SourcePoints with the dataset')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='''
OSM Conflator.
Reads a profile with source data and conflates it with OpenStreetMap data.
Produces an osmChange file ready to be uploaded.''')
parser.add_argument('profile', type=argparse.FileType('r'), help='Name of a profile to use')
parser.add_argument('-o', '--osc', type=argparse.FileType('w'), default=sys.stdout, help='Output osmChange file name')
parser.add_argument('-i', '--source', type=argparse.FileType('rb'), help='Source file to pass to the profile dataset() function')
parser.add_argument('--osm', type=argparse.FileType('r'), help='Instead of querying Overpass API, use this unpacked osm file')
parser.add_argument('--verbose', '-v', action='count', help='Display info messages, use -vv for debugging')
options = parser.parse_args()
if options.verbose == 2:
log_level = logging.DEBUG
elif options.verbose == 1:
log_level = logging.INFO
else:
log_level = logging.WARNING
logging.basicConfig(level=log_level, format='%(asctime)s %(message)s', datefmt='%H:%M:%S')
logging.getLogger("requests").setLevel(logging.WARNING)
logging.debug('Loading profile %s', options.profile)
profile = Profile(options.profile)
dataset = read_dataset(profile, options.source)
if not dataset:
logging.error('Empty source dataset')
sys.exit(2)
logging.info('Read %s items from the dataset', len(dataset))
conflator = OsmConflator(profile, dataset)
if options.osm:
conflator.parse_osm(options.osm)
else:
conflator.download_osm()
logging.info('Downloaded %s objects from OSM', len(conflator.osmdata))
conflator.match()
diff = conflator.to_osc()
options.osc.write(diff)