From 8c987b3ab47bcdfc8996f6bdedac186addfc8d98 Mon Sep 17 00:00:00 2001 From: Maksim Andrianov Date: Wed, 16 Jan 2019 19:10:03 +0300 Subject: [PATCH] Added PrimitiveThreadPool class. --- base/CMakeLists.txt | 2 + base/base_tests/CMakeLists.txt | 1 + .../primitive_thread_pool_tests.cpp | 130 ++++++++++++++++++ base/primitive_thread_pool.hpp | 105 ++++++++++++++ base/thread_utils.hpp | 64 +++++++++ 5 files changed, 302 insertions(+) create mode 100644 base/base_tests/primitive_thread_pool_tests.cpp create mode 100644 base/primitive_thread_pool.hpp create mode 100644 base/thread_utils.hpp diff --git a/base/CMakeLists.txt b/base/CMakeLists.txt index e8133931d1..7744b6c475 100644 --- a/base/CMakeLists.txt +++ b/base/CMakeLists.txt @@ -51,6 +51,7 @@ set( observer_list.hpp pprof.cpp pprof.hpp + primitive_thread_pool.hpp random.cpp random.hpp range_iterator.hpp @@ -83,6 +84,7 @@ set( thread_checker.hpp thread_pool.cpp thread_pool.hpp + thread_utils.hpp threaded_container.cpp threaded_container.hpp threaded_list.hpp diff --git a/base/base_tests/CMakeLists.txt b/base/base_tests/CMakeLists.txt index 443934b739..90c92f694f 100644 --- a/base/base_tests/CMakeLists.txt +++ b/base/base_tests/CMakeLists.txt @@ -25,6 +25,7 @@ set( move_to_front_tests.cpp newtype_test.cpp observer_list_test.cpp + primitive_thread_pool_tests.cpp range_iterator_test.cpp ref_counted_tests.cpp regexp_test.cpp diff --git a/base/base_tests/primitive_thread_pool_tests.cpp b/base/base_tests/primitive_thread_pool_tests.cpp new file mode 100644 index 0000000000..6180c82c1c --- /dev/null +++ b/base/base_tests/primitive_thread_pool_tests.cpp @@ -0,0 +1,130 @@ +#include "testing/testing.hpp" + +#include +#include +#include +#include +#include + +#include "base/primitive_thread_pool.hpp" + +namespace +{ +size_t const kTimes = 500; +} // namespace + +UNIT_TEST(PrimitiveThreadPool_SomeThreads) +{ + for (size_t t = 0; t < kTimes; ++t) + { + size_t threadCount = 4; + size_t counter = 0; + { + std::mutex mutex; + threads::PrimitiveThreadPool threadPool(threadCount); + for (size_t i = 0; i < threadCount; ++i) + { + threadPool.Submit([&]() { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + std::lock_guard lock(mutex); + ++counter; + }); + } + } + + TEST_EQUAL(threadCount, counter, ()); + } +} + +UNIT_TEST(PrimitiveThreadPool_OneThread) +{ + for (size_t t = 0; t < kTimes; ++t) + { + size_t threadCount = 1; + size_t counter = 0; + { + std::mutex mutex; + threads::PrimitiveThreadPool threadPool(threadCount); + for (size_t i = 0; i < threadCount; ++i) + { + threadPool.Submit([&]() { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + std::lock_guard lock(mutex); + ++counter; + }); + } + } + + TEST_EQUAL(threadCount, counter, ()); + } +} + +UNIT_TEST(PrimitiveThreadPool_ManyThread) +{ + for (size_t t = 0; t < kTimes; ++t) + { + size_t threadCount = std::thread::hardware_concurrency(); + CHECK_NOT_EQUAL(threadCount, 0, ()); + threadCount *= 2; + size_t counter = 0; + { + std::mutex mutex; + threads::PrimitiveThreadPool threadPool(threadCount); + for (size_t i = 0; i < threadCount; ++i) + { + threadPool.Submit([&]() { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + std::lock_guard lock(mutex); + ++counter; + }); + } + } + + TEST_EQUAL(threadCount, counter, ()); + } +} + +UNIT_TEST(PrimitiveThreadPool_ReturnValue) +{ + for (size_t t = 0; t < kTimes; ++t) + { + size_t threadCount = 4; + threads::PrimitiveThreadPool threadPool(threadCount); + std::vector> futures; + for (size_t i = 0; i < threadCount; ++i) + { + auto f = threadPool.Submit([=]() { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + return i; + }); + + futures.push_back(std::move(f)); + } + + for (size_t i = 0; i < threadCount; ++i) + TEST_EQUAL(futures[i].get(), i, ()); + } +} + +UNIT_TEST(PrimitiveThreadPool_ManyTasks) +{ + for (size_t t = 0; t < kTimes; ++t) + { + size_t taskCount = 11; + size_t counter = 0; + { + std::mutex mutex; + threads::PrimitiveThreadPool threadPool(4); + for (size_t i = 0; i < taskCount; ++i) + { + threadPool.Submit([&]() { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + std::lock_guard lock(mutex); + ++counter; + }); + } + } + + TEST_EQUAL(taskCount, counter, ()); + } +} diff --git a/base/primitive_thread_pool.hpp b/base/primitive_thread_pool.hpp new file mode 100644 index 0000000000..e86d88e5ab --- /dev/null +++ b/base/primitive_thread_pool.hpp @@ -0,0 +1,105 @@ +// This file contains PrimitiveThreadPool class. +#pragma once + +#include "base/assert.hpp" +#include "base/thread_utils.hpp" + +#include +#include +#include +#include +#include +#include + +namespace threads +{ +// PrimitiveThreadPool is needed for easy parallelization of tasks. +// PrimitiveThreadPool can accept tasks that return result as std::future. +// When the destructor is called, all threads will join. +// +// Usage example: +// size_t threadCount = 4; +// size_t counter = 0; +// { +// std::mutex mutex; +// threads::PrimitiveThreadPool threadPool(threadCount); +// for (size_t i = 0; i < threadCount; ++i) +// { +// threadPool.Submit([&]() { +// std::this_thread::sleep_for(std::chrono::milliseconds(1)); +// std::lock_guard lock(mutex); +// ++counter; +// }); +// } +// } +// TEST_EQUAL(threadCount, counter, ()); +// +class PrimitiveThreadPool +{ +public: + using FuntionType = FunctionWrapper; + using Threads = std::vector; + + PrimitiveThreadPool(size_t threadCount) : m_done(false), m_joiner(m_threads) + { + CHECK_GREATER(threadCount, 0, ()); + + for (size_t i = 0; i < threadCount; i++) + m_threads.push_back(std::thread(&PrimitiveThreadPool::Worker, this)); + } + + ~PrimitiveThreadPool() + { + { + std::unique_lock lock(m_mutex); + m_done = true; + } + m_condition.notify_all(); + } + + template + auto Submit(F && func, Args &&... args) ->std::future + { + using ResultType = decltype(func(args...)); + std::packaged_task task(std::bind(std::forward(func), + std::forward(args)...)); + std::future result(task.get_future()); + { + std::unique_lock lock(m_mutex); + m_queue.push(std::move(task)); + } + m_condition.notify_one(); + return result; + } + +private: + void Worker() + { + while (true) + { + FuntionType task; + { + std::unique_lock lock(m_mutex); + m_condition.wait(lock, [&] { + return m_done || !m_queue.empty(); + }); + + if (m_done && m_queue.empty()) + return; + + task = std::move(m_queue.front()); + m_queue.pop(); + } + + task(); + } + } + + bool m_done; + std::mutex m_mutex; + std::condition_variable m_condition; + std::queue m_queue; + Threads m_threads; + StandartThreadsJoiner m_joiner; +}; +} // namespace threads diff --git a/base/thread_utils.hpp b/base/thread_utils.hpp new file mode 100644 index 0000000000..553103677d --- /dev/null +++ b/base/thread_utils.hpp @@ -0,0 +1,64 @@ +#pragma once + +#include +#include + +#include + +namespace threads +{ +template> +class ThreadsJoiner +{ +public: + explicit ThreadsJoiner(ThreadColl & threads) : m_threads(threads) {} + ~ThreadsJoiner() + { + for (auto & thread : m_threads) + { + if (thread.joinable()) + thread.join(); + } + } + +private: + ThreadColl & m_threads; +}; + +using StandartThreadsJoiner = ThreadsJoiner<>; + +class FunctionWrapper : boost::noncopyable +{ +public: + template + FunctionWrapper(F && func) : m_impl(new ImplType(std::move(func))) {} + FunctionWrapper() = default; + + FunctionWrapper(FunctionWrapper && other) : m_impl(std::move(other.m_impl)) {} + FunctionWrapper & operator=(FunctionWrapper && other) + { + m_impl = std::move(other.m_impl); + return *this; + } + + void operator()() { m_impl->Call(); } + +private: + struct ImplBase + { + virtual ~ImplBase() = default; + virtual void Call() = 0; + }; + + template + struct ImplType : ImplBase + { + ImplType(F && func) : m_func(std::move(func)) {} + void Call() override { m_func(); } + + F m_func; + }; + + std::unique_ptr m_impl; +}; +} // namespace threads