diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index f4876256a378f..02454543d0a60 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -1407,8 +1407,7 @@ bool ClauseProcessor::processUseDeviceAddr( const parser::CharBlock &source) { mlir::Location location = converter.genLocation(source); llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM; processMapObjects(stmtCtx, location, clause.v, mapTypeBits, parentMemberIndices, result.useDeviceAddrVars, useDeviceSyms); @@ -1429,8 +1428,7 @@ bool ClauseProcessor::processUseDevicePtr( const parser::CharBlock &source) { mlir::Location location = converter.genLocation(source); llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM; processMapObjects(stmtCtx, location, clause.v, mapTypeBits, parentMemberIndices, result.useDevicePtrVars, useDeviceSyms); diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp index 3f4cfb8c11a9d..173dceb07b193 100644 --- a/flang/lib/Lower/OpenMP/Utils.cpp +++ b/flang/lib/Lower/OpenMP/Utils.cpp @@ -398,14 +398,16 @@ mlir::Value createParentSymAndGenIntermediateMaps( interimBounds, treatIndexAsSection); } - // Remove all map TO, FROM and TOFROM bits, from the intermediate - // allocatable maps, we simply wish to alloc or release them. It may be - // safer to just pass OMP_MAP_NONE as the map type, but we may still + // Remove all map-type bits (e.g. TO, FROM, etc.) from the intermediate + // allocatable maps, as we simply wish to alloc or release them. It may + // be safer to just pass OMP_MAP_NONE as the map type, but we may still // need some of the other map types the mapped member utilises, so for // now it's good to keep an eye on this. llvm::omp::OpenMPOffloadMappingFlags interimMapType = mapTypeBits; interimMapType &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; interimMapType &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + interimMapType &= + ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM; // Create a map for the intermediate member and insert it and it's // indices into the parentMemberIndices list to track it. diff --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir index 8019ecf7f6a05..b13921f822b4d 100644 --- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir +++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir @@ -423,14 +423,15 @@ func.func @_QPopenmp_target_data_region() { func.func @_QPomp_target_data_empty() { %0 = fir.alloca !fir.array<1024xi32> {bindc_name = "a", uniq_name = "_QFomp_target_data_emptyEa"} - omp.target_data use_device_addr(%0 -> %arg0 : !fir.ref>) { + %1 = omp.map.info var_ptr(%0 : !fir.ref>, !fir.ref>) map_clauses(return_param) capture(ByRef) -> !fir.ref> {name = ""} + omp.target_data use_device_addr(%1 -> %arg0 : !fir.ref>) { omp.terminator } return } // CHECK-LABEL: llvm.func @_QPomp_target_data_empty -// CHECK: omp.target_data use_device_addr(%1 -> %{{.*}} : !llvm.ptr) { +// CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}} : !llvm.ptr) { // CHECK: } // ----- diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90 index 4815e6564fc7e..f04aacc63fc2b 100644 --- a/flang/test/Lower/OpenMP/target.f90 +++ b/flang/test/Lower/OpenMP/target.f90 @@ -544,7 +544,7 @@ subroutine omp_target_device_addr !CHECK: %[[VAL_0_DECL:.*]]:2 = hlfir.declare %[[VAL_0]] {fortran_attrs = #fir.var_attrs, uniq_name = "_QFomp_target_device_addrEa"} : (!fir.ref>>) -> (!fir.ref>>, !fir.ref>>) !CHECK: %[[MAP_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref>>, i32) map_clauses(tofrom) capture(ByRef) var_ptr_ptr({{.*}} : !fir.llvm_ptr>) -> !fir.llvm_ptr> {name = ""} !CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref>>, !fir.box>) map_clauses(to) capture(ByRef) members(%[[MAP_MEMBERS]] : [0] : !fir.llvm_ptr>) -> !fir.ref>> {name = "a"} - !CHECK: %[[DEV_ADDR_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref>>, i32) map_clauses(tofrom) capture(ByRef) var_ptr_ptr({{.*}} : !fir.llvm_ptr>) -> !fir.llvm_ptr> {name = ""} + !CHECK: %[[DEV_ADDR_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref>>, i32) map_clauses(return_param) capture(ByRef) var_ptr_ptr({{.*}} : !fir.llvm_ptr>) -> !fir.llvm_ptr> {name = ""} !CHECK: %[[DEV_ADDR:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref>>, !fir.box>) map_clauses(to) capture(ByRef) members(%[[DEV_ADDR_MEMBERS]] : [0] : !fir.llvm_ptr>) -> !fir.ref>> {name = "a"} !CHECK: omp.target_data map_entries(%[[MAP]], %[[MAP_MEMBERS]] : {{.*}}) use_device_addr(%[[DEV_ADDR]] -> %[[ARG_0:.*]], %[[DEV_ADDR_MEMBERS]] -> %[[ARG_1:.*]] : !fir.ref>>, !fir.llvm_ptr>) { !$omp target data map(tofrom: a) use_device_addr(a) diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 2bf7aaa46db11..deff86d5c5ecb 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1521,6 +1521,9 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) { if (mapTypeMod == "delete") mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE; + if (mapTypeMod == "return_param") + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM; + return success(); }; @@ -1583,6 +1586,12 @@ static void printMapClause(OpAsmPrinter &p, Operation *op, emitAllocRelease = false; mapTypeStrs.push_back("delete"); } + if (mapTypeToBitFlag( + mapTypeBits, + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM)) { + emitAllocRelease = false; + mapTypeStrs.push_back("return_param"); + } if (emitAllocRelease) mapTypeStrs.push_back("exit_release_or_enter_alloc"); @@ -1777,6 +1786,17 @@ static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) { // MapInfoOp //===----------------------------------------------------------------------===// +static LogicalResult verifyMapInfoDefinedArgs(Operation *op, + StringRef clauseName, + OperandRange vars) { + for (Value var : vars) + if (!llvm::isa_and_present(var.getDefiningOp())) + return op->emitOpError() + << "'" << clauseName + << "' arguments must be defined by 'omp.map.info' ops"; + return success(); +} + LogicalResult MapInfoOp::verify() { if (getMapperId() && !SymbolTable::lookupNearestSymbolFrom( @@ -1784,6 +1804,9 @@ LogicalResult MapInfoOp::verify() { return emitError("invalid mapper id"); } + if (failed(verifyMapInfoDefinedArgs(*this, "members", getMembers()))) + return failure(); + return success(); } @@ -1805,6 +1828,15 @@ LogicalResult TargetDataOp::verify() { "At least one of map, use_device_ptr_vars, or " "use_device_addr_vars operand must be present"); } + + if (failed(verifyMapInfoDefinedArgs(*this, "use_device_ptr", + getUseDevicePtrVars()))) + return failure(); + + if (failed(verifyMapInfoDefinedArgs(*this, "use_device_addr", + getUseDeviceAddrVars()))) + return failure(); + return verifyMapClause(*this, getMapVars()); } @@ -1889,16 +1921,15 @@ void TargetOp::build(OpBuilder &builder, OperationState &state, } LogicalResult TargetOp::verify() { - LogicalResult verifyDependVars = - verifyDependVarList(*this, getDependKinds(), getDependVars()); - - if (failed(verifyDependVars)) - return verifyDependVars; + if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars()))) + return failure(); - LogicalResult verifyMapVars = verifyMapClause(*this, getMapVars()); + if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr", + getHasDeviceAddrVars()))) + return failure(); - if (failed(verifyMapVars)) - return verifyMapVars; + if (failed(verifyMapClause(*this, getMapVars()))) + return failure(); return verifyPrivateVarsMapping(*this); } diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index b7e16b7ec35e2..a9e4af035dbd7 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -802,10 +802,14 @@ func.func @omp_target_data (%if_cond : i1, %device : si32, %device_ptr: memref, tensor) map_clauses(always, from) capture(ByRef) -> memref {name = ""} omp.target_data if(%if_cond) device(%device : si32) map_entries(%mapv1 : memref){} - // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_2:.*]] : memref, tensor) map_clauses(close, present, to) capture(ByRef) -> memref {name = ""} - // CHECK: omp.target_data map_entries(%[[MAP_A]] : memref) use_device_addr(%[[VAL_3:.*]] -> %{{.*}} : memref) use_device_ptr(%[[VAL_4:.*]] -> %{{.*}} : memref) + // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%{{.*}} : memref, tensor) map_clauses(close, present, to) capture(ByRef) -> memref {name = ""} + // CHECK: %[[DEV_ADDR:.*]] = omp.map.info var_ptr(%{{.*}} : memref, tensor) map_clauses(return_param) capture(ByRef) -> memref {name = ""} + // CHECK: %[[DEV_PTR:.*]] = omp.map.info var_ptr(%{{.*}} : memref, tensor) map_clauses(return_param) capture(ByRef) -> memref {name = ""} + // CHECK: omp.target_data map_entries(%[[MAP_A]] : memref) use_device_addr(%[[DEV_ADDR]] -> %{{.*}} : memref) use_device_ptr(%[[DEV_PTR]] -> %{{.*}} : memref) %mapv2 = omp.map.info var_ptr(%map1 : memref, tensor) map_clauses(close, present, to) capture(ByRef) -> memref {name = ""} - omp.target_data map_entries(%mapv2 : memref) use_device_addr(%device_addr -> %arg0 : memref) use_device_ptr(%device_ptr -> %arg1 : memref) { + %device_addrv1 = omp.map.info var_ptr(%device_addr : memref, tensor) map_clauses(return_param) capture(ByRef) -> memref {name = ""} + %device_ptrv1 = omp.map.info var_ptr(%device_ptr : memref, tensor) map_clauses(return_param) capture(ByRef) -> memref {name = ""} + omp.target_data map_entries(%mapv2 : memref) use_device_addr(%device_addrv1 -> %arg0 : memref) use_device_ptr(%device_ptrv1 -> %arg1 : memref) { omp.terminator }