diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index 299d9d438f115..febc6adcf9d6f 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -344,6 +344,20 @@ bool ClauseProcessor::processDistSchedule( return false; } +bool ClauseProcessor::processExclusive( + mlir::Location currentLocation, + mlir::omp::ExclusiveClauseOps &result) const { + if (auto *clause = findUniqueClause()) { + for (const Object &object : clause->v) { + const semantics::Symbol *symbol = object.sym(); + mlir::Value symVal = converter.getSymbolAddress(*symbol); + result.exclusiveVars.push_back(symVal); + } + return true; + } + return false; +} + bool ClauseProcessor::processFilter(lower::StatementContext &stmtCtx, mlir::omp::FilterClauseOps &result) const { if (auto *clause = findUniqueClause()) { @@ -380,6 +394,20 @@ bool ClauseProcessor::processHint(mlir::omp::HintClauseOps &result) const { return false; } +bool ClauseProcessor::processInclusive( + mlir::Location currentLocation, + mlir::omp::InclusiveClauseOps &result) const { + if (auto *clause = findUniqueClause()) { + for (const Object &object : clause->v) { + const semantics::Symbol *symbol = object.sym(); + mlir::Value symVal = converter.getSymbolAddress(*symbol); + result.inclusiveVars.push_back(symVal); + } + return true; + } + return false; +} + bool ClauseProcessor::processMergeable( mlir::omp::MergeableClauseOps &result) const { return markClauseOccurrence(result.mergeable); @@ -1135,10 +1163,9 @@ bool ClauseProcessor::processReduction( llvm::SmallVector reductionDeclSymbols; llvm::SmallVector reductionSyms; ReductionProcessor rp; - rp.addDeclareReduction(currentLocation, converter, clause, - reductionVars, reduceVarByRef, - reductionDeclSymbols, reductionSyms); - + rp.processReductionArguments( + currentLocation, converter, clause, reductionVars, reduceVarByRef, + reductionDeclSymbols, reductionSyms, result.reductionMod); // Copy local lists into the output. llvm::copy(reductionVars, std::back_inserter(result.reductionVars)); llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref)); diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index 7b047d4a7567a..e05f66c766684 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -64,6 +64,8 @@ class ClauseProcessor { bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const; bool processDistSchedule(lower::StatementContext &stmtCtx, mlir::omp::DistScheduleClauseOps &result) const; + bool processExclusive(mlir::Location currentLocation, + mlir::omp::ExclusiveClauseOps &result) const; bool processFilter(lower::StatementContext &stmtCtx, mlir::omp::FilterClauseOps &result) const; bool processFinal(lower::StatementContext &stmtCtx, @@ -72,6 +74,8 @@ class ClauseProcessor { mlir::omp::HasDeviceAddrClauseOps &result, llvm::SmallVectorImpl &isDeviceSyms) const; bool processHint(mlir::omp::HintClauseOps &result) const; + bool processInclusive(mlir::Location currentLocation, + mlir::omp::InclusiveClauseOps &result) const; bool processMergeable(mlir::omp::MergeableClauseOps &result) const; bool processNowait(mlir::omp::NowaitClauseOps &result) const; bool processNumTeams(lower::StatementContext &stmtCtx, diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp index b424e209d56da..a26bdcdf343e1 100644 --- a/flang/lib/Lower/OpenMP/Clauses.cpp +++ b/flang/lib/Lower/OpenMP/Clauses.cpp @@ -728,8 +728,8 @@ Enter make(const parser::OmpClause::Enter &inp, Exclusive make(const parser::OmpClause::Exclusive &inp, semantics::SemanticsContext &semaCtx) { - // inp -> empty - llvm_unreachable("Empty: exclusive"); + // inp.v -> parser::OmpObjectList + return Exclusive{makeObjects(/*List=*/inp.v, semaCtx)}; } Fail make(const parser::OmpClause::Fail &inp, @@ -838,8 +838,8 @@ If make(const parser::OmpClause::If &inp, Inclusive make(const parser::OmpClause::Inclusive &inp, semantics::SemanticsContext &semaCtx) { - // inp -> empty - llvm_unreachable("Empty: inclusive"); + // inp.v -> parser::OmpObjectList + return Inclusive{makeObjects(/*List=*/inp.v, semaCtx)}; } Indirect make(const parser::OmpClause::Indirect &inp, diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 1434bcd6330e0..48e8e433e1f1f 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -1578,6 +1578,15 @@ static void genParallelClauses( cp.processReduction(loc, clauseOps, reductionSyms); } +static void genScanClauses(lower::AbstractConverter &converter, + semantics::SemanticsContext &semaCtx, + const List &clauses, mlir::Location loc, + mlir::omp::ScanOperands &clauseOps) { + ClauseProcessor cp(converter, semaCtx, clauses); + cp.processInclusive(loc, clauseOps); + cp.processExclusive(loc, clauseOps); +} + static void genSectionsClauses( lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, const List &clauses, mlir::Location loc, @@ -1975,6 +1984,16 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable, return parallelOp; } +static mlir::omp::ScanOp +genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable, + semantics::SemanticsContext &semaCtx, mlir::Location loc, + const ConstructQueue &queue, ConstructQueue::const_iterator item) { + mlir::omp::ScanOperands clauseOps; + genScanClauses(converter, semaCtx, item->clauses, loc, clauseOps); + return converter.getFirOpBuilder().create( + converter.getCurrentLocation(), clauseOps); +} + /// This breaks the normal prototype of the gen*Op functions: adding the /// sectionBlocks argument so that the enclosed section constructs can be /// lowered here with correct reduction symbol remapping. @@ -2978,7 +2997,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter, genStandaloneParallel(converter, symTable, semaCtx, eval, loc, queue, item); break; case llvm::omp::Directive::OMPD_scan: - TODO(loc, "Unhandled directive " + llvm::omp::getOpenMPDirectiveName(dir)); + genScanOp(converter, symTable, semaCtx, loc, queue, item); break; case llvm::omp::Directive::OMPD_section: llvm_unreachable("genOMPDispatch: OMPD_section"); diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp index 2cd21107a916e..2036dc82d1aa0 100644 --- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp @@ -31,6 +31,9 @@ static llvm::cl::opt forceByrefReduction( llvm::cl::desc("Pass all reduction arguments by reference"), llvm::cl::Hidden); +using ReductionModifier = + Fortran::lower::omp::clause::Reduction::ReductionModifier; + namespace Fortran { namespace lower { namespace omp { @@ -514,18 +517,36 @@ static bool doReductionByRef(mlir::Value reductionVar) { return false; } -void ReductionProcessor::addDeclareReduction( +mlir::omp::ReductionModifier translateReductionModifier(ReductionModifier mod) { + switch (mod) { + case ReductionModifier::Default: + return mlir::omp::ReductionModifier::defaultmod; + case ReductionModifier::Inscan: + return mlir::omp::ReductionModifier::inscan; + case ReductionModifier::Task: + return mlir::omp::ReductionModifier::task; + } + return mlir::omp::ReductionModifier::defaultmod; +} + +void ReductionProcessor::processReductionArguments( mlir::Location currentLocation, lower::AbstractConverter &converter, const omp::clause::Reduction &reduction, llvm::SmallVectorImpl &reductionVars, llvm::SmallVectorImpl &reduceVarByRef, llvm::SmallVectorImpl &reductionDeclSymbols, - llvm::SmallVectorImpl &reductionSymbols) { + llvm::SmallVectorImpl &reductionSymbols, + mlir::omp::ReductionModifierAttr &reductionMod) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - if (std::get>( - reduction.t)) - TODO(currentLocation, "Reduction modifiers are not supported"); + auto mod = std::get>(reduction.t); + if (mod.has_value()) { + if (mod.value() == ReductionModifier::Task) + TODO(currentLocation, "Reduction modifier `task` is not supported"); + else + reductionMod = mlir::omp::ReductionModifierAttr::get( + firOpBuilder.getContext(), translateReductionModifier(mod.value())); + } mlir::omp::DeclareReductionOp decl; const auto &redOperatorList{ diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h index 5f4d742b62cb1..8a463b08faa8e 100644 --- a/flang/lib/Lower/OpenMP/ReductionProcessor.h +++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h @@ -19,6 +19,7 @@ #include "flang/Parser/parse-tree.h" #include "flang/Semantics/symbol.h" #include "flang/Semantics/type.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/IR/Location.h" #include "mlir/IR/Types.h" @@ -120,13 +121,14 @@ class ReductionProcessor { /// Creates a reduction declaration and associates it with an OpenMP block /// directive. - static void addDeclareReduction( + static void processReductionArguments( mlir::Location currentLocation, lower::AbstractConverter &converter, const omp::clause::Reduction &reduction, llvm::SmallVectorImpl &reductionVars, llvm::SmallVectorImpl &reduceVarByRef, llvm::SmallVectorImpl &reductionDeclSymbols, - llvm::SmallVectorImpl &reductionSymbols); + llvm::SmallVectorImpl &reductionSymbols, + mlir::omp::ReductionModifierAttr &reductionMod); }; template diff --git a/flang/test/Lower/OpenMP/Todo/reduction-inscan.f90 b/flang/test/Lower/OpenMP/Todo/reduction-inscan.f90 deleted file mode 100644 index 152d91a16f80f..0000000000000 --- a/flang/test/Lower/OpenMP/Todo/reduction-inscan.f90 +++ /dev/null @@ -1,15 +0,0 @@ -! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s -! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s - -! CHECK: not yet implemented: Reduction modifiers are not supported -subroutine reduction_inscan() - integer :: i,j - i = 0 - - !$omp do reduction(inscan, +:i) - do j=1,10 - !$omp scan inclusive(i) - i = i + 1 - end do - !$omp end do -end subroutine reduction_inscan diff --git a/flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90 b/flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90 deleted file mode 100644 index 82625ed8c5f31..0000000000000 --- a/flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90 +++ /dev/null @@ -1,14 +0,0 @@ -! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s -! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s - -! CHECK: not yet implemented: Reduction modifiers are not supported - -subroutine foo() - integer :: i, j - j = 0 - !$omp do reduction (inscan, *: j) - do i = 1, 10 - !$omp scan inclusive(j) - j = j + 1 - end do -end subroutine diff --git a/flang/test/Lower/OpenMP/Todo/reduction-task.f90 b/flang/test/Lower/OpenMP/Todo/reduction-task.f90 index 6707f65e1a4cc..b8bfc37d1758f 100644 --- a/flang/test/Lower/OpenMP/Todo/reduction-task.f90 +++ b/flang/test/Lower/OpenMP/Todo/reduction-task.f90 @@ -1,7 +1,7 @@ ! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s ! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s -! CHECK: not yet implemented: Reduction modifiers are not supported +! CHECK: not yet implemented: Reduction modifier `task` is not supported subroutine reduction_task() integer :: i i = 0 diff --git a/flang/test/Lower/OpenMP/scan.f90 b/flang/test/Lower/OpenMP/scan.f90 new file mode 100644 index 0000000000000..97b672ec41f20 --- /dev/null +++ b/flang/test/Lower/OpenMP/scan.f90 @@ -0,0 +1,36 @@ +! RUN: bbc -emit-hlfir -fopenmp %s -o - | FileCheck %s +! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s + +! CHECK: omp.wsloop reduction(mod: inscan, @add_reduction_i32 %{{.*}} -> %[[RED_ARG_1:.*]] : {{.*}}) { +! CHECK: %[[RED_DECL_1:.*]]:2 = hlfir.declare %[[RED_ARG_1]] +! CHECK: omp.scan inclusive(%[[RED_DECL_1]]#1 : {{.*}}) + +subroutine inclusive_scan(a, b, n) + implicit none + integer a(:), b(:) + integer x, k, n + + !$omp parallel do reduction(inscan, +: x) + do k = 1, n + x = x + a(k) + !$omp scan inclusive(x) + b(k) = x + end do +end subroutine inclusive_scan + + +! CHECK: omp.wsloop reduction(mod: inscan, @add_reduction_i32 %{{.*}} -> %[[RED_ARG_2:.*]] : {{.*}}) { +! CHECK: %[[RED_DECL_2:.*]]:2 = hlfir.declare %[[RED_ARG_2]] +! CHECK: omp.scan exclusive(%[[RED_DECL_2]]#1 : {{.*}}) +subroutine exclusive_scan(a, b, n) + implicit none + integer a(:), b(:) + integer x, k, n + + !$omp parallel do reduction(inscan, +: x) + do k = 1, n + x = x + a(k) + !$omp scan exclusive(x) + b(k) = x + end do +end subroutine exclusive_scan diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp index 5d0003911bca8..89444118a9d04 100644 --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -226,7 +226,7 @@ void mlir::configureOpenMPToLLVMConversionLegality( target.addDynamicallyLegalOp< omp::AtomicReadOp, omp::AtomicWriteOp, omp::CancellationPointOp, omp::CancelOp, omp::CriticalDeclareOp, omp::FlushOp, omp::MapBoundsOp, - omp::MapInfoOp, omp::OrderedOp, omp::TargetEnterDataOp, + omp::MapInfoOp, omp::OrderedOp, omp::ScanOp, omp::TargetEnterDataOp, omp::TargetExitDataOp, omp::TargetUpdateOp, omp::ThreadprivateOp, omp::YieldOp>([&](Operation *op) { return typeConverter.isLegal(op->getOperandTypes()) && @@ -264,6 +264,7 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, RegionLessOpConversion, RegionLessOpConversion, RegionLessOpConversion, + RegionLessOpConversion, RegionLessOpConversion, RegionLessOpConversion, RegionLessOpConversion,