diff --git a/defines.hpp b/defines.hpp index ea9796a2b4..c8d87a2e8f 100644 --- a/defines.hpp +++ b/defines.hpp @@ -29,6 +29,7 @@ #define ALTITUDES_FILE_TAG "altitudes" #define RESTRICTIONS_FILE_TAG "restrictions" #define ROUTING_FILE_TAG "routing" +#define CROSS_MWM_FILE_TAG "cross_mwm" #define FEATURE_OFFSETS_FILE_TAG "offs" #define RANKS_FILE_TAG "ranks" #define REGION_INFO_FILE_TAG "rgninfo" diff --git a/generator/generator_tool/generator_tool.cpp b/generator/generator_tool/generator_tool.cpp index 151b9d5655..cd2d86f292 100644 --- a/generator/generator_tool/generator_tool.cpp +++ b/generator/generator_tool/generator_tool.cpp @@ -69,8 +69,9 @@ DEFINE_bool(split_by_polygons, false, // Routing. DEFINE_string(osrm_file_name, "", "Input osrm file to generate routing info."); DEFINE_bool(make_routing, false, "Make routing info based on osrm file."); -DEFINE_bool(make_cross_section, false, "Make cross section in routing file for cross mwm routing."); +DEFINE_bool(make_cross_section, false, "Make cross section in routing file for cross mwm routing (for old OSRM routing)."); DEFINE_bool(make_routing_index, false, "Make sections with the routing information."); +DEFINE_bool(make_cross_mwm, false, "Make section for cross mwm routing (for new AStar routing)."); DEFINE_string(srtm_path, "", "Path to srtm directory. If set, generates a section with altitude information " "about roads."); @@ -161,7 +162,7 @@ int main(int argc, char ** argv) FLAGS_generate_index || FLAGS_generate_search_index || FLAGS_calc_statistics || FLAGS_type_statistics || FLAGS_dump_types || FLAGS_dump_prefixes || FLAGS_dump_feature_names != "" || FLAGS_check_mwm || FLAGS_srtm_path != "" || - FLAGS_make_routing_index || FLAGS_generate_traffic_keys) + FLAGS_make_routing_index || FLAGS_make_cross_mwm || FLAGS_generate_traffic_keys) { classificator::Load(); classif().SortClassificator(); @@ -256,6 +257,9 @@ int main(int argc, char ** argv) routing::BuildRoutingIndex(datFile, country); } + if (FLAGS_make_cross_mwm) + routing::BuildCrossMwmSection(path, datFile, country); + if (FLAGS_generate_traffic_keys) { if (!traffic::GenerateTrafficKeysFromDataFile(datFile)) diff --git a/generator/routing_index_generator.cpp b/generator/routing_index_generator.cpp index 6db5e03755..19c2b33c24 100644 --- a/generator/routing_index_generator.cpp +++ b/generator/routing_index_generator.cpp @@ -1,9 +1,13 @@ #include "generator/routing_index_generator.hpp" +#include "generator/borders_generator.hpp" +#include "generator/borders_loader.hpp" #include "routing/index_graph.hpp" #include "routing/index_graph_serialization.hpp" #include "routing/vehicle_mask.hpp" +#include "routing/cross_mwm_ramp.hpp" +#include "routing/cross_mwm_ramp_serialization.hpp" #include "routing_common/bicycle_model.hpp" #include "routing_common/car_model.hpp" #include "routing_common/pedestrian_model.hpp" @@ -13,6 +17,7 @@ #include "indexer/point_to_int64.hpp" #include "coding/file_container.hpp" +#include "coding/file_name_utils.hpp" #include "base/checked_cast.hpp" #include "base/logging.hpp" @@ -28,10 +33,10 @@ using namespace routing; namespace { -class Processor final +class VehicleMaskMaker final { public: - explicit Processor(string const & country) + explicit VehicleMaskMaker(string const & country) : m_pedestrianModel(PedestrianModelFactory().GetVehicleModelForCountry(country)) , m_bicycleModel(BicycleModelFactory().GetVehicleModelForCountry(country)) , m_carModel(CarModelFactory().GetVehicleModelForCountry(country)) @@ -41,6 +46,43 @@ public: CHECK(m_carModel, ()); } + VehicleMask CalcRoadMask(FeatureType const & f) const + { + VehicleMask mask = 0; + if (m_pedestrianModel->IsRoad(f)) + mask |= kPedestrianMask; + if (m_bicycleModel->IsRoad(f)) + mask |= kBicycleMask; + if (m_carModel->IsRoad(f)) + mask |= kCarMask; + + return mask; + } + + VehicleMask CalcOneWayMask(FeatureType const & f) const + { + VehicleMask mask = 0; + if (m_pedestrianModel->IsOneWay(f)) + mask |= kPedestrianMask; + if (m_bicycleModel->IsOneWay(f)) + mask |= kBicycleMask; + if (m_carModel->IsOneWay(f)) + mask |= kCarMask; + + return mask; + } + +private: + shared_ptr const m_pedestrianModel; + shared_ptr const m_bicycleModel; + shared_ptr const m_carModel; +}; + +class Processor final +{ +public: + explicit Processor(string const & country) : m_maskMaker(country) {} + void ProcessAllFeatures(string const & filename) { feature::ForEachFromDat(filename, bind(&Processor::ProcessFeature, this, _1, _2)); @@ -64,7 +106,7 @@ public: private: void ProcessFeature(FeatureType const & f, uint32_t id) { - VehicleMask const mask = CalcVehicleMask(f); + VehicleMask const mask = m_maskMaker.CalcRoadMask(f); if (mask == 0) return; @@ -78,25 +120,86 @@ private: } } - VehicleMask CalcVehicleMask(FeatureType const & f) const - { - VehicleMask mask = 0; - if (m_pedestrianModel->IsRoad(f)) - mask |= kPedestrianMask; - if (m_bicycleModel->IsRoad(f)) - mask |= kBicycleMask; - if (m_carModel->IsRoad(f)) - mask |= kCarMask; - - return mask; - } - - shared_ptr const m_pedestrianModel; - shared_ptr const m_bicycleModel; - shared_ptr const m_carModel; + VehicleMaskMaker const m_maskMaker; unordered_map m_posToJoint; unordered_map m_masks; }; + +bool BordersContains(vector const & borders, m2::PointD const & point) +{ + for (m2::RegionD const & region : borders) + { + if (region.Contains(point)) + return true; + } + + return false; +} + +void CalcCrossMwmTransitions(string const & path, string const & mwmFile, string const & country, + vector & transitions, + vector & ramps) +{ + string const polyFile = my::JoinFoldersToPath({path, BORDERS_DIR}, country + BORDERS_EXTENSION); + vector borders; + osm::LoadBorders(polyFile, borders); + + VehicleMaskMaker const maskMaker(country); + + feature::ForEachFromDat(mwmFile, [&](FeatureType const & f, uint32_t featureId) { + VehicleMask const roadMask = maskMaker.CalcRoadMask(f); + if (roadMask == 0) + return; + + f.ParseGeometry(FeatureType::BEST_GEOMETRY); + size_t const pointsCount = f.GetPointsCount(); + if (pointsCount <= 0) + return; + + bool prevPointIn = BordersContains(borders, f.GetPoint(0)); + + for (size_t i = 1; i < pointsCount; ++i) + { + bool const pointIn = BordersContains(borders, f.GetPoint(i)); + if (pointIn != prevPointIn) + { + uint32_t const segmentIdx = base::asserted_cast(i - 1); + VehicleMask const oneWayMask = maskMaker.CalcOneWayMask(f); + + transitions.emplace_back(featureId, segmentIdx, roadMask, oneWayMask, pointIn, + f.GetPoint(i - 1), f.GetPoint(i)); + + for (size_t j = 0; j < ramps.size(); ++j) + { + VehicleMask const mask = GetVehicleMask(static_cast(j)); + CrossMwmRampSerializer::AddTransition(transitions.back(), mask, ramps[j]); + } + } + + prevPointIn = pointIn; + } + }); +} + +void FillWeights(string const & path, string const & country, CrossMwmRamp & ramp) +{ + shared_ptr vehicleModel = CarModelFactory().GetVehicleModelForCountry(country); + shared_ptr estimator = + EdgeEstimator::CreateForCar(nullptr /* trafficStash */, vehicleModel->GetMaxSpeed()); + + Index index; + platform::CountryFile countryFile(country); + index.RegisterMap(LocalCountryFile(path, countryFile, 0)); + MwmSet::MwmHandle handle = index.GetMwmHandleByCountryFile(countryFile); + CHECK(handle.IsAlive(), ()); + + Geometry geometry(GeometryLoader::Create(index, handle.GetId(), vehicleModel)); + + ramp.FillWeights([&](Segment const & enter, Segment const & exit) { + return estimator->CalcHeuristic(geometry.GetPoint(enter.GetRoadPoint(true)), + geometry.GetPoint(exit.GetRoadPoint(true))); + }); +} } // namespace namespace routing @@ -129,4 +232,30 @@ bool BuildRoutingIndex(string const & filename, string const & country) return false; } } + +void BuildCrossMwmSection(string const & path, string const & mwmFile, string const & country) +{ + LOG(LINFO, ("Building cross mwm section for", country)); + my::Timer timer; + + vector ramps(static_cast(VehicleType::Count), kFakeNumMwmId); + + vector transitions; + CalcCrossMwmTransitions(path, mwmFile, country, transitions, ramps); + + FillWeights(path, country, ramps[static_cast(VehicleType::Car)]); + + FilesContainerW cont(mwmFile, FileWriter::OP_WRITE_EXISTING); + FileWriter writer = cont.GetWriter(CROSS_MWM_FILE_TAG); + + DataHeader const dataHeader(mwmFile); + serial::CodingParams const & codingParams = dataHeader.GetDefCodingParams(); + + auto const startPos = writer.Pos(); + CrossMwmRampSerializer::Serialize(transitions, ramps, codingParams, writer); + auto const sectionSize = writer.Pos() - startPos; + + LOG(LINFO, ("Cross mwm section for", country, "generated in", timer.ElapsedSeconds(), + "seconds, section size:", sectionSize, "bytes, transitions:", transitions.size())); +} } // namespace routing diff --git a/generator/routing_index_generator.hpp b/generator/routing_index_generator.hpp index f5158c904d..6d70078618 100644 --- a/generator/routing_index_generator.hpp +++ b/generator/routing_index_generator.hpp @@ -5,4 +5,5 @@ namespace routing { bool BuildRoutingIndex(string const & filename, string const & country); +void BuildCrossMwmSection(string const & path, string const & mwmFile, string const & country); } // namespace routing diff --git a/routing/cross_mwm_ramp.cpp b/routing/cross_mwm_ramp.cpp new file mode 100644 index 0000000000..8a5e82b301 --- /dev/null +++ b/routing/cross_mwm_ramp.cpp @@ -0,0 +1,110 @@ +#include "routing/cross_mwm_ramp.hpp" + +namespace +{ +uint32_t constexpr kFakeId = std::numeric_limits::max(); +} // namespace + +namespace routing +{ +// static +CrossMwmRamp::Weight constexpr CrossMwmRamp::kNoRoute; + +void CrossMwmRamp::AddTransition(uint32_t featureId, uint32_t segmentIdx, bool oneWay, + bool forwardIsEnter, m2::PointD const & backPoint, + m2::PointD const & frontPoint) +{ + Transition transition(kFakeId, kFakeId, oneWay, forwardIsEnter, backPoint, frontPoint); + + if (forwardIsEnter) + { + transition.m_enterIdx = base::asserted_cast(m_enters.size()); + m_enters.emplace_back(m_mwmId, featureId, segmentIdx, true); + } + else + { + transition.m_exitIdx = base::asserted_cast(m_exits.size()); + m_exits.emplace_back(m_mwmId, featureId, segmentIdx, true); + } + + if (!oneWay) + { + if (forwardIsEnter) + { + transition.m_exitIdx = base::asserted_cast(m_exits.size()); + m_exits.emplace_back(m_mwmId, featureId, segmentIdx, false); + } + else + { + transition.m_enterIdx = base::asserted_cast(m_enters.size()); + m_enters.emplace_back(m_mwmId, featureId, segmentIdx, false); + } + } + + m_transitions[Key(featureId, segmentIdx)] = transition; +} + +bool CrossMwmRamp::IsTransition(Segment const & segment, bool isOutgoing) const +{ + auto it = m_transitions.find(Key(segment.GetFeatureId(), segment.GetSegmentIdx())); + if (it == m_transitions.cend()) + return false; + + Transition const & transition = it->second; + if (transition.m_oneWay && !segment.IsForward()) + return false; + + return (segment.IsForward() == transition.m_forwardIsEnter) == isOutgoing; +} + +m2::PointD const & CrossMwmRamp::GetPoint(Segment const & segment, bool front) const +{ + Transition const & transition = GetTransition(segment); + return segment.IsForward() == front ? transition.m_frontPoint : transition.m_backPoint; +} + +void CrossMwmRamp::GetEdgeList(Segment const & segment, bool isOutgoing, + std::vector & edges) const +{ + Transition const & transition = GetTransition(segment); + if (isOutgoing) + { + ASSERT_NOT_EQUAL(transition.m_enterIdx, kFakeId, ()); + for (size_t exitIdx = 0; exitIdx < m_exits.size(); ++exitIdx) + { + Weight const weight = GetWeight(base::asserted_cast(transition.m_enterIdx), exitIdx); + AddEdge(m_exits[exitIdx], weight, edges); + } + } + else + { + ASSERT_NOT_EQUAL(transition.m_exitIdx, kFakeId, ()); + for (size_t enterIdx = 0; enterIdx < m_enters.size(); ++enterIdx) + { + Weight const weight = GetWeight(enterIdx, base::asserted_cast(transition.m_exitIdx)); + AddEdge(m_enters[enterIdx], weight, edges); + } + } +} + +void CrossMwmRamp::AddEdge(Segment const & segment, Weight weight, + std::vector & edges) const +{ + if (weight != kNoRoute) + edges.emplace_back(segment, static_cast(weight)); +} + +CrossMwmRamp::Transition const & CrossMwmRamp::GetTransition(Segment const & segment) const +{ + auto it = m_transitions.find(Key(segment.GetFeatureId(), segment.GetSegmentIdx())); + CHECK(it != m_transitions.cend(), ("Not transition segment:", segment)); + return it->second; +} + +CrossMwmRamp::Weight CrossMwmRamp::GetWeight(size_t enterIdx, size_t exitIdx) const +{ + size_t const i = enterIdx * m_exits.size() + exitIdx; + ASSERT_LESS(i, m_weights.size(), ()); + return m_weights[i]; +} +} // namespace routing diff --git a/routing/cross_mwm_ramp.hpp b/routing/cross_mwm_ramp.hpp new file mode 100644 index 0000000000..e8ac49101b --- /dev/null +++ b/routing/cross_mwm_ramp.hpp @@ -0,0 +1,117 @@ +#pragma once + +#include "routing/segment.hpp" + +#include "base/assert.hpp" + +#include "geometry/point2d.hpp" + +#include +#include +#include +#include + +namespace routing +{ +class CrossMwmRamp final +{ +public: + CrossMwmRamp(NumMwmId mwmId) : m_mwmId(mwmId) {} + void AddTransition(uint32_t featureId, uint32_t segmentIdx, bool oneWay, bool forwardIsEnter, + m2::PointD const & backPoint, m2::PointD const & frontPoint); + + bool IsTransition(Segment const & segment, bool isOutgoing) const; + m2::PointD const & GetPoint(Segment const & segment, bool front) const; + void GetEdgeList(Segment const & segment, bool isOutgoing, + std::vector & edges) const; + + std::vector const & GetEnters() const { return m_enters; } + std::vector const & GetExits() const { return m_exits; } + bool HasWeights() const { return !m_weights.empty(); } + bool WeightsWereLoaded() const { return m_weightsWereLoaded; } + template + void FillWeights(CalcWeight && calcWeight) + { + CHECK(m_weights.empty(), ()); + m_weights.reserve(m_enters.size() * m_exits.size()); + for (size_t i = 0; i < m_enters.size(); ++i) + { + for (size_t j = 0; j < m_exits.size(); ++j) + { + double const weight = calcWeight(m_enters[i], m_exits[j]); + m_weights.push_back(static_cast(std::ceil(weight))); + } + } + } + +private: + // This is internal type used for storing edges weights. + // Weight is the time requred for the route to pass. + // Weight is measured in seconds rounded upwards. + using Weight = uint32_t; + + static Weight constexpr kNoRoute = 0; + + struct Key + { + Key() = default; + + Key(uint32_t featureId, uint32_t segmentIdx) : m_featureId(featureId), m_segmentIdx(segmentIdx) + { + } + + bool operator==(const Key & key) const + { + return m_featureId == key.m_featureId && m_segmentIdx == key.m_segmentIdx; + } + + uint32_t m_featureId = 0; + uint32_t m_segmentIdx = 0; + }; + + struct HashKey + { + size_t operator()(const Key & key) const + { + return std::hash()((static_cast(key.m_featureId) << 32) + + static_cast(key.m_segmentIdx)); + } + }; + + struct Transition + { + Transition() = default; + + Transition(uint32_t enterIdx, uint32_t exitIdx, bool oneWay, bool forwardIsEnter, + m2::PointD const & backPoint, m2::PointD const & frontPoint) + : m_enterIdx(enterIdx) + , m_exitIdx(exitIdx) + , m_backPoint(backPoint) + , m_frontPoint(frontPoint) + , m_oneWay(oneWay) + , m_forwardIsEnter(forwardIsEnter) + { + } + + uint32_t m_enterIdx = 0; + uint32_t m_exitIdx = 0; + m2::PointD m_backPoint = {0.0, 0.0}; + m2::PointD m_frontPoint = {0.0, 0.0}; + bool m_oneWay = false; + bool m_forwardIsEnter = false; + }; + + friend class CrossMwmRampSerializer; + + void AddEdge(Segment const & segment, Weight weight, std::vector & edges) const; + Transition const & GetTransition(Segment const & segment) const; + Weight GetWeight(size_t enterIdx, size_t exitIdx) const; + + NumMwmId const m_mwmId; + std::vector m_enters; + std::vector m_exits; + std::unordered_map m_transitions; + std::vector m_weights; + bool m_weightsWereLoaded = false; +}; +} // namespace routing diff --git a/routing/cross_mwm_ramp_serialization.cpp b/routing/cross_mwm_ramp_serialization.cpp new file mode 100644 index 0000000000..46d1dda389 --- /dev/null +++ b/routing/cross_mwm_ramp_serialization.cpp @@ -0,0 +1,30 @@ +#include "routing/cross_mwm_ramp_serialization.hpp" + +using namespace std; + +namespace routing +{ +// static +uint32_t constexpr CrossMwmRampSerializer::kLastVersion; + +// static +void CrossMwmRampSerializer::WriteTransitions(vector const & transitions, + serial::CodingParams const & codingParams, + uint8_t bitsPerMask, vector & buffer) +{ + MemWriter> memWriter(buffer); + + for (Transition const & transition : transitions) + transition.Serialize(codingParams, bitsPerMask, memWriter); +} + +// static +void CrossMwmRampSerializer::WriteWeights(vector const & weights, + vector & buffer) +{ + MemWriter> memWriter(buffer); + + for (auto weight : weights) + WriteToSink(memWriter, weight); +} +} // namespace routing diff --git a/routing/cross_mwm_ramp_serialization.hpp b/routing/cross_mwm_ramp_serialization.hpp new file mode 100644 index 0000000000..dbc0d7f42b --- /dev/null +++ b/routing/cross_mwm_ramp_serialization.hpp @@ -0,0 +1,332 @@ +#pragma once + +#include "routing/cross_mwm_ramp.hpp" +#include "routing/routing_exceptions.hpp" +#include "routing/vehicle_mask.hpp" + +#include "indexer/coding_params.hpp" +#include "indexer/geometry_serialization.hpp" + +#include "coding/bit_streams.hpp" +#include "coding/reader.hpp" +#include "coding/write_to_sink.hpp" +#include "coding/writer.hpp" + +#include "base/checked_cast.hpp" + +#include +#include + +namespace routing +{ +class CrossMwmRampSerializer final +{ +public: + class Transition final + { + public: + Transition() = default; + + Transition(uint32_t featureId, uint32_t segmentIdx, VehicleMask roadMask, + VehicleMask oneWayMask, bool forwardIsEnter, m2::PointD const & backPoint, + m2::PointD const & frontPoint) + : m_featureId(featureId) + , m_segmentIdx(segmentIdx) + , m_backPoint(backPoint) + , m_frontPoint(frontPoint) + , m_roadMask(roadMask) + , m_oneWayMask(oneWayMask) + , m_forwardIsEnter(forwardIsEnter) + { + } + + template + void Serialize(serial::CodingParams const & codingParams, uint8_t bitsPerMask, + Sink & sink) const + { + WriteToSink(sink, m_featureId); + WriteToSink(sink, m_segmentIdx); + serial::SavePoint(sink, m_backPoint, codingParams); + serial::SavePoint(sink, m_frontPoint, codingParams); + + BitWriter writer(sink); + writer.WriteAtMost32Bits(static_cast(m_roadMask), bitsPerMask); + writer.WriteAtMost32Bits(static_cast(m_oneWayMask), bitsPerMask); + writer.Write(m_forwardIsEnter ? 0 : 1, 1); + } + + template + void Deserialize(serial::CodingParams const & codingParams, uint8_t bitsPerMask, Source & src) + { + m_featureId = ReadPrimitiveFromSource(src); + m_segmentIdx = ReadPrimitiveFromSource(src); + m_backPoint = serial::LoadPoint(src, codingParams); + m_frontPoint = serial::LoadPoint(src, codingParams); + + BitReader reader(src); + m_roadMask = reader.ReadAtMost32Bits(bitsPerMask); + m_oneWayMask = reader.ReadAtMost32Bits(bitsPerMask); + m_forwardIsEnter = reader.Read(1) == 0; + } + + uint32_t GetFeatureId() const { return m_featureId; } + uint32_t GetSegmentIdx() const { return m_segmentIdx; } + m2::PointD const & GetBackPoint() const { return m_backPoint; } + m2::PointD const & GetFrontPoint() const { return m_frontPoint; } + bool ForwardIsEnter() const { return m_forwardIsEnter; } + VehicleMask GetRoadMask() const { return m_roadMask; } + VehicleMask GetOneWayMask() const { return m_oneWayMask; } + + private: + uint32_t m_featureId = 0; + uint32_t m_segmentIdx = 0; + m2::PointD m_backPoint = {0.0, 0.0}; + m2::PointD m_frontPoint = {0.0, 0.0}; + VehicleMask m_roadMask = 0; + VehicleMask m_oneWayMask = 0; + bool m_forwardIsEnter = false; + }; + + CrossMwmRampSerializer() = delete; + + template + static void Serialize(std::vector const & transitions, + vector const & ramps, + serial::CodingParams const & codingParams, Sink & sink) + { + auto const bitsPerMask = static_cast(VehicleType::Count); + vector transitionsBuf; + WriteTransitions(transitions, codingParams, bitsPerMask, transitionsBuf); + + Header header(base::checked_cast(transitions.size()), + base::checked_cast(transitionsBuf.size()), codingParams, bitsPerMask); + vector> weightBuffers(ramps.size()); + + for (size_t i = 0; i < ramps.size(); ++i) + { + CrossMwmRamp const & ramp = ramps[i]; + if (!ramp.HasWeights()) + continue; + + vector & buffer = weightBuffers[i]; + WriteWeights(ramp.m_weights, buffer); + + auto numEnters = base::checked_cast(ramp.GetEnters().size()); + auto numExits = base::checked_cast(ramp.GetExits().size()); + auto const vehicleType = static_cast(i); + header.AddSection(Section(buffer.size(), numEnters, numExits, vehicleType)); + } + + header.Serialize(sink); + FlushBuffer(transitionsBuf, sink); + + for (auto & buffer : weightBuffers) + FlushBuffer(buffer, sink); + } + + template + static void DeserializeTransitions(VehicleType requiredVehicle, CrossMwmRamp & ramp, Source & src) + { + Header header; + header.Deserialize(src); + + uint64_t const transitionsEnd = src.Pos() + header.GetSizeTransitions(); + VehicleMask const requiredMask = GetVehicleMask(requiredVehicle); + auto const numTransitions = base::checked_cast(header.GetNumTransitions()); + + for (size_t i = 0; i < numTransitions; ++i) + { + Transition transition; + transition.Deserialize(header.GetCodingParams(), header.GetBitsPerMask(), src); + AddTransition(transition, requiredMask, ramp); + } + + if (src.Pos() != transitionsEnd) + { + MYTHROW(CorruptedDataException, + ("Wrong position", src.Pos(), "after decoding transitions, expected:", transitionsEnd, + "size:", header.GetSizeTransitions())); + } + } + + template + static void DeserializeWeights(VehicleType requiredVehicle, CrossMwmRamp & ramp, Source & src) + { + CHECK(!ramp.WeightsWereLoaded(), ()); + + Header header; + header.Deserialize(src); + src.Skip(header.GetSizeTransitions()); + + for (Section const & section : header.GetSections()) + { + if (section.GetVehicleType() != requiredVehicle) + { + src.Skip(section.GetSize()); + continue; + } + + size_t const numEnters = ramp.GetEnters().size(); + size_t const numExits = ramp.GetExits().size(); + + if (base::checked_cast(section.GetNumEnters()) != numEnters) + { + MYTHROW(CorruptedDataException, + ("Mismatch enters number, section:", section.GetNumEnters(), ", ramp:", numEnters)); + } + + if (base::checked_cast(section.GetNumExits()) != numExits) + { + MYTHROW(CorruptedDataException, + ("Mismatch exits number, section:", section.GetNumExits(), ", ramp:", numExits)); + } + + size_t const size = numEnters * numExits; + ramp.m_weights.reserve(size); + for (size_t i = 0; i < size; ++i) + { + auto const weight = ReadPrimitiveFromSource(src); + ramp.m_weights.push_back(static_cast(weight)); + } + break; + } + + ramp.m_weightsWereLoaded = true; + } + + static void AddTransition(Transition const & transition, VehicleMask requiredMask, + CrossMwmRamp & ramp) + { + if ((transition.GetRoadMask() & requiredMask) == 0) + return; + + bool const isOneWay = (transition.GetOneWayMask() & requiredMask) != 0; + ramp.AddTransition(transition.GetFeatureId(), transition.GetSegmentIdx(), isOneWay, + transition.ForwardIsEnter(), transition.GetBackPoint(), + transition.GetFrontPoint()); + } + +private: + static uint32_t constexpr kLastVersion = 0; + + class Section final + { + public: + Section() = default; + + Section(uint64_t size, uint32_t numEnters, uint32_t numExits, VehicleType vehicleType) + : m_size(size), m_numEnters(numEnters), m_numExits(numExits), m_vehicleType(vehicleType) + { + } + + template + void Serialize(Sink & sink) const + { + WriteToSink(sink, m_size); + WriteToSink(sink, m_numEnters); + WriteToSink(sink, m_numExits); + WriteToSink(sink, static_cast(m_vehicleType)); + } + + template + void Deserialize(Source & src) + { + m_size = ReadPrimitiveFromSource(src); + m_numEnters = ReadPrimitiveFromSource(src); + m_numExits = ReadPrimitiveFromSource(src); + m_vehicleType = static_cast(ReadPrimitiveFromSource(src)); + } + + uint64_t GetSize() const { return m_size; } + uint32_t GetNumEnters() const { return m_numEnters; } + uint32_t GetNumExits() const { return m_numExits; } + VehicleType GetVehicleType() const { return m_vehicleType; } + + private: + uint64_t m_size = 0; + uint32_t m_numEnters = 0; + uint32_t m_numExits = 0; + VehicleType m_vehicleType = VehicleType::Pedestrian; + }; + + class Header final + { + public: + Header() = default; + + Header(uint32_t numTransitions, uint64_t sizeTransitions, + serial::CodingParams const & codingParams, uint8_t bitsPerMask) + : m_numTransitions(numTransitions) + , m_sizeTransitions(sizeTransitions) + , m_codingParams(codingParams) + , m_bitsPerMask(bitsPerMask) + { + } + + template + void Serialize(Sink & sink) const + { + WriteToSink(sink, m_version); + WriteToSink(sink, m_numTransitions); + WriteToSink(sink, m_sizeTransitions); + m_codingParams.Save(sink); + WriteToSink(sink, m_bitsPerMask); + + WriteToSink(sink, base::checked_cast(m_sections.size())); + for (Section const & section : m_sections) + section.Serialize(sink); + } + + template + void Deserialize(Source & src) + { + m_version = ReadPrimitiveFromSource(src); + if (m_version != kLastVersion) + { + MYTHROW(CorruptedDataException, ("Unknown cross mwm section version ", m_version, + ", current version ", kLastVersion)); + } + + m_numTransitions = ReadPrimitiveFromSource(src); + m_sizeTransitions = ReadPrimitiveFromSource(src); + m_codingParams.Load(src); + m_bitsPerMask = ReadPrimitiveFromSource(src); + + auto const sectionsSize = ReadPrimitiveFromSource(src); + m_sections.resize(base::checked_cast(sectionsSize)); + for (Section & section : m_sections) + section.Deserialize(src); + } + + void AddSection(Section const & section) { m_sections.push_back(section); } + + uint32_t GetNumTransitions() const { return m_numTransitions; } + uint64_t GetSizeTransitions() const { return m_sizeTransitions; } + serial::CodingParams const & GetCodingParams() const { return m_codingParams; } + uint8_t GetBitsPerMask() const { return m_bitsPerMask; } + vector
const & GetSections() const { return m_sections; } + + private: + uint32_t m_version = kLastVersion; + uint32_t m_numTransitions = 0; + uint64_t m_sizeTransitions = 0; + serial::CodingParams m_codingParams; + uint8_t m_bitsPerMask = 0; + vector
m_sections; + }; + + template + static void FlushBuffer(vector & buffer, Sink & sink) + { + sink.Write(buffer.data(), buffer.size()); + buffer.clear(); + } + + static void WriteTransitions(std::vector const & transitions, + serial::CodingParams const & codingParams, uint8_t bitsPerMask, + std::vector & buffer); + + static void WriteWeights(std::vector const & weights, + std::vector & buffer); +}; +} // namespace routing diff --git a/routing/edge_estimator.cpp b/routing/edge_estimator.cpp index e1248ddb28..3a3f8dd01b 100644 --- a/routing/edge_estimator.cpp +++ b/routing/edge_estimator.cpp @@ -73,17 +73,20 @@ double CarEdgeEstimator::CalcSegmentWeight(Segment const & segment, RoadGeometry road.GetPoint(segment.GetPointId(true /* front */)), speedMPS) * kTimePenalty; - auto const * trafficColoring = m_trafficStash->Get(segment.GetMwmId()); - if (trafficColoring) + if (m_trafficStash) { - auto const dir = segment.IsForward() ? TrafficInfo::RoadSegmentId::kForwardDirection - : TrafficInfo::RoadSegmentId::kReverseDirection; - auto const it = trafficColoring->find( - TrafficInfo::RoadSegmentId(segment.GetFeatureId(), segment.GetSegmentIdx(), dir)); - SpeedGroup const speedGroup = - (it == trafficColoring->cend()) ? SpeedGroup::Unknown : it->second; - ASSERT_LESS(speedGroup, SpeedGroup::Count, ()); - result *= CalcTrafficFactor(speedGroup); + auto const * trafficColoring = m_trafficStash->Get(segment.GetMwmId()); + if (trafficColoring) + { + auto const dir = segment.IsForward() ? TrafficInfo::RoadSegmentId::kForwardDirection + : TrafficInfo::RoadSegmentId::kReverseDirection; + auto const it = trafficColoring->find( + TrafficInfo::RoadSegmentId(segment.GetFeatureId(), segment.GetSegmentIdx(), dir)); + SpeedGroup const speedGroup = + (it == trafficColoring->cend()) ? SpeedGroup::Unknown : it->second; + ASSERT_LESS(speedGroup, SpeedGroup::Count, ()); + result *= CalcTrafficFactor(speedGroup); + } } return result; diff --git a/routing/routing.pro b/routing/routing.pro index 6441f1a865..f2aac961f0 100644 --- a/routing/routing.pro +++ b/routing/routing.pro @@ -18,6 +18,8 @@ SOURCES += \ bicycle_directions.cpp \ car_router.cpp \ cross_mwm_index_graph.cpp \ + cross_mwm_ramp.cpp \ + cross_mwm_ramp_serialization.cpp \ cross_mwm_road_graph.cpp \ cross_mwm_router.cpp \ cross_routing_context.cpp \ @@ -68,6 +70,8 @@ HEADERS += \ bicycle_directions.hpp \ car_router.hpp \ cross_mwm_index_graph.hpp \ + cross_mwm_ramp.hpp \ + cross_mwm_ramp_serialization.hpp \ cross_mwm_road_graph.hpp \ cross_mwm_router.hpp \ cross_routing_context.hpp \ diff --git a/routing/routing_tests/cross_mwm_ramp_test.cpp b/routing/routing_tests/cross_mwm_ramp_test.cpp new file mode 100644 index 0000000000..18ac375124 --- /dev/null +++ b/routing/routing_tests/cross_mwm_ramp_test.cpp @@ -0,0 +1,226 @@ +#include "testing/testing.hpp" + +#include "routing/cross_mwm_ramp_serialization.hpp" + +#include "coding/writer.hpp" + +using namespace routing; +using namespace std; + +namespace +{ +NumMwmId constexpr mwmId = 777; + +void CheckRampConsistency(CrossMwmRamp const & ramp) +{ + for (Segment const & enter : ramp.GetEnters()) + { + TEST(ramp.IsTransition(enter, true /* isOutgoing */), ()); + TEST(!ramp.IsTransition(enter, false /* isOutgoing */), ()); + } + + for (Segment const & exit : ramp.GetExits()) + { + TEST(!ramp.IsTransition(exit, true /* isOutgoing */), ()); + TEST(ramp.IsTransition(exit, false /* isOutgoing */), ()); + } +} + +void CheckEdges(CrossMwmRamp const & ramp, Segment const & from, bool isOutgoing, + vector const & expectedEdges) +{ + vector edges; + ramp.GetEdgeList(from, isOutgoing, edges); + TEST_EQUAL(edges, expectedEdges, ()); +} +} + +namespace routing_test +{ +UNIT_TEST(OneWayEnter) +{ + uint32_t constexpr featureId = 1; + uint32_t constexpr segmentIdx = 1; + CrossMwmRamp ramp(mwmId); + ramp.AddTransition(featureId, segmentIdx, true /* oneWay */, true /* forwardIsEnter */, + {} /* backPoint */, {} /* frontPoint */); + + CheckRampConsistency(ramp); + TEST_EQUAL(ramp.GetEnters().size(), 1, ()); + TEST_EQUAL(ramp.GetExits().size(), 0, ()); + TEST(ramp.IsTransition(Segment(mwmId, featureId, segmentIdx, true /* forward */), + true /* isOutgoing */), + ()); + TEST(!ramp.IsTransition(Segment(mwmId, featureId, segmentIdx, true /* forward */), + false /* isOutgoing */), + ()); + TEST(!ramp.IsTransition(Segment(mwmId, featureId, segmentIdx, false /* forward */), + true /* isOutgoing */), + ()); + TEST(!ramp.IsTransition(Segment(mwmId, featureId, segmentIdx, false /* forward */), + false /* isOutgoing */), + ()); +} + +UNIT_TEST(OneWayExit) +{ + uint32_t constexpr featureId = 1; + uint32_t constexpr segmentIdx = 1; + CrossMwmRamp ramp(mwmId); + ramp.AddTransition(featureId, segmentIdx, true /* oneWay */, false /* forwardIsEnter */, + {} /* backPoint */, {} /* frontPoint */); + + CheckRampConsistency(ramp); + TEST_EQUAL(ramp.GetEnters().size(), 0, ()); + TEST_EQUAL(ramp.GetExits().size(), 1, ()); + TEST(!ramp.IsTransition(Segment(mwmId, featureId, segmentIdx, true /* forward */), + true /* isOutgoing */), + ()); + TEST(ramp.IsTransition(Segment(mwmId, featureId, segmentIdx, true /* forward */), + false /* isOutgoing */), + ()); + TEST(!ramp.IsTransition(Segment(mwmId, featureId, segmentIdx, false /* forward */), + true /* isOutgoing */), + ()); + TEST(!ramp.IsTransition(Segment(mwmId, featureId, segmentIdx, false /* forward */), + false /* isOutgoing */), + ()); +} + +UNIT_TEST(TwoWayEnter) +{ + uint32_t constexpr featureId = 1; + uint32_t constexpr segmentIdx = 1; + CrossMwmRamp ramp(mwmId); + ramp.AddTransition(featureId, segmentIdx, false /* oneWay */, true /* forwardIsEnter */, + {} /* backPoint */, {} /* frontPoint */); + + CheckRampConsistency(ramp); + TEST_EQUAL(ramp.GetEnters().size(), 1, ()); + TEST_EQUAL(ramp.GetExits().size(), 1, ()); + TEST(ramp.IsTransition(Segment(mwmId, featureId, segmentIdx, true /* forward */), + true /* isOutgoing */), + ()); + TEST(!ramp.IsTransition(Segment(mwmId, featureId, segmentIdx, true /* forward */), + false /* isOutgoing */), + ()); + TEST(!ramp.IsTransition(Segment(mwmId, featureId, segmentIdx, false /* forward */), + true /* isOutgoing */), + ()); + TEST(ramp.IsTransition(Segment(mwmId, featureId, segmentIdx, false /* forward */), + false /* isOutgoing */), + ()); +} + +UNIT_TEST(TwoWayExit) +{ + uint32_t constexpr featureId = 1; + uint32_t constexpr segmentIdx = 1; + CrossMwmRamp ramp(mwmId); + ramp.AddTransition(featureId, segmentIdx, false /* oneWay */, false /* forwardIsEnter */, + {} /* backPoint */, {} /* frontPoint */); + + CheckRampConsistency(ramp); + TEST_EQUAL(ramp.GetEnters().size(), 1, ()); + TEST_EQUAL(ramp.GetExits().size(), 1, ()); + TEST(!ramp.IsTransition(Segment(mwmId, featureId, segmentIdx, true /* forward */), + true /* isOutgoing */), + ()); + TEST(ramp.IsTransition(Segment(mwmId, featureId, segmentIdx, true /* forward */), + false /* isOutgoing */), + ()); + TEST(ramp.IsTransition(Segment(mwmId, featureId, segmentIdx, false /* forward */), + true /* isOutgoing */), + ()); + TEST(!ramp.IsTransition(Segment(mwmId, featureId, segmentIdx, false /* forward */), + false /* isOutgoing */), + ()); +} + +UNIT_TEST(Serialization) +{ + float constexpr kEdgesWeight = 333; + + vector buffer; + { + vector transitions = { + /* featureId, segmentIdx, roadMask, oneWayMask, forwardIsEnter, backPoint, frontPoint */ + {10, 1, kCarMask, kCarMask, true, m2::PointD(1.1, 1.2), m2::PointD(1.3, 1.4)}, + {20, 2, kCarMask, 0, true, m2::PointD(2.1, 2.2), m2::PointD(2.3, 2.4)}, + {30, 3, kPedestrianMask, kCarMask, true, m2::PointD(3.1, 3.2), m2::PointD(3.3, 3.4)}}; + + vector ramps(static_cast(VehicleType::Count), mwmId); + + CrossMwmRamp & carRamp = ramps[static_cast(VehicleType::Car)]; + for (auto const & transition : transitions) + CrossMwmRampSerializer::AddTransition(transition, kCarMask, carRamp); + + carRamp.FillWeights([](Segment const & enter, Segment const & exit) { return kEdgesWeight; }); + + serial::CodingParams const codingParams; + MemWriter> writer(buffer); + CrossMwmRampSerializer::Serialize(transitions, ramps, codingParams, writer); + } + + CrossMwmRamp ramp(mwmId); + { + MemReader reader(buffer.data(), buffer.size()); + ReaderSource source(reader); + CrossMwmRampSerializer::DeserializeTransitions(VehicleType::Car, ramp, source); + } + + CheckRampConsistency(ramp); + + TEST_EQUAL(ramp.GetEnters().size(), 2, ()); + TEST_EQUAL(ramp.GetExits().size(), 1, ()); + + TEST(!ramp.IsTransition(Segment(mwmId, 0, 0, true), true /* isOutgoing */), ()); + + TEST(ramp.IsTransition(Segment(mwmId, 10, 1, true /* forward */), true /* isOutgoing */), ()); + TEST(!ramp.IsTransition(Segment(mwmId, 10, 1, true /* forward */), false /* isOutgoing */), ()); + TEST(!ramp.IsTransition(Segment(mwmId, 10, 1, false /* forward */), true /* isOutgoing */), ()); + TEST(!ramp.IsTransition(Segment(mwmId, 10, 1, false /* forward */), false /* isOutgoing */), ()); + + TEST(ramp.IsTransition(Segment(mwmId, 20, 2, true /* forward */), true /* isOutgoing */), ()); + TEST(!ramp.IsTransition(Segment(mwmId, 20, 2, true /* forward */), false /* isOutgoing */), ()); + TEST(!ramp.IsTransition(Segment(mwmId, 20, 2, false /* forward */), true /* isOutgoing */), ()); + TEST(ramp.IsTransition(Segment(mwmId, 20, 2, false /* forward */), false /* isOutgoing */), ()); + + TEST(!ramp.IsTransition(Segment(mwmId, 30, 3, true /* forward */), true /* isOutgoing */), ()); + + TEST(!ramp.WeightsWereLoaded(), ()); + TEST(!ramp.HasWeights(), ()); + + { + MemReader reader(buffer.data(), buffer.size()); + ReaderSource source(reader); + CrossMwmRampSerializer::DeserializeWeights(VehicleType::Car, ramp, source); + } + TEST(ramp.WeightsWereLoaded(), ()); + TEST(ramp.HasWeights(), ()); + + double constexpr eps = 1e-6; + TEST(AlmostEqualAbs(ramp.GetPoint(Segment(mwmId, 20, 2, true /* forward */), true /* front */), + m2::PointD(2.3, 2.4), eps), + ()); + TEST(AlmostEqualAbs(ramp.GetPoint(Segment(mwmId, 20, 2, true /* forward */), false /* front */), + m2::PointD(2.1, 2.2), eps), + ()); + TEST(AlmostEqualAbs(ramp.GetPoint(Segment(mwmId, 20, 2, false /* forward */), true /* front */), + m2::PointD(2.1, 2.2), eps), + ()); + TEST(AlmostEqualAbs(ramp.GetPoint(Segment(mwmId, 20, 2, true /* forward */), true /* front */), + m2::PointD(2.3, 2.4), eps), + ()); + + CheckEdges(ramp, Segment(mwmId, 10, 1, true /* forward */), true /* isOutgoing */, + {{Segment(mwmId, 20, 2, false /* forward */), kEdgesWeight}}); + + CheckEdges(ramp, Segment(mwmId, 20, 2, true /* forward */), true /* isOutgoing */, + {{Segment(mwmId, 20, 2, false /* forward */), kEdgesWeight}}); + + CheckEdges(ramp, Segment(mwmId, 20, 2, false /* forward */), false /* isOutgoing */, + {{Segment(mwmId, 10, 1, true /* forward */), kEdgesWeight}, + {Segment(mwmId, 20, 2, true /* forward */), kEdgesWeight}}); +} +} // namespace routing_test diff --git a/routing/routing_tests/routing_tests.pro b/routing/routing_tests/routing_tests.pro index c144590a22..3788a26582 100644 --- a/routing/routing_tests/routing_tests.pro +++ b/routing/routing_tests/routing_tests.pro @@ -26,6 +26,7 @@ SOURCES += \ astar_progress_test.cpp \ astar_router_test.cpp \ async_router_test.cpp \ + cross_mwm_ramp_test.cpp \ cross_routing_tests.cpp \ cumulative_restriction_test.cpp \ followed_polyline_test.cpp \ diff --git a/routing/segment.hpp b/routing/segment.hpp index cbeb14d449..1729b22cb2 100644 --- a/routing/segment.hpp +++ b/routing/segment.hpp @@ -75,6 +75,11 @@ public: Segment const & GetTarget() const { return m_target; } double GetWeight() const { return m_weight; } + bool operator==(SegmentEdge const & edge) const + { + return m_target == edge.m_target && m_weight == edge.m_weight; + } + private: // Target is vertex going to for outgoing edges, vertex going from for ingoing edges. Segment m_target; @@ -88,4 +93,11 @@ inline string DebugPrint(Segment const & segment) << segment.GetSegmentIdx() << ", " << segment.IsForward() << ")"; return out.str(); } + +inline string DebugPrint(SegmentEdge const & edge) +{ + ostringstream out; + out << "Edge(" << DebugPrint(edge.GetTarget()) << ", " << edge.GetWeight() << ")"; + return out.str(); +} } // namespace routing