diff --git a/generator/routing_index_generator.cpp b/generator/routing_index_generator.cpp index 505f77548c..5ab8730273 100644 --- a/generator/routing_index_generator.cpp +++ b/generator/routing_index_generator.cpp @@ -3,9 +3,11 @@ #include "generator/borders_generator.hpp" #include "generator/borders_loader.hpp" +#include "routing/base/astar_algorithm.hpp" #include "routing/cross_mwm_connector.hpp" #include "routing/cross_mwm_connector_serialization.hpp" #include "routing/index_graph.hpp" +#include "routing/index_graph_loader.hpp" #include "routing/index_graph_serialization.hpp" #include "routing/vehicle_mask.hpp" @@ -26,6 +28,7 @@ #include "base/logging.hpp" #include +#include #include #include #include @@ -33,6 +36,7 @@ using namespace feature; using namespace platform; using namespace routing; +using namespace std; namespace { @@ -85,6 +89,7 @@ class Processor final { public: explicit Processor(string const & country) : m_maskBuilder(country) {} + void ProcessAllFeatures(string const & filename) { feature::ForEachFromDat(filename, bind(&Processor::ProcessFeature, this, _1, _2)); @@ -104,6 +109,7 @@ public: } unordered_map const & GetMasks() const { return m_masks; } + private: void ProcessFeature(FeatureType const & f, uint32_t id) { @@ -126,6 +132,33 @@ private: unordered_map m_masks; }; +class DijkstraWrapper final +{ +public: + // AStarAlgorithm types aliases: + using TVertexType = Segment; + using TEdgeType = SegmentEdge; + + DijkstraWrapper(IndexGraph & graph) : m_graph(graph) {} + + void GetOutgoingEdgesList(TVertexType const & vertex, vector & edges) + { + edges.clear(); + m_graph.GetEdgeList(vertex, true /* isOutgoing */, edges); + } + + void GetIngoingEdgesList(TVertexType const & vertex, vector & edges) + { + edges.clear(); + m_graph.GetEdgeList(vertex, false /* isOutgoing */, edges); + } + + double HeuristicCostEstimate(TVertexType const & from, TVertexType const & to) { return 0.0; } + +private: + IndexGraph & m_graph; +}; + bool RegionsContain(vector const & regions, m2::PointD const & point) { for (auto const & region : regions) @@ -182,16 +215,55 @@ void CalcCrossMwmTransitions(string const & path, string const & mwmFile, string }); } -void FillWeights(string const & path, string const & country, CrossMwmConnector & connector) +void FillWeights(string const & path, string const & mwmFile, string const & country, + CrossMwmConnector & connector) { shared_ptr vehicleModel = CarModelFactory().GetVehicleModelForCountry(country); - shared_ptr estimator = - EdgeEstimator::CreateForCar(nullptr /* trafficStash */, vehicleModel->GetMaxSpeed()); + IndexGraph graph( + GeometryLoader::CreateFromFile(mwmFile, vehicleModel), + EdgeEstimator::CreateForCar(nullptr /* trafficStash */, vehicleModel->GetMaxSpeed())); + + MwmValue mwmValue(LocalCountryFile(path, platform::CountryFile(country), 0)); + DeserializeIndexGraph(mwmValue, graph); + + map> weights; + map distanceMap; + map parent; + + auto const numEnters = connector.GetEnters().size(); + cout << "Building leaps: 0/" << numEnters << " waves passed" << flush; + for (size_t i = 0; i < numEnters; ++i) + { + Segment const & enter = connector.GetEnter(i); + + AStarAlgorithm astar; + DijkstraWrapper wrapper(graph); + astar.PropagateWave( + wrapper, enter, [](Segment const & vertex) { return false; }, + [](Segment const & vertex, SegmentEdge const & edge) { return edge.GetWeight(); }, + distanceMap, parent); + + for (Segment const & exit : connector.GetExits()) + { + auto it = distanceMap.find(exit); + if (it != distanceMap.end()) + weights[enter][exit] = it->second; + } + + cout << "\rBuilding leaps: " << (i + 1) << "/" << numEnters << " waves passed" << flush; + } + cout << endl; connector.FillWeights([&](Segment const & enter, Segment const & exit) { - // TODO replace fake weights with weights calculated by routing. - return estimator->CalcHeuristic(connector.GetPoint(enter, true /* front */), - connector.GetPoint(exit, true /* front */)); + auto it0 = weights.find(enter); + if (it0 == weights.end()) + return CrossMwmConnector::kNoRoute; + + auto it1 = it0->second.find(exit); + if (it1 == it0->second.end()) + return CrossMwmConnector::kNoRoute; + + return it1->second; }); } @@ -243,7 +315,19 @@ void BuildCrossMwmSection(string const & path, string const & mwmFile, string co vector transitions; CalcCrossMwmTransitions(path, mwmFile, country, transitions, connectors); - FillWeights(path, country, connectors[static_cast(VehicleType::Car)]); + LOG(LINFO, ("Transitions finished, transitions:", transitions.size(), ", elapsed:", + timer.ElapsedSeconds(), "seconds")); + for (size_t i = 0; i < connectors.size(); ++i) + { + VehicleType const vehicleType = static_cast(i); + CrossMwmConnector const & connector = connectors[i]; + LOG(LINFO, (vehicleType, "model:", "enters:", connector.GetEnters().size(), ", exits:", + connector.GetExits().size())); + } + timer.Reset(); + + FillWeights(path, mwmFile, country, connectors[static_cast(VehicleType::Car)]); + LOG(LINFO, ("Leaps finished, elapsed:", timer.ElapsedSeconds(), "seconds")); serial::CodingParams const codingParams = LoadCodingParams(mwmFile); FilesContainerW cont(mwmFile, FileWriter::OP_WRITE_EXISTING); @@ -252,7 +336,6 @@ void BuildCrossMwmSection(string const & path, string const & mwmFile, string co CrossMwmConnectorSerializer::Serialize(transitions, connectors, codingParams, writer); auto const sectionSize = writer.Pos() - startPos; - LOG(LINFO, ("Cross mwm section for", country, "generated in", timer.ElapsedSeconds(), - "seconds, section size:", sectionSize, "bytes, transitions:", transitions.size())); + LOG(LINFO, ("Cross mwm section generated, size:", sectionSize, "bytes")); } } // namespace routing diff --git a/routing/base/astar_algorithm.hpp b/routing/base/astar_algorithm.hpp index 1c753701b3..f26d50bd9f 100644 --- a/routing/base/astar_algorithm.hpp +++ b/routing/base/astar_algorithm.hpp @@ -61,9 +61,14 @@ public: using TOnVisitedVertexCallback = std::function; - Result FindPath(TGraphType & graph, - TVertexType const & startVertex, TVertexType const & finalVertex, - RoutingResult & result, + template + void PropagateWave(TGraphType & graph, TVertexType const & startVertex, + CheckForStop && checkForStop, AdjustEdgeWeight && adjustEdgeWeight, + map & bestDistance, + map & parent) const; + + Result FindPath(TGraphType & graph, TVertexType const & startVertex, + TVertexType const & finalVertex, RoutingResult & result, my::Cancellable const & cancellable = my::Cancellable(), TOnVisitedVertexCallback onVisitedVertexCallback = nullptr) const; @@ -169,6 +174,57 @@ private: vector & path); }; +template +template +void AStarAlgorithm::PropagateWave(TGraphType & graph, TVertexType const & startVertex, + CheckForStop && checkForStop, + AdjustEdgeWeight && adjustEdgeWeight, + map & bestDistance, + map & parent) const +{ + bestDistance.clear(); + parent.clear(); + + priority_queue, greater> queue; + + bestDistance[startVertex] = 0.0; + queue.push(State(startVertex, 0.0)); + + vector adj; + + while (!queue.empty()) + { + State const stateV = queue.top(); + queue.pop(); + + if (stateV.distance > bestDistance[stateV.vertex]) + continue; + + if (checkForStop(stateV.vertex)) + return; + + graph.GetOutgoingEdgesList(stateV.vertex, adj); + for (auto const & edge : adj) + { + State stateW(edge.GetTarget(), 0.0); + if (stateV.vertex == stateW.vertex) + continue; + + double const edgeWeight = adjustEdgeWeight(stateV.vertex, edge); + double const newReducedDist = stateV.distance + edgeWeight; + + auto const t = bestDistance.find(stateW.vertex); + if (t != bestDistance.end() && newReducedDist >= t->second - kEpsilon) + continue; + + stateW.distance = newReducedDist; + bestDistance[stateW.vertex] = newReducedDist; + parent[stateW.vertex] = stateV.vertex; + queue.push(stateW); + } + } +} + // This implementation is based on the view that the A* algorithm // is equivalent to Dijkstra's algorithm that is run on a reweighted // version of the graph. If an edge (v, w) has length l(v, w), its reduced @@ -182,77 +238,56 @@ private: template typename AStarAlgorithm::Result AStarAlgorithm::FindPath( - TGraphType & graph, - TVertexType const & startVertex, TVertexType const & finalVertex, - RoutingResult & result, - my::Cancellable const & cancellable, + TGraphType & graph, TVertexType const & startVertex, TVertexType const & finalVertex, + RoutingResult & result, my::Cancellable const & cancellable, TOnVisitedVertexCallback onVisitedVertexCallback) const { result.Clear(); if (nullptr == onVisitedVertexCallback) - onVisitedVertexCallback = [](TVertexType const &, TVertexType const &){}; + onVisitedVertexCallback = [](TVertexType const &, TVertexType const &) {}; map bestDistance; - priority_queue, greater> queue; map parent; - - bestDistance[startVertex] = 0.0; - queue.push(State(startVertex, 0.0)); - - vector adj; - uint32_t steps = 0; - while (!queue.empty()) - { + Result resultCode = Result::NoPath; + + auto checkForStop = [&](TVertexType const & vertex) { ++steps; if (steps % kCancelledPollPeriod == 0 && cancellable.IsCancelled()) - return Result::Cancelled; - - State const stateV = queue.top(); - queue.pop(); - - if (stateV.distance > bestDistance[stateV.vertex]) - continue; + { + resultCode = Result::Cancelled; + return true; + } if (steps % kVisitedVerticesPeriod == 0) - onVisitedVertexCallback(stateV.vertex, finalVertex); + onVisitedVertexCallback(vertex, finalVertex); - if (stateV.vertex == finalVertex) + if (vertex == finalVertex) { - ReconstructPath(stateV.vertex, parent, result.path); - result.distance = stateV.distance - graph.HeuristicCostEstimate(stateV.vertex, finalVertex) + graph.HeuristicCostEstimate(startVertex, finalVertex); - ASSERT_EQUAL(graph.HeuristicCostEstimate(stateV.vertex, finalVertex), 0, ()); - return Result::OK; + ReconstructPath(vertex, parent, result.path); + result.distance = bestDistance[vertex] - graph.HeuristicCostEstimate(vertex, finalVertex) + + graph.HeuristicCostEstimate(startVertex, finalVertex); + ASSERT_EQUAL(graph.HeuristicCostEstimate(vertex, finalVertex), 0, ()); + resultCode = Result::OK; + return true; } - graph.GetOutgoingEdgesList(stateV.vertex, adj); - for (auto const & edge : adj) - { - State stateW(edge.GetTarget(), 0.0); - if (stateV.vertex == stateW.vertex) - continue; + return false; + }; - double const len = edge.GetWeight(); - double const piV = graph.HeuristicCostEstimate(stateV.vertex, finalVertex); - double const piW = graph.HeuristicCostEstimate(stateW.vertex, finalVertex); - double const reducedLen = len + piW - piV; + auto adjustEdgeWeight = [&](TVertexType const & vertex, TEdgeType const & edge) { + double const len = edge.GetWeight(); + double const piV = graph.HeuristicCostEstimate(vertex, finalVertex); + double const piW = graph.HeuristicCostEstimate(edge.GetTarget(), finalVertex); + double const reducedLen = len + piW - piV; - CHECK(reducedLen >= -kEpsilon, ("Invariant violated:", reducedLen, "<", -kEpsilon)); - double const newReducedDist = stateV.distance + max(reducedLen, 0.0); + CHECK(reducedLen >= -kEpsilon, ("Invariant violated:", reducedLen, "<", -kEpsilon)); + return max(reducedLen, 0.0); + }; - auto const t = bestDistance.find(stateW.vertex); - if (t != bestDistance.end() && newReducedDist >= t->second - kEpsilon) - continue; - - stateW.distance = newReducedDist; - bestDistance[stateW.vertex] = newReducedDist; - parent[stateW.vertex] = stateV.vertex; - queue.push(stateW); - } - } - - return Result::NoPath; + PropagateWave(graph, startVertex, checkForStop, adjustEdgeWeight, bestDistance, parent); + return resultCode; } template diff --git a/routing/cross_mwm_connector.cpp b/routing/cross_mwm_connector.cpp index cf6d62469f..14f28e940a 100644 --- a/routing/cross_mwm_connector.cpp +++ b/routing/cross_mwm_connector.cpp @@ -8,7 +8,7 @@ uint32_t constexpr kFakeId = std::numeric_limits::max(); namespace routing { // static -CrossMwmConnector::Weight constexpr CrossMwmConnector::kNoRoute; +double constexpr CrossMwmConnector::kNoRoute; void CrossMwmConnector::AddTransition(uint32_t featureId, uint32_t segmentIdx, bool oneWay, bool forwardIsEnter, m2::PointD const & backPoint, @@ -112,7 +112,7 @@ std::string DebugPrint(CrossMwmConnector::WeightsLoadState state) void CrossMwmConnector::AddEdge(Segment const & segment, Weight weight, std::vector & edges) const { - if (weight != kNoRoute) + if (weight != static_cast(kNoRoute)) edges.emplace_back(segment, static_cast(weight)); } diff --git a/routing/cross_mwm_connector.hpp b/routing/cross_mwm_connector.hpp index 22ef84de61..2de2a25a96 100644 --- a/routing/cross_mwm_connector.hpp +++ b/routing/cross_mwm_connector.hpp @@ -16,6 +16,8 @@ namespace routing class CrossMwmConnector final { public: + static double constexpr kNoRoute = 0.0; + CrossMwmConnector() : m_mwmId(kFakeNumMwmId) {} explicit CrossMwmConnector(NumMwmId mwmId) : m_mwmId(mwmId) {} @@ -28,7 +30,20 @@ public: std::vector & edges) const; std::vector const & GetEnters() const { return m_enters; } - std::vector const & GetExits() const { return m_exits; } + std::vector const & GetExits() const { return m_exits; } + + Segment const & GetEnter(size_t i) const + { + ASSERT_LESS(i, m_enters.size(), ()); + return m_enters[i]; + } + + Segment const & GetExit(size_t i) const + { + ASSERT_LESS(i, m_exits.size(), ()); + return m_exits[i]; + } + bool HasWeights() const { return !m_weights.empty(); } bool WeightsWereLoaded() const; @@ -56,8 +71,6 @@ private: // Weight is measured in seconds rounded upwards. using Weight = uint32_t; - static Weight constexpr kNoRoute = 0; - struct Key { Key() = default; diff --git a/routing/geometry.cpp b/routing/geometry.cpp index 0b65626c5b..9f6bba35d1 100644 --- a/routing/geometry.cpp +++ b/routing/geometry.cpp @@ -12,6 +12,7 @@ using namespace routing; namespace { +// GeometryLoaderImpl ------------------------------------------------------------------------------ class GeometryLoaderImpl final : public GeometryLoader { public: @@ -32,7 +33,7 @@ GeometryLoaderImpl::GeometryLoaderImpl(Index const & index, MwmSet::MwmId const shared_ptr vehicleModel) : m_vehicleModel(vehicleModel), m_guard(index, mwmId), m_country(country) { - ASSERT(m_vehicleModel, ()); + CHECK(m_vehicleModel, ()); } void GeometryLoaderImpl::Load(uint32_t featureId, RoadGeometry & road) const @@ -45,6 +46,36 @@ void GeometryLoaderImpl::Load(uint32_t featureId, RoadGeometry & road) const feature.ParseGeometry(FeatureType::BEST_GEOMETRY); road.Load(*m_vehicleModel, feature); } + +// FileGeometryLoader ------------------------------------------------------------------------------ +class FileGeometryLoader final : public GeometryLoader +{ +public: + FileGeometryLoader(string const & fileName, shared_ptr vehicleModel); + + // GeometryLoader overrides: + virtual void Load(uint32_t featureId, RoadGeometry & road) const override; + +private: + FeaturesVectorTest m_featuresVector; + shared_ptr m_vehicleModel; +}; + +FileGeometryLoader::FileGeometryLoader(string const & fileName, + shared_ptr vehicleModel) + : m_featuresVector(FilesContainerR(make_unique(fileName))) + , m_vehicleModel(vehicleModel) +{ + CHECK(m_vehicleModel, ()); +} + +void FileGeometryLoader::Load(uint32_t featureId, RoadGeometry & road) const +{ + FeatureType feature; + m_featuresVector.GetVector().GetByIndex(featureId, feature); + feature.ParseGeometry(FeatureType::BEST_GEOMETRY); + road.Load(*m_vehicleModel, feature); +} } // namespace namespace routing @@ -104,4 +135,11 @@ unique_ptr GeometryLoader::Create(Index const & index, MwmSet::M return make_unique(index, mwmId, mwmId.GetInfo()->GetCountryName(), vehicleModel); } + +// static +unique_ptr GeometryLoader::CreateFromFile(string const & fileName, + shared_ptr vehicleModel) +{ + return make_unique(fileName, vehicleModel); +} } // namespace routing diff --git a/routing/geometry.hpp b/routing/geometry.hpp index 909cbc52b9..d97d312072 100644 --- a/routing/geometry.hpp +++ b/routing/geometry.hpp @@ -66,6 +66,9 @@ public: // mwmId should be alive: it is caller responsibility to check it. static unique_ptr Create(Index const & index, MwmSet::MwmId const & mwmId, shared_ptr vehicleModel); + + static unique_ptr CreateFromFile(string const & fileName, + shared_ptr vehicleModel); }; class Geometry final diff --git a/routing/index_graph_loader.cpp b/routing/index_graph_loader.cpp index 0d166fdee5..380df7598b 100644 --- a/routing/index_graph_loader.cpp +++ b/routing/index_graph_loader.cpp @@ -70,18 +70,11 @@ IndexGraph & IndexGraphLoaderImpl::Load(NumMwmId numMwmId) auto const mwmId = MwmSet::MwmId(handle.GetInfo()); auto graphPtr = make_unique(GeometryLoader::Create(m_index, mwmId, vehicleModel), m_estimator); - auto & graph = *graphPtr; - - MwmValue const & mwmValue = *handle.GetValue(); + IndexGraph & graph = *graphPtr; my::Timer timer; - FilesContainerR::TReader reader(mwmValue.m_cont.GetReader(ROUTING_FILE_TAG)); - ReaderSource src(reader); - IndexGraphSerializer::Deserialize(graph, src, kCarMask); - RestrictionLoader restrictionLoader(mwmValue, graph); - if (restrictionLoader.HasRestrictions()) - graph.SetRestrictions(restrictionLoader.StealRestrictions()); - + MwmValue const & mwmValue = *handle.GetValue(); + DeserializeIndexGraph(mwmValue, graph); m_graphs[numMwmId] = move(graphPtr); LOG(LINFO, (ROUTING_FILE_TAG, "section for", file.GetName(), "loaded in", timer.ElapsedSeconds(), "seconds")); @@ -100,4 +93,14 @@ unique_ptr IndexGraphLoader::Create( { return make_unique(numMwmIds, vehicleModelFactory, estimator, index); } + +void DeserializeIndexGraph(MwmValue const & mwmValue, IndexGraph & graph) +{ + FilesContainerR::TReader reader(mwmValue.m_cont.GetReader(ROUTING_FILE_TAG)); + ReaderSource src(reader); + IndexGraphSerializer::Deserialize(graph, src, kCarMask); + RestrictionLoader restrictionLoader(mwmValue, graph); + if (restrictionLoader.HasRestrictions()) + graph.SetRestrictions(restrictionLoader.StealRestrictions()); +} } // namespace routing diff --git a/routing/index_graph_loader.hpp b/routing/index_graph_loader.hpp index eb942e201f..217cec4f6e 100644 --- a/routing/index_graph_loader.hpp +++ b/routing/index_graph_loader.hpp @@ -25,4 +25,6 @@ public: std::shared_ptr vehicleModelFactory, std::shared_ptr estimator, Index & index); }; + +void DeserializeIndexGraph(MwmValue const & mwmValue, IndexGraph & graph); } // namespace routing diff --git a/routing/vehicle_mask.hpp b/routing/vehicle_mask.hpp index 95416015ee..d201f9805a 100644 --- a/routing/vehicle_mask.hpp +++ b/routing/vehicle_mask.hpp @@ -1,6 +1,7 @@ #pragma once -#include "std/cstdint.hpp" +#include +#include namespace routing { @@ -25,4 +26,15 @@ VehicleMask constexpr kAllVehiclesMask = kNumVehicleMasks - 1; VehicleMask constexpr kPedestrianMask = GetVehicleMask(VehicleType::Pedestrian); VehicleMask constexpr kBicycleMask = GetVehicleMask(VehicleType::Bicycle); VehicleMask constexpr kCarMask = GetVehicleMask(VehicleType::Car); + +inline std::string DebugPrint(VehicleType vehicleType) +{ + switch (vehicleType) + { + case VehicleType::Pedestrian: return "Pedestrian"; + case VehicleType::Bicycle: return "Bicycle"; + case VehicleType::Car: return "Car"; + case VehicleType::Count: return "Count"; + } +} } // namespace routing