diff --git a/base/base_tests/mem_trie_test.cpp b/base/base_tests/mem_trie_test.cpp index de806d61e0..42e4c627f6 100644 --- a/base/base_tests/mem_trie_test.cpp +++ b/base/base_tests/mem_trie_test.cpp @@ -2,31 +2,50 @@ #include "base/mem_trie.hpp" -#include "std/algorithm.hpp" -#include "std/string.hpp" -#include "std/utility.hpp" -#include "std/vector.hpp" +#include +#include +#include +#include + +namespace +{ +using Trie = my::MemTrie; +using Data = std::vector>; + +void GetTrieContents(Trie const & trie, Data & data) +{ + data.clear(); + trie.ForEachInTrie([&data](std::string const & k, int v) { data.emplace_back(k, v); }); + std::sort(data.begin(), data.end()); +} UNIT_TEST(MemTrie_Basic) { - vector> data = {{"roger", 3}, - {"amy", 1}, - {"emma", 1}, - {"ann", 1}, - {"rob", 1}, - {"roger", 2}, - {"", 0}, - {"roger", 1}}; - my::MemTrie trie; + Data data = {{"roger", 3}, {"amy", 1}, {"emma", 1}, {"ann", 1}, + {"rob", 1}, {"roger", 2}, {"", 0}, {"roger", 1}}; + Trie trie; + TEST_EQUAL(trie.GetNumNodes(), 1, ()); + for (auto const & p : data) trie.Add(p.first, p.second); + TEST_EQUAL(trie.GetNumNodes(), 16, ()); - vector> trie_data; - trie.ForEach([&trie_data](string const & k, int v) - { - trie_data.emplace_back(k, v); - }); - sort(data.begin(), data.end()); - sort(trie_data.begin(), trie_data.end()); - TEST_EQUAL(data, trie_data, ()); + std::sort(data.begin(), data.end()); + + Data contents; + GetTrieContents(trie, contents); + TEST_EQUAL(contents, data, ()); + + TEST_EQUAL(trie.GetNumNodes(), 16, ()); + + Trie newTrie(move(trie)); + + TEST_EQUAL(trie.GetNumNodes(), 1, ()); + GetTrieContents(trie, contents); + TEST(contents.empty(), ()); + + TEST_EQUAL(newTrie.GetNumNodes(), 16, ()); + GetTrieContents(newTrie, contents); + TEST_EQUAL(contents, data, ()); } +} // namespace diff --git a/base/mem_trie.hpp b/base/mem_trie.hpp index 46affa1cdc..b460c8d46c 100644 --- a/base/mem_trie.hpp +++ b/base/mem_trie.hpp @@ -1,61 +1,66 @@ #pragma once #include "base/macros.hpp" +#include "base/stl_add.hpp" +#include #include +#include #include namespace my { // This class is a simple in-memory trie which allows to add // key-value pairs and then traverse them in a sorted order. -template +template class MemTrie { public: MemTrie() = default; + MemTrie(MemTrie && rhs) { *this = std::move(rhs); } - MemTrie(MemTrie && other) : m_root(move(other.m_root)) + MemTrie & operator=(MemTrie && rhs) { - m_numNodes = other.m_numNodes; - other.m_numNodes = 0; - } - - MemTrie & operator=(MemTrie && other) - { - m_root = move(other.m_root); - m_numNodes = other.m_numNodes; - other.m_numNodes = 0; + m_root = std::move(rhs.m_root); + m_numNodes = rhs.m_numNodes; + rhs.m_numNodes = 1; return *this; } // Adds a key-value pair to the trie. - void Add(TString const & key, TValue const & value) + void Add(String const & key, Value const & value) { - Node * cur = &m_root; + auto * cur = &m_root; for (auto const & c : key) { - size_t numNewNodes; - cur = cur->GetMove(c, numNewNodes); - m_numNodes += numNewNodes; + bool created; + cur = &cur->GetMove(c, created); + if (created) + ++m_numNodes; } cur->AddValue(value); } // Traverses all key-value pairs in the trie and calls |toDo| on each of them. template - void ForEach(ToDo && toDo) + void ForEachInTrie(ToDo && toDo) const { - TString prefix; - ForEach(&m_root, prefix, std::forward(toDo)); + String prefix; + ForEachInSubtree(m_root, prefix, std::forward(toDo)); } template - void ForEachInSubtree(TString prefix, ToDo && toDo) const + void ForEachInNode(String const & prefix, ToDo && toDo) const { - Node const * node = MoveTo(prefix); - if (node) - ForEach(node, prefix, std::forward(toDo)); + if (auto const * root = MoveTo(prefix)) + ForEachInNode(*root, prefix, std::forward(toDo)); + } + + template + void ForEachInSubtree(String prefix, ToDo && toDo) const + { + if (auto const * root = MoveTo(prefix)) + ForEachInSubtree(*root, prefix, std::forward(toDo)); } size_t GetNumNodes() const { return m_numNodes; } @@ -63,72 +68,71 @@ public: private: struct Node { - using TChar = typename TString::value_type; + using Char = typename String::value_type; Node() = default; + Node(Node && /* rhs */) = default; - Node(Node && other) = default; + Node & operator=(Node && /* rhs */) = default; - ~Node() + Node & GetMove(Char const & c, bool & created) { - for (auto const & move : m_moves) - delete move.second; - } - - Node & operator=(Node && other) = default; - - Node * GetMove(TChar const & c, size_t & numNewNodes) - { - numNewNodes = 0; - Node *& node = m_moves[c]; + auto & node = m_moves[c]; if (!node) { - node = new Node(); - ++numNewNodes; + node = my::make_unique(); + created = true; } - return node; + else + { + created = false; + } + return *node; } - void AddValue(TValue const & value) { m_values.push_back(value); } + void AddValue(Value const & value) { m_values.push_back(value); } - std::map m_moves; - std::vector m_values; + std::map> m_moves; + std::vector m_values; DISALLOW_COPY(Node); }; - Node const * MoveTo(TString const & key) const + Node const * MoveTo(String const & key) const { - Node const * cur = &m_root; + auto const * cur = &m_root; for (auto const & c : key) { auto const it = cur->m_moves.find(c); if (it == cur->m_moves.end()) return nullptr; - cur = it->second; + cur = it->second.get(); } return cur; } template - void ForEach(Node const * root, TString & prefix, ToDo && toDo) const + void ForEachInNode(Node const & root, String const & prefix, ToDo && toDo) const { - if (!root->m_values.empty()) - { - for (auto const & value : root->m_values) - toDo(prefix, value); - } + for (auto const & value : root.m_values) + toDo(prefix, value); + } - for (auto const & move : root->m_moves) + template + void ForEachInSubtree(Node const & root, String & prefix, ToDo && toDo) const + { + ForEachInNode(root, prefix, toDo); + + for (auto const & move : root.m_moves) { prefix.push_back(move.first); - ForEach(move.second, prefix, toDo); + ForEachInSubtree(*move.second, prefix, toDo); prefix.pop_back(); } } Node m_root; - size_t m_numNodes = 0; + size_t m_numNodes = 1; DISALLOW_COPY(MemTrie); }; // class MemTrie diff --git a/base/stl_add.hpp b/base/stl_add.hpp index 5df723ee69..37476d0bc1 100644 --- a/base/stl_add.hpp +++ b/base/stl_add.hpp @@ -14,8 +14,7 @@ std::unique_ptr make_unique(Args &&... args) { return std::unique_ptr(new T(std::forward(args)...)); } -} - +} // namespace my template class BackInsertFunctor {