@@ -45,7 +45,24 @@ using namespace ompx;
45
45
46
46
namespace {
47
47
48
- uint32_t determineNumberOfThreads (int32_t NumThreadsClause) {
48
+ void num_threads_strict_error (int32_t nt_strict, int32_t nt_severity,
49
+ const char *nt_message, int32_t requested,
50
+ int32_t actual) {
51
+ if (nt_message)
52
+ printf (" %s\n " , nt_message);
53
+ else
54
+ printf (" The computed number of threads (%u) does not match the requested "
55
+ " number of threads (%d). Consider that it might not be supported "
56
+ " to select exactly %d threads on this target device.\n " ,
57
+ actual, requested, requested);
58
+ if (nt_severity == severity_fatal)
59
+ __builtin_trap ();
60
+ }
61
+
62
+ uint32_t determineNumberOfThreads (int32_t NumThreadsClause,
63
+ int32_t nt_strict = false ,
64
+ int32_t nt_severity = severity_fatal,
65
+ const char *nt_message = nullptr ) {
49
66
uint32_t NThreadsICV =
50
67
NumThreadsClause != -1 ? NumThreadsClause : icv::NThreads;
51
68
uint32_t NumThreads = mapping::getMaxTeamThreads ();
@@ -55,13 +72,17 @@ uint32_t determineNumberOfThreads(int32_t NumThreadsClause) {
55
72
56
73
// SPMD mode allows any number of threads, for generic mode we round down to a
57
74
// multiple of WARPSIZE since it is legal to do so in OpenMP.
58
- if (mapping::isSPMDMode ())
59
- return NumThreads;
75
+ if (!mapping::isSPMDMode ()) {
76
+ if (NumThreads < mapping::getWarpSize ())
77
+ NumThreads = 1 ;
78
+ else
79
+ NumThreads = (NumThreads & ~((uint32_t )mapping::getWarpSize () - 1 ));
80
+ }
60
81
61
- if (NumThreads < mapping::getWarpSize ())
62
- NumThreads = 1 ;
63
- else
64
- NumThreads = (NumThreads & ~(( uint32_t ) mapping::getWarpSize () - 1 ) );
82
+ if (NumThreadsClause != - 1 && nt_strict &&
83
+ NumThreads != static_cast < uint32_t >(NumThreadsClause))
84
+ num_threads_strict_error (nt_strict, nt_severity, nt_message,
85
+ NumThreadsClause, NumThreads );
65
86
66
87
return NumThreads;
67
88
}
@@ -82,12 +103,14 @@ uint32_t determineNumberOfThreads(int32_t NumThreadsClause) {
82
103
83
104
extern " C" {
84
105
85
- [[clang::always_inline]] void __kmpc_parallel_spmd (IdentTy *ident,
86
- int32_t num_threads,
87
- void *fn, void **args,
88
- const int64_t nargs) {
106
+ [[clang::always_inline]] void
107
+ __kmpc_parallel_spmd (IdentTy *ident, int32_t num_threads, void *fn, void **args,
108
+ const int64_t nargs, int32_t nt_strict = false ,
109
+ int32_t nt_severity = severity_fatal,
110
+ const char *nt_message = nullptr ) {
89
111
uint32_t TId = mapping::getThreadIdInBlock ();
90
- uint32_t NumThreads = determineNumberOfThreads (num_threads);
112
+ uint32_t NumThreads =
113
+ determineNumberOfThreads (num_threads, nt_strict, nt_severity, nt_message);
91
114
uint32_t PTeamSize =
92
115
NumThreads == mapping::getMaxTeamThreads () ? 0 : NumThreads;
93
116
// Avoid the race between the read of the `icv::Level` above and the write
@@ -140,10 +163,11 @@ extern "C" {
140
163
return ;
141
164
}
142
165
143
- [[clang::always_inline]] void
144
- __kmpc_parallel_51 (IdentTy *ident, int32_t , int32_t if_expr,
145
- int32_t num_threads, int proc_bind, void *fn,
146
- void *wrapper_fn, void **args, int64_t nargs) {
166
+ [[clang::always_inline]] void __kmpc_parallel_51 (
167
+ IdentTy *ident, int32_t , int32_t if_expr, int32_t num_threads,
168
+ int proc_bind, void *fn, void *wrapper_fn, void **args, int64_t nargs,
169
+ int32_t nt_strict = false , int32_t nt_severity = severity_fatal,
170
+ const char *nt_message = nullptr ) {
147
171
uint32_t TId = mapping::getThreadIdInBlock ();
148
172
149
173
// Assert the parallelism level is zero if disabled by the user.
@@ -156,6 +180,12 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
156
180
// 3) nested parallel regions
157
181
if (OMP_UNLIKELY (!if_expr || state::HasThreadState ||
158
182
(config::mayUseNestedParallelism () && icv::Level))) {
183
+ // OpenMP 6.0 12.1.2 requires the num_threads 'strict' modifier to also have
184
+ // effect when parallel execution is disabled by a corresponding if clause
185
+ // attached to the parallel directive.
186
+ if (nt_strict && num_threads > 1 )
187
+ num_threads_strict_error (nt_strict, nt_severity, nt_message, num_threads,
188
+ 1 );
159
189
state::DateEnvironmentRAII DERAII (ident);
160
190
++icv::Level;
161
191
invokeMicrotask (TId, 0 , fn, args, nargs);
@@ -169,12 +199,14 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
169
199
// This was moved to its own routine so it could be called directly
170
200
// in certain situations to avoid resource consumption of unused
171
201
// logic in parallel_51.
172
- __kmpc_parallel_spmd (ident, num_threads, fn, args, nargs);
202
+ __kmpc_parallel_spmd (ident, num_threads, fn, args, nargs, nt_strict,
203
+ nt_severity, nt_message);
173
204
174
205
return ;
175
206
}
176
207
177
- uint32_t NumThreads = determineNumberOfThreads (num_threads);
208
+ uint32_t NumThreads =
209
+ determineNumberOfThreads (num_threads, nt_strict, nt_severity, nt_message);
178
210
uint32_t MaxTeamThreads = mapping::getMaxTeamThreads ();
179
211
uint32_t PTeamSize = NumThreads == MaxTeamThreads ? 0 : NumThreads;
180
212
@@ -277,6 +309,16 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
277
309
__kmpc_end_sharing_variables ();
278
310
}
279
311
312
+ [[clang::always_inline]] void __kmpc_parallel_60 (
313
+ IdentTy *ident, int32_t id, int32_t if_expr, int32_t num_threads,
314
+ int proc_bind, void *fn, void *wrapper_fn, void **args, int64_t nargs,
315
+ int32_t nt_strict = false , int32_t nt_severity = severity_fatal,
316
+ const char *nt_message = nullptr ) {
317
+ return __kmpc_parallel_51 (ident, id, if_expr, num_threads, proc_bind, fn,
318
+ wrapper_fn, args, nargs, nt_strict, nt_severity,
319
+ nt_message);
320
+ }
321
+
280
322
[[clang::noinline]] bool __kmpc_kernel_parallel (ParallelRegionFnTy *WorkFn) {
281
323
// Work function and arguments for L1 parallel region.
282
324
*WorkFn = state::ParallelRegionFn;
0 commit comments