diff --git a/routing/base/bfs.hpp b/routing/base/bfs.hpp new file mode 100644 index 0000000000..33c9d59d21 --- /dev/null +++ b/routing/base/bfs.hpp @@ -0,0 +1,103 @@ +#pragma once + +#include "base/scope_guard.hpp" + +#include +#include +#include +#include +#include +#include + +namespace routing +{ +template +class BFS +{ +public: + using Vertex = typename Graph::Vertex; + using Edge = typename Graph::Edge; + using Weight = typename Graph::Weight; + + struct State + { + State(Vertex const & v, Vertex const & p) : m_vertex(v), m_parent(p) {} + + Vertex m_vertex; + Vertex m_parent; + }; + + explicit BFS(Graph & graph): m_graph(graph) {} + + void Run(Vertex const & start, bool isOutgoing, + std::function && onVisitCallback); + + std::vector ReconstructPath(Vertex from, bool reverse); + +private: + Graph & m_graph; + std::map m_parents; +}; + +template +void BFS::Run(Vertex const & start, bool isOutgoing, + std::function && onVisitCallback) +{ + m_parents.clear(); + + m_parents.emplace(start, start); + SCOPE_GUARD(removeStart, [&]() { + m_parents.erase(start); + }); + + std::queue queue; + queue.emplace(start); + + std::vector edges; + while (!queue.empty()) + { + Vertex const current = queue.front(); + queue.pop(); + + if (isOutgoing) + m_graph.GetOutgoingEdgesList(current, edges); + else + m_graph.GetIngoingEdgesList(current, edges); + + for (auto const & edge : edges) + { + Vertex const & child = edge.GetTarget(); + if (m_parents.count(child) != 0) + continue; + + State const state(child, current); + if (!onVisitCallback(state)) + continue; + + m_parents.emplace(child, current); + queue.emplace(child); + } + } +} + +template +auto BFS::ReconstructPath(Vertex from, bool reverse) -> std::vector +{ + std::vector result; + auto it = m_parents.find(from); + while (it != m_parents.end()) + { + result.emplace_back(from); + from = it->second; + it = m_parents.find(from); + } + + // Here stored path in inverse order (from end to begin). + result.emplace_back(from); + + if (!reverse) + std::reverse(result.begin(), result.end()); + + return result; +} +} // namespace routing