Skip to content

Commit cf566c6

Browse files
committed
[OpenMP][clang] 6.0: num_threads strict (part 2: device runtime)
OpenMP 6.0 12.1.2 specifies the behavior of the strict modifier for the num_threads clause on parallel directives, along with the message and severity clauses. This commit implements necessary device runtime changes.
1 parent b3e7d2e commit cf566c6

File tree

3 files changed

+67
-18
lines changed

3 files changed

+67
-18
lines changed

offload/DeviceRTL/include/DeviceTypes.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,12 @@ struct omp_lock_t {
136136
void *Lock;
137137
};
138138

139+
// see definition in openmp/runtime kmp.h
140+
typedef enum omp_severity_t {
141+
severity_warning = 1,
142+
severity_fatal = 2
143+
} omp_severity_t;
144+
139145
using InterWarpCopyFnTy = void (*)(void *src, int32_t warp_num);
140146
using ShuffleReductFnTy = void (*)(void *rhsData, int16_t lane_id,
141147
int16_t lane_offset, int16_t shortCircuit);

offload/DeviceRTL/src/Parallelism.cpp

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,24 @@ using namespace ompx;
4545

4646
namespace {
4747

48-
uint32_t determineNumberOfThreads(int32_t NumThreadsClause) {
48+
void num_threads_strict_error(int32_t nt_strict, int32_t nt_severity,
49+
const char *nt_message, int32_t requested,
50+
int32_t actual) {
51+
if (nt_message)
52+
printf("%s\n", nt_message);
53+
else
54+
printf("The computed number of threads (%u) does not match the requested "
55+
"number of threads (%d). Consider that it might not be supported "
56+
"to select exactly %d threads on this target device.\n",
57+
actual, requested, requested);
58+
if (nt_severity == severity_fatal)
59+
__builtin_trap();
60+
}
61+
62+
uint32_t determineNumberOfThreads(int32_t NumThreadsClause,
63+
int32_t nt_strict = false,
64+
int32_t nt_severity = severity_fatal,
65+
const char *nt_message = nullptr) {
4966
uint32_t NThreadsICV =
5067
NumThreadsClause != -1 ? NumThreadsClause : icv::NThreads;
5168
uint32_t NumThreads = mapping::getMaxTeamThreads();
@@ -55,13 +72,17 @@ uint32_t determineNumberOfThreads(int32_t NumThreadsClause) {
5572

5673
// SPMD mode allows any number of threads, for generic mode we round down to a
5774
// multiple of WARPSIZE since it is legal to do so in OpenMP.
58-
if (mapping::isSPMDMode())
59-
return NumThreads;
75+
if (!mapping::isSPMDMode()) {
76+
if (NumThreads < mapping::getWarpSize())
77+
NumThreads = 1;
78+
else
79+
NumThreads = (NumThreads & ~((uint32_t)mapping::getWarpSize() - 1));
80+
}
6081

61-
if (NumThreads < mapping::getWarpSize())
62-
NumThreads = 1;
63-
else
64-
NumThreads = (NumThreads & ~((uint32_t)mapping::getWarpSize() - 1));
82+
if (NumThreadsClause != -1 && nt_strict &&
83+
NumThreads != static_cast<uint32_t>(NumThreadsClause))
84+
num_threads_strict_error(nt_strict, nt_severity, nt_message,
85+
NumThreadsClause, NumThreads);
6586

6687
return NumThreads;
6788
}
@@ -82,12 +103,14 @@ uint32_t determineNumberOfThreads(int32_t NumThreadsClause) {
82103

83104
extern "C" {
84105

85-
[[clang::always_inline]] void __kmpc_parallel_spmd(IdentTy *ident,
86-
int32_t num_threads,
87-
void *fn, void **args,
88-
const int64_t nargs) {
106+
[[clang::always_inline]] void
107+
__kmpc_parallel_spmd(IdentTy *ident, int32_t num_threads, void *fn, void **args,
108+
const int64_t nargs, int32_t nt_strict = false,
109+
int32_t nt_severity = severity_fatal,
110+
const char *nt_message = nullptr) {
89111
uint32_t TId = mapping::getThreadIdInBlock();
90-
uint32_t NumThreads = determineNumberOfThreads(num_threads);
112+
uint32_t NumThreads =
113+
determineNumberOfThreads(num_threads, nt_strict, nt_severity, nt_message);
91114
uint32_t PTeamSize =
92115
NumThreads == mapping::getMaxTeamThreads() ? 0 : NumThreads;
93116
// Avoid the race between the read of the `icv::Level` above and the write
@@ -140,10 +163,11 @@ extern "C" {
140163
return;
141164
}
142165

143-
[[clang::always_inline]] void
144-
__kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
145-
int32_t num_threads, int proc_bind, void *fn,
146-
void *wrapper_fn, void **args, int64_t nargs) {
166+
[[clang::always_inline]] void __kmpc_parallel_51(
167+
IdentTy *ident, int32_t, int32_t if_expr, int32_t num_threads,
168+
int proc_bind, void *fn, void *wrapper_fn, void **args, int64_t nargs,
169+
int32_t nt_strict = false, int32_t nt_severity = severity_fatal,
170+
const char *nt_message = nullptr) {
147171
uint32_t TId = mapping::getThreadIdInBlock();
148172

149173
// Assert the parallelism level is zero if disabled by the user.
@@ -156,6 +180,12 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
156180
// 3) nested parallel regions
157181
if (OMP_UNLIKELY(!if_expr || state::HasThreadState ||
158182
(config::mayUseNestedParallelism() && icv::Level))) {
183+
// OpenMP 6.0 12.1.2 requires the num_threads 'strict' modifier to also have
184+
// effect when parallel execution is disabled by a corresponding if clause
185+
// attached to the parallel directive.
186+
if (nt_strict && num_threads > 1)
187+
num_threads_strict_error(nt_strict, nt_severity, nt_message, num_threads,
188+
1);
159189
state::DateEnvironmentRAII DERAII(ident);
160190
++icv::Level;
161191
invokeMicrotask(TId, 0, fn, args, nargs);
@@ -169,12 +199,14 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
169199
// This was moved to its own routine so it could be called directly
170200
// in certain situations to avoid resource consumption of unused
171201
// logic in parallel_51.
172-
__kmpc_parallel_spmd(ident, num_threads, fn, args, nargs);
202+
__kmpc_parallel_spmd(ident, num_threads, fn, args, nargs, nt_strict,
203+
nt_severity, nt_message);
173204

174205
return;
175206
}
176207

177-
uint32_t NumThreads = determineNumberOfThreads(num_threads);
208+
uint32_t NumThreads =
209+
determineNumberOfThreads(num_threads, nt_strict, nt_severity, nt_message);
178210
uint32_t MaxTeamThreads = mapping::getMaxTeamThreads();
179211
uint32_t PTeamSize = NumThreads == MaxTeamThreads ? 0 : NumThreads;
180212

@@ -277,6 +309,16 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
277309
__kmpc_end_sharing_variables();
278310
}
279311

312+
[[clang::always_inline]] void __kmpc_parallel_60(
313+
IdentTy *ident, int32_t id, int32_t if_expr, int32_t num_threads,
314+
int proc_bind, void *fn, void *wrapper_fn, void **args, int64_t nargs,
315+
int32_t nt_strict = false, int32_t nt_severity = severity_fatal,
316+
const char *nt_message = nullptr) {
317+
return __kmpc_parallel_51(ident, id, if_expr, num_threads, proc_bind, fn,
318+
wrapper_fn, args, nargs, nt_strict, nt_severity,
319+
nt_message);
320+
}
321+
280322
[[clang::noinline]] bool __kmpc_kernel_parallel(ParallelRegionFnTy *WorkFn) {
281323
// Work function and arguments for L1 parallel region.
282324
*WorkFn = state::ParallelRegionFn;

openmp/runtime/src/kmp.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4666,6 +4666,7 @@ static inline int __kmp_adjust_gtid_for_hidden_helpers(int gtid) {
46664666
}
46674667

46684668
// Support for error directive
4669+
// See definition in offload/DeviceRTL DeviceTypes.h
46694670
typedef enum kmp_severity_t {
46704671
severity_warning = 1,
46714672
severity_fatal = 2

0 commit comments

Comments
 (0)