Skip to content

Commit 3fb0298

Browse files
authored
Merge pull request #1 from SynaptiveMedical/synaptive/dev/sever/COM-471_RoundRobinWorkStealing
COM-471 Implement round robin stealing semantics.
2 parents 23c23a7 + f9e1f4c commit 3fb0298

File tree

2 files changed

+60
-23
lines changed

2 files changed

+60
-23
lines changed

include/thread_pool/thread_pool.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ using ThreadPool = ThreadPoolImpl<FixedFunction<void(), 128>,
2828
*/
2929
template <typename Task, template<typename> class Queue>
3030
class ThreadPoolImpl {
31+
32+
using WorkerVector = std::vector<std::unique_ptr<Worker<Task, Queue>>>;
33+
3134
public:
3235
/**
3336
* @brief ThreadPool Construct and start new thread pool.
@@ -74,7 +77,7 @@ class ThreadPoolImpl {
7477
private:
7578
Worker<Task, Queue>& getWorker();
7679

77-
std::vector<std::unique_ptr<Worker<Task, Queue>>> m_workers;
80+
WorkerVector m_workers;
7881
std::atomic<size_t> m_next_worker;
7982
};
8083

@@ -94,9 +97,7 @@ inline ThreadPoolImpl<Task, Queue>::ThreadPoolImpl(
9497

9598
for(size_t i = 0; i < m_workers.size(); ++i)
9699
{
97-
Worker<Task, Queue>* steal_donor =
98-
m_workers[(i + 1) % m_workers.size()].get();
99-
m_workers[i]->start(i, steal_donor);
100+
m_workers[i]->start(i, &m_workers);
100101
}
101102
}
102103

@@ -131,7 +132,7 @@ template <typename Task, template<typename> class Queue>
131132
template <typename Handler>
132133
inline bool ThreadPoolImpl<Task, Queue>::tryPost(Handler&& handler)
133134
{
134-
return getWorker().post(std::forward<Handler>(handler));
135+
return getWorker().tryPost(std::forward<Handler>(handler));
135136
}
136137

137138
template <typename Task, template<typename> class Queue>

include/thread_pool/worker.hpp

Lines changed: 54 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
#include <atomic>
44
#include <thread>
5-
#include <climits>
5+
#include <limits>
66

77
namespace tp
88
{
@@ -16,6 +16,8 @@ namespace tp
1616
template <typename Task, template<typename> class Queue>
1717
class Worker
1818
{
19+
using WorkerVector = std::vector<std::unique_ptr<Worker<Task, Queue>>>;
20+
1921
public:
2022
/**
2123
* @brief Worker Constructor.
@@ -36,9 +38,9 @@ class Worker
3638
/**
3739
* @brief start Create the executing thread and start tasks execution.
3840
* @param id Worker ID.
39-
* @param steal_donor Sibling worker to steal task from it.
41+
* @param workers Sibling workers for performing round robin work stealing.
4042
*/
41-
void start(size_t id, Worker* steal_donor);
43+
void start(size_t id, WorkerVector* workers);
4244

4345
/**
4446
* @brief stop Stop all worker's thread and stealing activity.
@@ -47,19 +49,19 @@ class Worker
4749
void stop();
4850

4951
/**
50-
* @brief post Post task to queue.
52+
* @brief tryPost Post task to queue.
5153
* @param handler Handler to be executed in executing thread.
5254
* @return true on success.
5355
*/
5456
template <typename Handler>
55-
bool post(Handler&& handler);
57+
bool tryPost(Handler&& handler);
5658

5759
/**
58-
* @brief steal Steal one task from this worker queue.
59-
* @param task Place for stealed task to be stored.
60+
* @brief tryGetLocalTask Get one task from this worker queue.
61+
* @param task Place for the obtained task to be stored.
6062
* @return true on success.
6163
*/
62-
bool steal(Task& task);
64+
bool tryGetLocalTask(Task& task);
6365

6466
/**
6567
* @brief getWorkerIdForCurrentThread Return worker ID associated with
@@ -69,16 +71,24 @@ class Worker
6971
static size_t getWorkerIdForCurrentThread();
7072

7173
private:
74+
/**
75+
* @brief tryRoundRobinSteal Try stealing a thread from sibling workers in a round-robin fashion.
76+
* @param task Place for the obtained task to be stored.
77+
* @param workers Sibling workers for performing round robin work stealing.
78+
*/
79+
bool tryRoundRobinSteal(Task& task, WorkerVector* workers);
80+
7281
/**
7382
* @brief threadFunc Executing thread function.
7483
* @param id Worker ID to be associated with this thread.
75-
* @param steal_donor Sibling worker to steal task from it.
84+
* @param workers Sibling workers for performing round robin work stealing.
7685
*/
77-
void threadFunc(size_t id, Worker* steal_donor);
86+
void threadFunc(size_t id, WorkerVector* workers);
7887

7988
Queue<Task> m_queue;
8089
std::atomic<bool> m_running_flag;
8190
std::thread m_thread;
91+
size_t m_next_donor;
8292
};
8393

8494

@@ -88,7 +98,7 @@ namespace detail
8898
{
8999
inline size_t* thread_id()
90100
{
91-
static thread_local size_t tss_id = UINT_MAX;
101+
static thread_local size_t tss_id = std::numeric_limits<size_t>::max();
92102
return &tss_id;
93103
}
94104
}
@@ -97,6 +107,7 @@ template <typename Task, template<typename> class Queue>
97107
inline Worker<Task, Queue>::Worker(size_t queue_size)
98108
: m_queue(queue_size)
99109
, m_running_flag(true)
110+
, m_next_donor(0) // Initialized in threadFunc.
100111
{
101112
}
102113

@@ -126,9 +137,9 @@ inline void Worker<Task, Queue>::stop()
126137
}
127138

128139
template <typename Task, template<typename> class Queue>
129-
inline void Worker<Task, Queue>::start(size_t id, Worker* steal_donor)
140+
inline void Worker<Task, Queue>::start(size_t id, WorkerVector* workers)
130141
{
131-
m_thread = std::thread(&Worker<Task, Queue>::threadFunc, this, id, steal_donor);
142+
m_thread = std::thread(&Worker<Task, Queue>::threadFunc, this, id, workers);
132143
}
133144

134145
template <typename Task, template<typename> class Queue>
@@ -139,35 +150,60 @@ inline size_t Worker<Task, Queue>::getWorkerIdForCurrentThread()
139150

140151
template <typename Task, template<typename> class Queue>
141152
template <typename Handler>
142-
inline bool Worker<Task, Queue>::post(Handler&& handler)
153+
inline bool Worker<Task, Queue>::tryPost(Handler&& handler)
143154
{
144155
return m_queue.push(std::forward<Handler>(handler));
145156
}
146157

147158
template <typename Task, template<typename> class Queue>
148-
inline bool Worker<Task, Queue>::steal(Task& task)
159+
inline bool Worker<Task, Queue>::tryGetLocalTask(Task& task)
149160
{
150161
return m_queue.pop(task);
151162
}
152163

153164
template <typename Task, template<typename> class Queue>
154-
inline void Worker<Task, Queue>::threadFunc(size_t id, Worker* steal_donor)
165+
inline bool Worker<Task, Queue>::tryRoundRobinSteal(Task& task, WorkerVector* workers)
166+
{
167+
auto starting_index = m_next_donor;
168+
169+
// Iterate once through the worker ring, checking for queued work items on each thread.
170+
do
171+
{
172+
// Don't steal from local queue.
173+
if (m_next_donor != *detail::thread_id() && workers->at(m_next_donor)->tryGetLocalTask(task))
174+
{
175+
// Increment before returning so that m_next_donor always points to the worker that has gone the longest
176+
// without a steal attempt. This helps enforce fairness in the stealing.
177+
++m_next_donor %= workers->size();
178+
return true;
179+
}
180+
181+
++m_next_donor %= workers->size();
182+
} while (m_next_donor != starting_index);
183+
184+
return false;
185+
}
186+
187+
template <typename Task, template<typename> class Queue>
188+
inline void Worker<Task, Queue>::threadFunc(size_t id, WorkerVector* workers)
155189
{
156190
*detail::thread_id() = id;
191+
m_next_donor = ++id % workers->size();
157192

158193
Task handler;
159194

160195
while (m_running_flag.load(std::memory_order_relaxed))
161196
{
162-
if (m_queue.pop(handler) || steal_donor->steal(handler))
197+
// Prioritize local queue, then try stealing from sibling workers.
198+
if (tryGetLocalTask(handler) || tryRoundRobinSteal(handler, workers))
163199
{
164200
try
165201
{
166202
handler();
167203
}
168204
catch(...)
169205
{
170-
// suppress all exceptions
206+
// Suppress all exceptions.
171207
}
172208
}
173209
else

0 commit comments

Comments
 (0)