diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 8af054be322a5..a8d97a36df79e 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -334,6 +334,43 @@ class OpenMP_DoacrossClauseSkip< def OpenMP_DoacrossClause : OpenMP_DoacrossClauseSkip<>; +//===----------------------------------------------------------------------===// +// V5.2: [5.4.7] `exclusive` clause +//===----------------------------------------------------------------------===// + +class OpenMP_ExclusiveClauseSkip< + bit traits = false, bit arguments = false, bit assemblyFormat = false, + bit description = false, bit extraClassDeclaration = false + > : OpenMP_Clause { + let arguments = (ins + Variadic:$exclusive_vars + ); + + let optAssemblyFormat = [{ + `exclusive` `(` $exclusive_vars `:` type($exclusive_vars) `)` + }]; + + let extraClassDeclaration = [{ + bool hasExclusiveVars() { + return !getExclusiveVars().empty(); + } + }]; + + let description = [{ + The exclusive clause is used on a separating directive that separates a + structured block into two structured block sequences. If it + is specified, the input phase excludes the preceding structured block + sequence and instead includes the following structured block sequence, + while the scan phase includes the preceding structured block sequence. + + The `exclusive_vars` is a variadic list of operands that specifies the + scan-reduction accumulator symbols. + }]; +} + +def OpenMP_ExclusiveClause : OpenMP_ExclusiveClauseSkip<>; + //===----------------------------------------------------------------------===// // V5.2: [10.5.1] `filter` clause //===----------------------------------------------------------------------===// @@ -444,6 +481,43 @@ class OpenMP_HasDeviceAddrClauseSkip< def OpenMP_HasDeviceAddrClause : OpenMP_HasDeviceAddrClauseSkip<>; +//===----------------------------------------------------------------------===// +// V5.2: [5.4.7] `inclusive` clause +//===----------------------------------------------------------------------===// + +class OpenMP_InclusiveClauseSkip< + bit traits = false, bit arguments = false, bit assemblyFormat = false, + bit description = false, bit extraClassDeclaration = false + > : OpenMP_Clause { + let arguments = (ins + Variadic:$inclusive_vars + ); + + let optAssemblyFormat = [{ + `inclusive` `(` $inclusive_vars `:` type($inclusive_vars) `)` + }]; + + let extraClassDeclaration = [{ + bool hasInclusiveVars() { + return !getInclusiveVars().empty(); + } + }]; + + let description = [{ + The inclusive clause is used on a separating directive that separates a + structured block into two structured block sequences. If it is specified, + the input phase includes the preceding structured block sequence and the + scan phase includes the following structured block sequence. + + The `inclusive_vars` is a variadic list of operands that specifies the + scan-reduction accumulator symbols. + }]; +} + +def OpenMP_InclusiveClause : OpenMP_InclusiveClauseSkip<>; + + //===----------------------------------------------------------------------===// // V5.2: [15.1.2] `hint` clause //===----------------------------------------------------------------------===// @@ -1100,6 +1174,7 @@ class OpenMP_ReductionClauseSkip< ]; let arguments = (ins + OptionalAttr:$reduction_mod, Variadic:$reduction_vars, OptionalAttr:$reduction_byref, OptionalAttr:$reduction_syms @@ -1113,10 +1188,11 @@ class OpenMP_ReductionClauseSkip< // Description varies depending on the operation. let description = [{ - Reductions can be performed by specifying reduction accumulator variables in - `reduction_vars`, symbols referring to reduction declarations in the - `reduction_syms` attribute, and whether the reduction variable should be - passed into the reduction region by value or by reference in + Reductions can be performed by specifying the reduction modifer + (`default`, `inscan` or `task`) in `reduction_mod`, reduction accumulator + variables in `reduction_vars`, symbols referring to reduction declarations + in the `reduction_syms` attribute, and whether the reduction variable + should be passed into the reduction region by value or by reference in `reduction_byref`. Each reduction is identified by the accumulator it uses and accumulators must not be repeated in the same reduction. A private variable corresponding to the accumulator is used in place of the diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td index 2091c0c76dff7..690e3df1f685e 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td @@ -179,6 +179,27 @@ def OrderModifier def OrderModifierAttr : EnumAttr; +//===----------------------------------------------------------------------===// +// reduction_modifier enum. +//===----------------------------------------------------------------------===// + +def ReductionModifierDefault : I32EnumAttrCase<"defaultmod", 0>; +def ReductionModifierInscan : I32EnumAttrCase<"inscan", 1>; +def ReductionModifierTask : I32EnumAttrCase<"task", 2>; + +def ReductionModifier : OpenMP_I32EnumAttr< + "ReductionModifier", + "reduction modifier", [ + ReductionModifierDefault, + ReductionModifierInscan, + ReductionModifierTask + ]>; + +def ReductionModifierAttr : OpenMP_EnumAttr { + let assemblyFormat = "`(` $value `)`"; +} + //===----------------------------------------------------------------------===// // sched_mod enum. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index c5b8890436708..580c9c6ef6fde 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -178,7 +178,7 @@ def ParallelOp : OpenMP_Op<"parallel", traits = [ let assemblyFormat = clausesAssemblyFormat # [{ custom($region, $private_vars, type($private_vars), - $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref, + $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict }]; @@ -223,7 +223,7 @@ def TeamsOp : OpenMP_Op<"teams", traits = [ let assemblyFormat = clausesAssemblyFormat # [{ custom($region, $private_vars, type($private_vars), - $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref, + $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict }]; @@ -282,7 +282,7 @@ def SectionsOp : OpenMP_Op<"sections", traits = [ let assemblyFormat = clausesAssemblyFormat # [{ custom($region, $private_vars, type($private_vars), - $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref, + $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict }]; @@ -469,7 +469,7 @@ def LoopOp : OpenMP_Op<"loop", traits = [ let assemblyFormat = clausesAssemblyFormat # [{ custom($region, $private_vars, type($private_vars), - $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref, + $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict }]; @@ -521,7 +521,7 @@ def WsloopOp : OpenMP_Op<"wsloop", traits = [ let assemblyFormat = clausesAssemblyFormat # [{ custom($region, $private_vars, type($private_vars), - $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref, + $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict }]; @@ -575,7 +575,7 @@ def SimdOp : OpenMP_Op<"simd", traits = [ let assemblyFormat = clausesAssemblyFormat # [{ custom($region, $private_vars, type($private_vars), - $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref, + $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict }]; @@ -782,7 +782,7 @@ def TaskloopOp : OpenMP_Op<"taskloop", traits = [ custom( $region, $in_reduction_vars, type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms, $private_vars, - type($private_vars), $private_syms, $reduction_vars, + type($private_vars), $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict }]; @@ -1706,6 +1706,26 @@ def CancellationPointOp : OpenMP_Op<"cancellation_point", clauses = [ let hasVerifier = 1; } +def ScanOp : OpenMP_Op<"scan", [ + AttrSizedOperandSegments, MemoryEffects<[MemWrite]> + ], clauses = [ + OpenMP_InclusiveClause, OpenMP_ExclusiveClause]> { + let summary = "scan directive"; + let description = [{ + The scan directive allows to specify scan reductions. It should be + enclosed within a parent directive along with which a reduction clause + with `inscan` modifier must be specified. The scan directive allows to + split code blocks into input phase and scan phase in the region + enclosed by the parent. + }] # clausesDescription; + + let builders = [ + OpBuilder<(ins CArg<"const ScanOperands &">:$clauses)> + ]; + + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // 2.19.5.7 declare reduction Directive //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index aa241b91d758c..233739e1d6d91 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -451,6 +451,7 @@ struct ParallelOpLowering : public OpRewritePattern { /* private_vars = */ ValueRange(), /* private_syms = */ nullptr, /* proc_bind_kind = */ omp::ClauseProcBindKindAttr{}, + /* reduction_mod = */ nullptr, /* reduction_vars = */ llvm::SmallVector{}, /* reduction_byref = */ DenseBoolArrayAttr{}, /* reduction_syms = */ ArrayAttr{}); diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 5a619254a5ee1..88f56dc514422 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -494,16 +494,19 @@ struct PrivateParseArgs { DenseI64ArrayAttr *mapIndices = nullptr) : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {} }; + struct ReductionParseArgs { SmallVectorImpl &vars; SmallVectorImpl &types; DenseBoolArrayAttr &byref; ArrayAttr &syms; + ReductionModifierAttr *modifier; ReductionParseArgs(SmallVectorImpl &vars, SmallVectorImpl &types, DenseBoolArrayAttr &byref, - ArrayAttr &syms) - : vars(vars), types(types), byref(byref), syms(syms) {} + ArrayAttr &syms, ReductionModifierAttr *mod = nullptr) + : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {} }; + struct AllRegionParseArgs { std::optional hostEvalArgs; std::optional inReductionArgs; @@ -522,7 +525,8 @@ static ParseResult parseClauseWithRegionArgs( SmallVectorImpl &types, SmallVectorImpl ®ionPrivateArgs, ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr, - DenseBoolArrayAttr *byref = nullptr) { + DenseBoolArrayAttr *byref = nullptr, + ReductionModifierAttr *modifier = nullptr) { SmallVector symbolVec; SmallVector mapIndicesVec; SmallVector isByRefVec; @@ -531,6 +535,20 @@ static ParseResult parseClauseWithRegionArgs( if (parser.parseLParen()) return failure(); + if (modifier && succeeded(parser.parseOptionalKeyword("mod"))) { + StringRef enumStr; + if (parser.parseColon() || parser.parseKeyword(&enumStr) || + parser.parseComma()) + return failure(); + std::optional enumValue = + symbolizeReductionModifier(enumStr); + if (!enumValue.has_value()) + return failure(); + *modifier = ReductionModifierAttr::get(parser.getContext(), *enumValue); + if (!*modifier) + return failure(); + } + if (parser.parseCommaSeparatedList([&]() { if (byref) isByRefVec.push_back( @@ -635,11 +653,10 @@ static ParseResult parseBlockArgClause( if (succeeded(parser.parseOptionalKeyword(keyword))) { if (!reductionArgs) return failure(); - if (failed(parseClauseWithRegionArgs( parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs, - &reductionArgs->syms, /*mapIndices=*/nullptr, - &reductionArgs->byref))) + &reductionArgs->syms, /*mapIndices=*/nullptr, &reductionArgs->byref, + reductionArgs->modifier))) return failure(); } return success(); @@ -735,6 +752,7 @@ static ParseResult parseInReductionPrivateReductionRegion( DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl &privateVars, llvm::SmallVectorImpl &privateTypes, ArrayAttr &privateSyms, + ReductionModifierAttr &reductionMod, SmallVectorImpl &reductionVars, SmallVectorImpl &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms) { @@ -743,7 +761,7 @@ static ParseResult parseInReductionPrivateReductionRegion( inReductionByref, inReductionSyms); args.privateArgs.emplace(privateVars, privateTypes, privateSyms); args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref, - reductionSyms); + reductionSyms, &reductionMod); return parseBlockArgRegion(parser, region, args); } @@ -760,13 +778,14 @@ static ParseResult parsePrivateReductionRegion( OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl &privateVars, llvm::SmallVectorImpl &privateTypes, ArrayAttr &privateSyms, + ReductionModifierAttr &reductionMod, SmallVectorImpl &reductionVars, SmallVectorImpl &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms) { AllRegionParseArgs args; args.privateArgs.emplace(privateVars, privateTypes, privateSyms); args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref, - reductionSyms); + reductionSyms, &reductionMod); return parseBlockArgRegion(parser, region, args); } @@ -817,9 +836,10 @@ struct ReductionPrintArgs { TypeRange types; DenseBoolArrayAttr byref; ArrayAttr syms; + ReductionModifierAttr modifier; ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref, - ArrayAttr syms) - : vars(vars), types(types), byref(byref), syms(syms) {} + ArrayAttr syms, ReductionModifierAttr mod = nullptr) + : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {} }; struct AllRegionPrintArgs { std::optional hostEvalArgs; @@ -833,18 +853,20 @@ struct AllRegionPrintArgs { }; } // namespace -static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, - StringRef clauseName, - ValueRange argsSubrange, - ValueRange operands, TypeRange types, - ArrayAttr symbols = nullptr, - DenseI64ArrayAttr mapIndices = nullptr, - DenseBoolArrayAttr byref = nullptr) { +static void printClauseWithRegionArgs( + OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, + ValueRange argsSubrange, ValueRange operands, TypeRange types, + ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr, + DenseBoolArrayAttr byref = nullptr, + ReductionModifierAttr modifier = nullptr) { if (argsSubrange.empty()) return; p << clauseName << "("; + if (modifier) + p << "mod: " << stringifyReductionModifier(modifier.getValue()) << ", "; + if (!symbols) { llvm::SmallVector values(operands.size(), nullptr); symbols = ArrayAttr::get(ctx, values); @@ -905,7 +927,7 @@ printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, reductionArgs->vars, reductionArgs->types, reductionArgs->syms, /*mapIndices=*/nullptr, - reductionArgs->byref); + reductionArgs->byref, reductionArgs->modifier); } static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion, @@ -968,7 +990,8 @@ static void printInReductionPrivateReductionRegion( OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, - ArrayAttr privateSyms, ValueRange reductionVars, TypeRange reductionTypes, + ArrayAttr privateSyms, ReductionModifierAttr reductionMod, + ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms) { AllRegionPrintArgs args; args.inReductionArgs.emplace(inReductionVars, inReductionTypes, @@ -976,7 +999,7 @@ static void printInReductionPrivateReductionRegion( args.privateArgs.emplace(privateVars, privateTypes, privateSyms, /*mapIndices=*/nullptr); args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref, - reductionSyms); + reductionSyms, reductionMod); printBlockArgRegion(p, op, region, args); } @@ -991,14 +1014,15 @@ static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, static void printPrivateReductionRegion( OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, - TypeRange privateTypes, ArrayAttr privateSyms, ValueRange reductionVars, + TypeRange privateTypes, ArrayAttr privateSyms, + ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms) { AllRegionPrintArgs args; args.privateArgs.emplace(privateVars, privateTypes, privateSyms, /*mapIndices=*/nullptr); args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref, - reductionSyms); + reductionSyms, reductionMod); printBlockArgRegion(p, op, region, args); } @@ -1942,7 +1966,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr, /*num_threads=*/nullptr, /*private_vars=*/ValueRange(), /*private_syms=*/nullptr, /*proc_bind_kind=*/nullptr, - /*reduction_vars=*/ValueRange(), + /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(), /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr); state.addAttributes(attributes); } @@ -1953,7 +1977,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, clauses.ifExpr, clauses.numThreads, clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), - clauses.procBindKind, clauses.reductionVars, + clauses.procBindKind, clauses.reductionMod, + clauses.reductionVars, makeDenseBoolArrayAttr(ctx, clauses.reductionByref), makeArrayAttr(ctx, clauses.reductionSyms)); } @@ -2052,12 +2077,13 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state, const TeamsOperands &clauses) { MLIRContext *ctx = builder.getContext(); // TODO Store clauses in op: privateVars, privateSyms. - TeamsOp::build( - builder, state, clauses.allocateVars, clauses.allocatorVars, - clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper, - /*private_vars=*/{}, /*private_syms=*/nullptr, clauses.reductionVars, - makeDenseBoolArrayAttr(ctx, clauses.reductionByref), - makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimit); + TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, + clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper, + /*private_vars=*/{}, /*private_syms=*/nullptr, + clauses.reductionMod, clauses.reductionVars, + makeDenseBoolArrayAttr(ctx, clauses.reductionByref), + makeArrayAttr(ctx, clauses.reductionSyms), + clauses.threadLimit); } LogicalResult TeamsOp::verify() { @@ -2114,7 +2140,8 @@ void SectionsOp::build(OpBuilder &builder, OperationState &state, // TODO Store clauses in op: privateVars, privateSyms. SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, clauses.nowait, /*private_vars=*/{}, - /*private_syms=*/nullptr, clauses.reductionVars, + /*private_syms=*/nullptr, clauses.reductionMod, + clauses.reductionVars, makeDenseBoolArrayAttr(ctx, clauses.reductionByref), makeArrayAttr(ctx, clauses.reductionSyms)); } @@ -2221,7 +2248,7 @@ void LoopOp::build(OpBuilder &builder, OperationState &state, LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), clauses.order, - clauses.orderMod, clauses.reductionVars, + clauses.orderMod, clauses.reductionMod, clauses.reductionVars, makeDenseBoolArrayAttr(ctx, clauses.reductionByref), makeArrayAttr(ctx, clauses.reductionSyms)); } @@ -2249,7 +2276,8 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state, /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(), /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr, /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr, - /*reduction_vars=*/ValueRange(), /*reduction_byref=*/nullptr, + /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(), + /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr, /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr, /*schedule_simd=*/false); @@ -2261,15 +2289,16 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state, MLIRContext *ctx = builder.getContext(); // TODO: Store clauses in op: allocateVars, allocatorVars, privateVars, // privateSyms. - WsloopOp::build( - builder, state, - /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars, - clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod, - clauses.ordered, clauses.privateVars, - makeArrayAttr(ctx, clauses.privateSyms), clauses.reductionVars, - makeDenseBoolArrayAttr(ctx, clauses.reductionByref), - makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind, - clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd); + WsloopOp::build(builder, state, + /*allocate_vars=*/{}, /*allocator_vars=*/{}, + clauses.linearVars, clauses.linearStepVars, clauses.nowait, + clauses.order, clauses.orderMod, clauses.ordered, + clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), + clauses.reductionMod, clauses.reductionVars, + makeDenseBoolArrayAttr(ctx, clauses.reductionByref), + makeArrayAttr(ctx, clauses.reductionSyms), + clauses.scheduleKind, clauses.scheduleChunk, + clauses.scheduleMod, clauses.scheduleSimd); } LogicalResult WsloopOp::verify() { @@ -2316,7 +2345,7 @@ void SimdOp::build(OpBuilder &builder, OperationState &state, /*linear_vars=*/{}, /*linear_step_vars=*/{}, clauses.nontemporalVars, clauses.order, clauses.orderMod, clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), - clauses.reductionVars, + clauses.reductionMod, clauses.reductionVars, makeDenseBoolArrayAttr(ctx, clauses.reductionByref), makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen, clauses.simdlen); @@ -2548,7 +2577,7 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state, makeDenseBoolArrayAttr(ctx, clauses.inReductionByref), makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable, clauses.nogroup, clauses.numTasks, clauses.priority, /*private_vars=*/{}, - /*private_syms=*/nullptr, clauses.reductionVars, + /*private_syms=*/nullptr, clauses.reductionMod, clauses.reductionVars, makeDenseBoolArrayAttr(ctx, clauses.reductionByref), makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied); } @@ -3125,6 +3154,36 @@ void MaskedOp::build(OpBuilder &builder, OperationState &state, MaskedOp::build(builder, state, clauses.filteredThreadId); } +//===----------------------------------------------------------------------===// +// Spec 5.2: Scan construct (5.6) +//===----------------------------------------------------------------------===// + +void ScanOp::build(OpBuilder &builder, OperationState &state, + const ScanOperands &clauses) { + ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars); +} + +LogicalResult ScanOp::verify() { + if (hasExclusiveVars() == hasInclusiveVars()) + return emitError( + "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected"); + if (WsloopOp parentWsLoopOp = (*this)->getParentOfType()) { + if (parentWsLoopOp.getReductionModAttr() && + parentWsLoopOp.getReductionModAttr().getValue() == + ReductionModifier::inscan) + return success(); + } + if (SimdOp parentSimdOp = (*this)->getParentOfType()) { + if (parentSimdOp.getReductionModAttr() && + parentSimdOp.getReductionModAttr().getValue() == + ReductionModifier::inscan) + return success(); + } + return emitError("SCAN directive needs to be enclosed within a parent " + "worksharing loop construct or SIMD construct with INSCAN " + "reduction modifier"); +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 0be515e63b470..a1b23011e41ef 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -241,9 +241,13 @@ static LogicalResult checkImplementationStatus(Operation &op) { } }; auto checkReduction = [&todo](auto op, LogicalResult &result) { - if (!op.getReductionVars().empty() || op.getReductionByref() || - op.getReductionSyms()) - result = todo("reduction"); + if (isa(op) || isa(op)) + if (!op.getReductionVars().empty() || op.getReductionByref() || + op.getReductionSyms()) + result = todo("reduction"); + if (op.getReductionMod() && + op.getReductionMod().value() != omp::ReductionModifier::defaultmod) + result = todo("reduction with modifier"); }; auto checkTaskReduction = [&todo](auto op, LogicalResult &result) { if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() || @@ -261,6 +265,7 @@ static LogicalResult checkImplementationStatus(Operation &op) { .Case([&](omp::SectionsOp op) { checkAllocate(op, result); checkPrivate(op, result); + checkReduction(op, result); }) .Case([&](omp::SingleOp op) { checkAllocate(op, result); @@ -289,8 +294,12 @@ static LogicalResult checkImplementationStatus(Operation &op) { checkAllocate(op, result); checkLinear(op, result); checkOrder(op, result); + checkReduction(op, result); + }) + .Case([&](omp::ParallelOp op) { + checkAllocate(op, result); + checkReduction(op, result); }) - .Case([&](omp::ParallelOp op) { checkAllocate(op, result); }) .Case([&](omp::SimdOp op) { checkLinear(op, result); checkNontemporal(op, result); diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index c611614265592..06fcf90e34480 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -1825,6 +1825,110 @@ func.func @omp_cancellationpoint2() { // ----- +omp.declare_reduction @add_f32 : f32 +init { + ^bb0(%arg: f32): + %0 = arith.constant 0.0 : f32 + omp.yield (%0 : f32) +} +combiner { + ^bb1(%arg0: f32, %arg1: f32): + %1 = arith.addf %arg0, %arg1 : f32 + omp.yield (%1 : f32) +} + +func.func @scan_test_2(%lb: i32, %ub: i32, %step: i32) { + %test1f32 = "test.f32"() : () -> (!llvm.ptr) + omp.wsloop reduction(mod:inscan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) { + omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) { + // expected-error @below {{Exactly one of EXCLUSIVE or INCLUSIVE clause is expected}} + omp.scan + omp.yield + } + } + return +} + +// ----- + +omp.declare_reduction @add_f32 : f32 +init { + ^bb0(%arg: f32): + %0 = arith.constant 0.0 : f32 + omp.yield (%0 : f32) +} +combiner { + ^bb1(%arg0: f32, %arg1: f32): + %1 = arith.addf %arg0, %arg1 : f32 + omp.yield (%1 : f32) +} + +func.func @scan_test_2(%lb: i32, %ub: i32, %step: i32) { + %test1f32 = "test.f32"() : () -> (!llvm.ptr) + omp.wsloop reduction(mod:inscan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) { + omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) { + // expected-error @below {{Exactly one of EXCLUSIVE or INCLUSIVE clause is expected}} + omp.scan inclusive(%test1f32 : !llvm.ptr) exclusive(%test1f32: !llvm.ptr) + omp.yield + } + } + return +} + +// ----- + +omp.declare_reduction @add_f32 : f32 +init { + ^bb0(%arg: f32): + %0 = arith.constant 0.0 : f32 + omp.yield (%0 : f32) +} +combiner { + ^bb1(%arg0: f32, %arg1: f32): + %1 = arith.addf %arg0, %arg1 : f32 + omp.yield (%1 : f32) +} + +func.func @scan_test_2(%lb: i32, %ub: i32, %step: i32) { + %test1f32 = "test.f32"() : () -> (!llvm.ptr) + omp.wsloop reduction(@add_f32 %test1f32 -> %arg1 : !llvm.ptr) { + omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) { + // expected-error @below {{SCAN directive needs to be enclosed within a parent worksharing loop construct or SIMD construct with INSCAN reduction modifier}} + omp.scan inclusive(%test1f32 : !llvm.ptr) + omp.yield + } + } + return +} + +// ----- + +omp.declare_reduction @add_f32 : f32 +init { + ^bb0(%arg: f32): + %0 = arith.constant 0.0 : f32 + omp.yield (%0 : f32) +} +combiner { + ^bb1(%arg0: f32, %arg1: f32): + %1 = arith.addf %arg0, %arg1 : f32 + omp.yield (%1 : f32) +} + +func.func @scan_test_2(%lb: i32, %ub: i32, %step: i32) { + %test1f32 = "test.f32"() : () -> (!llvm.ptr) + omp.taskloop reduction(mod:inscan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) { + omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) { + // expected-error @below {{SCAN directive needs to be enclosed within a parent worksharing loop construct or SIMD construct with INSCAN reduction modifier}} + omp.scan inclusive(%test1f32 : !llvm.ptr) + omp.yield + } + } + return +} + +// ----- + func.func @taskloop(%lb: i32, %ub: i32, %step: i32) { %testmemref = "test.memref"() : () -> (memref) // expected-error @below {{expected equal sizes for allocate and allocator variables}} diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index b1901c333ade8..c1259fabe82fb 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -900,6 +900,29 @@ func.func @wsloop_reduction(%lb : index, %ub : index, %step : index) { return } +// CHECK-LABEL: func @wsloop_inscan_reduction +func.func @wsloop_inscan_reduction(%lb : index, %ub : index, %step : index) { + %c1 = arith.constant 1 : i32 + %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr + // CHECK: reduction(mod: inscan, @add_f32 %{{.+}} -> %[[PRV:.+]] : !llvm.ptr) + omp.wsloop reduction(mod:inscan, @add_f32 %0 -> %prv : !llvm.ptr) { + omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) { + // CHECK: omp.scan inclusive(%{{.*}} : !llvm.ptr) + omp.scan inclusive(%prv : !llvm.ptr) + omp.yield + } + } + // CHECK: reduction(mod: inscan, @add_f32 %{{.+}} -> %[[PRV:.+]] : !llvm.ptr) + omp.wsloop reduction(mod:inscan, @add_f32 %0 -> %prv : !llvm.ptr) { + omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) { + // CHECK: omp.scan exclusive(%{{.*}} : !llvm.ptr) + omp.scan exclusive(%prv : !llvm.ptr) + omp.yield + } + } + return +} + // CHECK-LABEL: func @wsloop_reduction_byref func.func @wsloop_reduction_byref(%lb : index, %ub : index, %step : index) { %c1 = arith.constant 1 : i32 diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir index 392a6558dcfa6..30833474256a4 100644 --- a/mlir/test/Target/LLVMIR/openmp-todo.mlir +++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir @@ -186,6 +186,37 @@ llvm.func @simd_reduction(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { // ----- +omp.declare_reduction @add_f32 : f32 +init { +^bb0(%arg: f32): + %0 = llvm.mlir.constant(0.0 : f32) : f32 + omp.yield (%0 : f32) +} +combiner { +^bb1(%arg0: f32, %arg1: f32): + %1 = llvm.fadd %arg0, %arg1 : f32 + omp.yield (%1 : f32) +} +atomic { +^bb2(%arg2: !llvm.ptr, %arg3: !llvm.ptr): + %2 = llvm.load %arg3 : !llvm.ptr -> f32 + llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32 + omp.yield +} +llvm.func @scan_reduction(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { + // expected-error@below {{not yet implemented: Unhandled clause reduction with modifier in omp.wsloop operation}} + // expected-error@below {{LLVM Translation failed for operation: omp.wsloop}} + omp.wsloop reduction(mod:inscan, @add_f32 %x -> %prv : !llvm.ptr) { + omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { + omp.scan inclusive(%prv : !llvm.ptr) + omp.yield + } + } + llvm.return +} + +// ----- + llvm.func @single_allocate(%x : !llvm.ptr) { // expected-error@below {{not yet implemented: Unhandled clause allocate in omp.single operation}} // expected-error@below {{LLVM Translation failed for operation: omp.single}}