2
2
3
3
#include < atomic>
4
4
#include < thread>
5
- #include < climits >
5
+ #include < limits >
6
6
7
7
namespace tp
8
8
{
@@ -16,6 +16,8 @@ namespace tp
16
16
template <typename Task, template <typename > class Queue >
17
17
class Worker
18
18
{
19
+ using WorkerVector = std::vector<std::unique_ptr<Worker<Task, Queue>>>;
20
+
19
21
public:
20
22
/* *
21
23
* @brief Worker Constructor.
@@ -36,9 +38,9 @@ class Worker
36
38
/* *
37
39
* @brief start Create the executing thread and start tasks execution.
38
40
* @param id Worker ID.
39
- * @param steal_donor Sibling worker to steal task from it .
41
+ * @param workers Sibling workers for performing round robin work stealing .
40
42
*/
41
- void start (size_t id, Worker* steal_donor );
43
+ void start (size_t id, WorkerVector* workers );
42
44
43
45
/* *
44
46
* @brief stop Stop all worker's thread and stealing activity.
@@ -47,19 +49,19 @@ class Worker
47
49
void stop ();
48
50
49
51
/* *
50
- * @brief post Post task to queue.
52
+ * @brief tryPost Post task to queue.
51
53
* @param handler Handler to be executed in executing thread.
52
54
* @return true on success.
53
55
*/
54
56
template <typename Handler>
55
- bool post (Handler&& handler);
57
+ bool tryPost (Handler&& handler);
56
58
57
59
/* *
58
- * @brief steal Steal one task from this worker queue.
59
- * @param task Place for stealed task to be stored.
60
+ * @brief tryGetLocalTask Get one task from this worker queue.
61
+ * @param task Place for the obtained task to be stored.
60
62
* @return true on success.
61
63
*/
62
- bool steal (Task& task);
64
+ bool tryGetLocalTask (Task& task);
63
65
64
66
/* *
65
67
* @brief getWorkerIdForCurrentThread Return worker ID associated with
@@ -69,16 +71,24 @@ class Worker
69
71
static size_t getWorkerIdForCurrentThread ();
70
72
71
73
private:
74
+ /* *
75
+ * @brief tryRoundRobinSteal Try stealing a thread from sibling workers in a round-robin fashion.
76
+ * @param task Place for the obtained task to be stored.
77
+ * @param workers Sibling workers for performing round robin work stealing.
78
+ */
79
+ bool tryRoundRobinSteal (Task& task, WorkerVector* workers);
80
+
72
81
/* *
73
82
* @brief threadFunc Executing thread function.
74
83
* @param id Worker ID to be associated with this thread.
75
- * @param steal_donor Sibling worker to steal task from it .
84
+ * @param workers Sibling workers for performing round robin work stealing .
76
85
*/
77
- void threadFunc (size_t id, Worker* steal_donor );
86
+ void threadFunc (size_t id, WorkerVector* workers );
78
87
79
88
Queue<Task> m_queue;
80
89
std::atomic<bool > m_running_flag;
81
90
std::thread m_thread;
91
+ size_t m_next_donor;
82
92
};
83
93
84
94
@@ -88,7 +98,7 @@ namespace detail
88
98
{
89
99
inline size_t * thread_id ()
90
100
{
91
- static thread_local size_t tss_id = UINT_MAX ;
101
+ static thread_local size_t tss_id = std::numeric_limits< size_t >:: max () ;
92
102
return &tss_id;
93
103
}
94
104
}
@@ -97,6 +107,7 @@ template <typename Task, template<typename> class Queue>
97
107
inline Worker<Task, Queue>::Worker(size_t queue_size)
98
108
: m_queue(queue_size)
99
109
, m_running_flag(true )
110
+ , m_next_donor(0 ) // Initialized in threadFunc.
100
111
{
101
112
}
102
113
@@ -126,9 +137,9 @@ inline void Worker<Task, Queue>::stop()
126
137
}
127
138
128
139
template <typename Task, template <typename > class Queue >
129
- inline void Worker<Task, Queue>::start(size_t id, Worker* steal_donor )
140
+ inline void Worker<Task, Queue>::start(size_t id, WorkerVector* workers )
130
141
{
131
- m_thread = std::thread (&Worker<Task, Queue>::threadFunc, this , id, steal_donor );
142
+ m_thread = std::thread (&Worker<Task, Queue>::threadFunc, this , id, workers );
132
143
}
133
144
134
145
template <typename Task, template <typename > class Queue >
@@ -139,35 +150,60 @@ inline size_t Worker<Task, Queue>::getWorkerIdForCurrentThread()
139
150
140
151
template <typename Task, template <typename > class Queue >
141
152
template <typename Handler>
142
- inline bool Worker<Task, Queue>::post (Handler&& handler)
153
+ inline bool Worker<Task, Queue>::tryPost (Handler&& handler)
143
154
{
144
155
return m_queue.push (std::forward<Handler>(handler));
145
156
}
146
157
147
158
template <typename Task, template <typename > class Queue >
148
- inline bool Worker<Task, Queue>::steal (Task& task)
159
+ inline bool Worker<Task, Queue>::tryGetLocalTask (Task& task)
149
160
{
150
161
return m_queue.pop (task);
151
162
}
152
163
153
164
template <typename Task, template <typename > class Queue >
154
- inline void Worker<Task, Queue>::threadFunc(size_t id, Worker* steal_donor)
165
+ inline bool Worker<Task, Queue>::tryRoundRobinSteal(Task& task, WorkerVector* workers)
166
+ {
167
+ auto starting_index = m_next_donor;
168
+
169
+ // Iterate once through the worker ring, checking for queued work items on each thread.
170
+ do
171
+ {
172
+ // Don't steal from local queue.
173
+ if (m_next_donor != *detail::thread_id () && workers->at (m_next_donor)->tryGetLocalTask (task))
174
+ {
175
+ // Increment before returning so that m_next_donor always points to the worker that has gone the longest
176
+ // without a steal attempt. This helps enforce fairness in the stealing.
177
+ ++m_next_donor %= workers->size ();
178
+ return true ;
179
+ }
180
+
181
+ ++m_next_donor %= workers->size ();
182
+ } while (m_next_donor != starting_index);
183
+
184
+ return false ;
185
+ }
186
+
187
+ template <typename Task, template <typename > class Queue >
188
+ inline void Worker<Task, Queue>::threadFunc(size_t id, WorkerVector* workers)
155
189
{
156
190
*detail::thread_id () = id;
191
+ m_next_donor = ++id % workers->size ();
157
192
158
193
Task handler;
159
194
160
195
while (m_running_flag.load (std::memory_order_relaxed))
161
196
{
162
- if (m_queue.pop (handler) || steal_donor->steal (handler))
197
+ // Prioritize local queue, then try stealing from sibling workers.
198
+ if (tryGetLocalTask (handler) || tryRoundRobinSteal (handler, workers))
163
199
{
164
200
try
165
201
{
166
202
handler ();
167
203
}
168
204
catch (...)
169
205
{
170
- // suppress all exceptions
206
+ // Suppress all exceptions.
171
207
}
172
208
}
173
209
else
0 commit comments