diff --git a/base/CMakeLists.txt b/base/CMakeLists.txt index 70f658e657..4bf65b9622 100644 --- a/base/CMakeLists.txt +++ b/base/CMakeLists.txt @@ -13,6 +13,7 @@ set( cache.hpp cancellable.hpp checked_cast.hpp + clustering_map.hpp collection_cast.hpp condition.cpp condition.hpp diff --git a/base/base.pro b/base/base.pro index a963709b98..11c8a56e7e 100644 --- a/base/base.pro +++ b/base/base.pro @@ -49,6 +49,7 @@ HEADERS += \ cache.hpp \ cancellable.hpp \ checked_cast.hpp \ + clustering_map.hpp \ collection_cast.hpp \ condition.hpp \ deferred_task.hpp \ diff --git a/base/base_tests/CMakeLists.txt b/base/base_tests/CMakeLists.txt index 8cc8ddf949..910aeaa0fc 100644 --- a/base/base_tests/CMakeLists.txt +++ b/base/base_tests/CMakeLists.txt @@ -9,6 +9,7 @@ set( buffer_vector_test.cpp bwt_tests.cpp cache_test.cpp + clustering_map_tests.cpp collection_cast_test.cpp condition_test.cpp containers_test.cpp diff --git a/base/base_tests/base_tests.pro b/base/base_tests/base_tests.pro index 32c72bb91b..44cf6b8921 100644 --- a/base/base_tests/base_tests.pro +++ b/base/base_tests/base_tests.pro @@ -19,6 +19,7 @@ SOURCES += \ buffer_vector_test.cpp \ bwt_tests.cpp \ cache_test.cpp \ + clustering_map_tests.cpp \ collection_cast_test.cpp \ condition_test.cpp \ containers_test.cpp \ diff --git a/base/base_tests/clustering_map_tests.cpp b/base/base_tests/clustering_map_tests.cpp new file mode 100644 index 0000000000..f1d670b0d2 --- /dev/null +++ b/base/base_tests/clustering_map_tests.cpp @@ -0,0 +1,84 @@ +#include "testing/testing.hpp" + +#include "base/clustering_map.hpp" + +#include +#include +#include +#include + +using namespace base; +using namespace std; + +namespace +{ +template +vector Sort(vector vs) +{ + sort(vs.begin(), vs.end()); + return vs; +} + +template > +class ClusteringMapAdapter +{ +public: + template + void Append(Key const & key, V && value) + { + m_m.Append(key, std::forward(value)); + } + + void Union(Key const & u, Key const & v) { m_m.Union(u, v); } + + std::vector Get(Key const & key) { return Sort(m_m.Get(key)); } + +private: + ClusteringMap m_m; +}; + +UNIT_TEST(ClusteringMap_Smoke) +{ + { + ClusteringMapAdapter m; + TEST(m.Get(0).empty(), ()); + TEST(m.Get(1).empty(), ()); + + m.Union(0, 1); + TEST(m.Get(0).empty(), ()); + TEST(m.Get(1).empty(), ()); + } + + { + ClusteringMapAdapter m; + m.Append(0, "Hello"); + m.Append(1, "World!"); + + TEST_EQUAL(m.Get(0), vector({"Hello"}), ()); + TEST_EQUAL(m.Get(1), vector({"World!"}), ()); + + m.Union(0, 1); + TEST_EQUAL(m.Get(0), vector({"Hello", "World!"}), ()); + TEST_EQUAL(m.Get(1), vector({"Hello", "World!"}), ()); + + m.Append(2, "alpha"); + m.Append(3, "beta"); + m.Append(4, "gamma"); + + TEST_EQUAL(m.Get(2), vector({"alpha"}), ()); + TEST_EQUAL(m.Get(3), vector({"beta"}), ()); + TEST_EQUAL(m.Get(4), vector({"gamma"}), ()); + + m.Union(2, 3); + m.Union(3, 4); + + TEST_EQUAL(m.Get(2), vector({"alpha", "beta", "gamma"}), ()); + TEST_EQUAL(m.Get(3), vector({"alpha", "beta", "gamma"}), ()); + TEST_EQUAL(m.Get(4), vector({"alpha", "beta", "gamma"}), ()); + + TEST_EQUAL(m.Get(5), vector(), ()); + m.Union(2, 5); + TEST_EQUAL(m.Get(5), vector({"alpha", "beta", "gamma"}), ()); + } +} +} // namespace diff --git a/base/clustering_map.hpp b/base/clustering_map.hpp new file mode 100644 index 0000000000..6eeed876f8 --- /dev/null +++ b/base/clustering_map.hpp @@ -0,0 +1,107 @@ +#pragma once + +#include "base/assert.hpp" + +#include +#include +#include +#include + +namespace base +{ +// Maps keys to lists of values, but allows to clusterize keys +// together, and to get all values from a cluster. +// +// NOTE: the map is NOT thread-safe. +template > +class ClusteringMap +{ +public: + // Appends |value| to the list of values in the cluster + // corresponding to |key|. + // + // Amortized complexity: O(log*(n) * F), where n is the total number + // of keys in the map, F is the complexity of find in unordered_map. + template + void Append(Key const & key, V && value) + { + auto & entry = GetRoot(key); + entry.m_values.push_back(std::forward(value)); + } + + // Unions clusters corresponding to |u| and |v|. + // + // Amortized complexity: O(log*(n) * F + log(m)), where n is the + // total number of keys and m is the total number of values in the + // map, F is the complexity of find in unordered_map. + void Union(Key const & u, Key const & v) + { + auto & ru = GetRoot(u); + auto & rv = GetRoot(v); + if (ru.m_root == rv.m_root) + return; + + if (ru.m_rank < rv.m_rank) + Attach(rv /* root */, ru /* child */); + else + Attach(ru /* root */, rv /* child */); + } + + // Returns all values from the cluster corresponding to |key|. + // + // Amortized complexity: O(log*(n) * F), where n is the total number + // of keys in the map, F is the complexity of find in unordered map. + std::vector const & Get(Key const & key) + { + auto const & entry = GetRoot(key); + return entry.m_values; + } + +private: + struct Entry + { + Key m_root; + size_t m_rank = 0; + std::vector m_values; + }; + + Entry & GetRoot(Key const & key) + { + auto & entry = GetEntry(key); + if (entry.m_root == key) + return entry; + + auto & root = GetRoot(entry.m_root); + entry.m_root = root.m_root; + return root; + } + + void Attach(Entry & parent, Entry & child) + { + ASSERT_LESS_OR_EQUAL(child.m_rank, parent.m_rank, ()); + + child.m_root = parent.m_root; + if (child.m_rank == parent.m_rank) + ++parent.m_rank; + + auto & pv = parent.m_values; + auto & cv = child.m_values; + if (pv.size() < cv.size()) + pv.swap(cv); + pv.insert(pv.end(), cv.begin(), cv.end()); + } + + Entry & GetEntry(Key const & key) + { + auto it = m_table.find(key); + if (it != m_table.end()) + return it->second; + + auto & entry = m_table[key]; + entry.m_root = key; + return entry; + } + + std::unordered_map m_table; +}; +} // namespace base