diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index 4edea86b417c3..5ba93fefab3f9 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -3928,6 +3928,7 @@ class FIRToLLVMLowering mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, pattern); mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, pattern); + mlir::cf::populateAssertToLLVMConversionPattern(typeConverter, pattern); // Math operations that have not been converted yet must be converted // to Libm. if (!isAMDGCN) diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp index 3ad70e7279692..123d114ae1635 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -220,6 +220,7 @@ void ToyToLLVMLoweringPass::runOnOperation() { mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns); cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); + cf::populateAssertToLLVMConversionPattern(typeConverter, patterns); populateFuncToLLVMConversionPatterns(typeConverter, patterns); // The only remaining operation to lower from the `toy` dialect, is the diff --git a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h index b88c1e8b20f32..88f18022da9bb 100644 --- a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h +++ b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h @@ -29,6 +29,10 @@ namespace cf { /// Collect the patterns to convert from the ControlFlow dialect to LLVM. The /// conversion patterns capture the LLVMTypeConverter by reference meaning the /// references have to remain alive during the entire pattern lifetime. +/// +/// Note: This function does not populate the default cf.assert lowering. That +/// is because some platforms have a custom cf.assert lowering. The default +/// lowering can be populated with `populateAssertToLLVMConversionPattern`. void populateControlFlowToLLVMConversionPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns); diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp index 8672e7b849d9d..d0ffb94f3f96a 100644 --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -215,7 +215,6 @@ void mlir::cf::populateControlFlowToLLVMConversionPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns) { // clang-format off patterns.add< - AssertOpLowering, BranchOpLowering, CondBranchOpLowering, SwitchOpLowering>(converter); @@ -258,6 +257,7 @@ struct ConvertControlFlowToLLVM LLVMTypeConverter converter(ctx, options); RewritePatternSet patterns(ctx); mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); + mlir::cf::populateAssertToLLVMConversionPattern(converter, patterns); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) @@ -286,6 +286,7 @@ struct ControlFlowToLLVMDialectInterface RewritePatternSet &patterns) const final { mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); + mlir::cf::populateAssertToLLVMConversionPattern(typeConverter, patterns); } }; } // namespace diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index b3c3fd4956d0b..544fc57949e24 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -19,6 +19,59 @@ using namespace mlir; +LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp, + Location loc, OpBuilder &b, + StringRef name, + LLVM::LLVMFunctionType type) { + LLVM::LLVMFuncOp ret; + if (!(ret = moduleOp.template lookupSymbol(name))) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleOp.getBody()); + ret = b.create(loc, name, type, LLVM::Linkage::External); + } + return ret; +} + +static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp, + StringRef prefix) { + // Get a unique global name. + unsigned stringNumber = 0; + SmallString<16> stringConstName; + do { + stringConstName.clear(); + (prefix + Twine(stringNumber++)).toStringRef(stringConstName); + } while (moduleOp.lookupSymbol(stringConstName)); + return stringConstName; +} + +LLVM::GlobalOp +mlir::getOrCreateStringConstant(OpBuilder &b, Location loc, + gpu::GPUModuleOp moduleOp, Type llvmI8, + StringRef namePrefix, StringRef str, + uint64_t alignment, unsigned addrSpace) { + llvm::SmallString<20> nullTermStr(str); + nullTermStr.push_back('\0'); // Null terminate for C + auto globalType = + LLVM::LLVMArrayType::get(llvmI8, nullTermStr.size_in_bytes()); + StringAttr attr = b.getStringAttr(nullTermStr); + + // Try to find existing global. + for (auto globalOp : moduleOp.getOps()) + if (globalOp.getGlobalType() == globalType && globalOp.getConstant() && + globalOp.getValueAttr() == attr && + globalOp.getAlignment().value_or(0) == alignment && + globalOp.getAddrSpace() == addrSpace) + return globalOp; + + // Not found: create new global. + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleOp.getBody()); + SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix); + return b.create(loc, globalType, + /*isConstant=*/true, LLVM::Linkage::Internal, + name, attr, alignment, addrSpace); +} + LogicalResult GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -328,61 +381,6 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, return success(); } -static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) { - const char formatStringPrefix[] = "printfFormat_"; - // Get a unique global name. - unsigned stringNumber = 0; - SmallString<16> stringConstName; - do { - stringConstName.clear(); - (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName); - } while (moduleOp.lookupSymbol(stringConstName)); - return stringConstName; -} - -/// Create an global that contains the given format string. If a global with -/// the same format string exists already in the module, return that global. -static LLVM::GlobalOp getOrCreateFormatStringConstant( - OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8, - StringRef str, uint64_t alignment = 0, unsigned addrSpace = 0) { - llvm::SmallString<20> formatString(str); - formatString.push_back('\0'); // Null terminate for C - auto globalType = - LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes()); - StringAttr attr = b.getStringAttr(formatString); - - // Try to find existing global. - for (auto globalOp : moduleOp.getOps()) - if (globalOp.getGlobalType() == globalType && globalOp.getConstant() && - globalOp.getValueAttr() == attr && - globalOp.getAlignment().value_or(0) == alignment && - globalOp.getAddrSpace() == addrSpace) - return globalOp; - - // Not found: create new global. - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointToStart(moduleOp.getBody()); - SmallString<16> name = getUniqueFormatGlobalName(moduleOp); - return b.create(loc, globalType, - /*isConstant=*/true, LLVM::Linkage::Internal, - name, attr, alignment, addrSpace); -} - -template -static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc, - ConversionPatternRewriter &rewriter, - StringRef name, - LLVM::LLVMFunctionType type) { - LLVM::LLVMFuncOp ret; - if (!(ret = moduleOp.template lookupSymbol(name))) { - ConversionPatternRewriter::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(moduleOp.getBody()); - ret = rewriter.create(loc, name, type, - LLVM::Linkage::External); - } - return ret; -} - LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -420,8 +418,8 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( Value printfDesc = printfBeginCall.getResult(); // Create the global op or find an existing one. - LLVM::GlobalOp global = getOrCreateFormatStringConstant( - rewriter, loc, moduleOp, llvmI8, adaptor.getFormat()); + LLVM::GlobalOp global = getOrCreateStringConstant( + rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat()); // Get a pointer to the format string's first element and pass it to printf() Value globalPtr = rewriter.create( @@ -502,9 +500,9 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType); // Create the global op or find an existing one. - LLVM::GlobalOp global = getOrCreateFormatStringConstant( - rewriter, loc, moduleOp, llvmI8, adaptor.getFormat(), /*alignment=*/0, - addressSpace); + LLVM::GlobalOp global = getOrCreateStringConstant( + rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat(), + /*alignment=*/0, addressSpace); // Get a pointer to the format string's first element Value globalPtr = rewriter.create( @@ -546,8 +544,8 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType); // Create the global op or find an existing one. - LLVM::GlobalOp global = getOrCreateFormatStringConstant( - rewriter, loc, moduleOp, llvmI8, adaptor.getFormat()); + LLVM::GlobalOp global = getOrCreateStringConstant( + rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat()); // Get a pointer to the format string's first element Value globalPtr = rewriter.create(loc, global); diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h index 444a07a93ca36..e73a74845d2b6 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h @@ -14,6 +14,27 @@ namespace mlir { +//===----------------------------------------------------------------------===// +// Helper Functions +//===----------------------------------------------------------------------===// + +/// Find or create an external function declaration in the given module. +LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc, + OpBuilder &b, StringRef name, + LLVM::LLVMFunctionType type); + +/// Create a global that contains the given string. If a global with the same +/// string already exists in the module, return that global. +LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc, + gpu::GPUModuleOp moduleOp, Type llvmI8, + StringRef namePrefix, StringRef str, + uint64_t alignment = 0, + unsigned addrSpace = 0); + +//===----------------------------------------------------------------------===// +// Lowering Patterns +//===----------------------------------------------------------------------===// + /// Lowering for gpu.dynamic.shared.memory to LLVM dialect. The pattern first /// create a 0-sized global array symbol similar as LLVM expects. It constructs /// a memref descriptor with these values and return it. diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index e022d3ce6f636..2768929f460e2 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -25,6 +25,7 @@ #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/GPU/Transforms/Passes.h" @@ -236,6 +237,103 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern { } }; +/// Lowering of cf.assert into a conditional __assertfail. +struct AssertOpToAssertfailLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *ctx = rewriter.getContext(); + Location loc = assertOp.getLoc(); + Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8)); + Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32)); + Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64)); + Type ptrType = LLVM::LLVMPointerType::get(ctx); + Type voidType = LLVM::LLVMVoidType::get(ctx); + + // Find or create __assertfail function declaration. + auto moduleOp = assertOp->getParentOfType(); + auto assertfailType = LLVM::LLVMFunctionType::get( + voidType, {ptrType, ptrType, i32Type, ptrType, i64Type}); + LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction( + moduleOp, loc, rewriter, "__assertfail", assertfailType); + assertfailDecl.setPassthroughAttr( + ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn"))); + + // Split blocks and insert conditional branch. + // ^before: + // ... + // cf.cond_br %condition, ^after, ^assert + // ^assert: + // cf.assert + // cf.br ^after + // ^after: + // ... + Block *beforeBlock = assertOp->getBlock(); + Block *assertBlock = + rewriter.splitBlock(beforeBlock, assertOp->getIterator()); + Block *afterBlock = + rewriter.splitBlock(assertBlock, ++assertOp->getIterator()); + rewriter.setInsertionPointToEnd(beforeBlock); + rewriter.create(loc, adaptor.getArg(), afterBlock, + assertBlock); + rewriter.setInsertionPointToEnd(assertBlock); + rewriter.create(loc, afterBlock); + + // Continue cf.assert lowering. + rewriter.setInsertionPoint(assertOp); + + // Populate file name, file number and function name from the location of + // the AssertOp. + StringRef fileName = "(unknown)"; + StringRef funcName = "(unknown)"; + int32_t fileLine = 0; + while (auto callSiteLoc = dyn_cast(loc)) + loc = callSiteLoc.getCallee(); + if (auto fileLineColLoc = dyn_cast(loc)) { + fileName = fileLineColLoc.getFilename().strref(); + fileLine = fileLineColLoc.getStartLine(); + } else if (auto nameLoc = dyn_cast(loc)) { + funcName = nameLoc.getName().strref(); + if (auto fileLineColLoc = + dyn_cast(nameLoc.getChildLoc())) { + fileName = fileLineColLoc.getFilename().strref(); + fileLine = fileLineColLoc.getStartLine(); + } + } + + // Create constants. + auto getGlobal = [&](LLVM::GlobalOp global) { + // Get a pointer to the format string's first element. + Value globalPtr = rewriter.create( + loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()), + global.getSymNameAttr()); + Value start = + rewriter.create(loc, ptrType, global.getGlobalType(), + globalPtr, ArrayRef{0, 0}); + return start; + }; + Value assertMessage = getGlobal(getOrCreateStringConstant( + rewriter, loc, moduleOp, i8Type, "assert_message_", assertOp.getMsg())); + Value assertFile = getGlobal(getOrCreateStringConstant( + rewriter, loc, moduleOp, i8Type, "assert_file_", fileName)); + Value assertFunc = getGlobal(getOrCreateStringConstant( + rewriter, loc, moduleOp, i8Type, "assert_func_", funcName)); + Value assertLine = + rewriter.create(loc, i32Type, fileLine); + Value c1 = rewriter.create(loc, i64Type, 1); + + // Insert function call to __assertfail. + SmallVector arguments{assertMessage, assertFile, assertLine, + assertFunc, c1}; + rewriter.replaceOpWithNewOp(assertOp, assertfailDecl, + arguments); + return success(); + } +}; + /// Import the GPU Ops to NVVM Patterns. #include "GPUToNVVM.cpp.inc" @@ -358,7 +456,8 @@ void mlir::populateGpuToNVVMConversionPatterns( using gpu::index_lowering::IndexKind; using gpu::index_lowering::IntrType; populateWithGenerated(patterns); - patterns.add(converter); + patterns.add( + converter); patterns.add< gpu::index_lowering::OpLowering>( diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index a1cefe289a696..afebded1c3ea4 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -296,6 +296,7 @@ struct LowerGpuOpsToROCDLOpsPass populateVectorToLLVMConversionPatterns(converter, llvmPatterns); populateMathToLLVMConversionPatterns(converter, llvmPatterns); cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns); + cf::populateAssertToLLVMConversionPattern(converter, llvmPatterns); populateFuncToLLVMConversionPatterns(converter, llvmPatterns); populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns); populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime); diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp index 58fd3d565fce5..5d0003911bca8 100644 --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -304,6 +304,7 @@ void ConvertOpenMPToLLVMPass::runOnOperation() { LLVMTypeConverter converter(&getContext()); arith::populateArithToLLVMConversionPatterns(converter, patterns); cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); + cf::populateAssertToLLVMConversionPattern(converter, patterns); populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns); populateFuncToLLVMConversionPatterns(converter, patterns); populateOpenMPToLLVMConversionPatterns(converter, patterns); diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir index 748dfe8c68fc7..318f0f78efa5b 100644 --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -969,6 +969,35 @@ gpu.module @test_module_50 { } } +// CHECK-LABEL: gpu.module @test_module_51 +// CHECK: llvm.mlir.global internal constant @[[func_name:.*]]("(unknown)\00") {addr_space = 0 : i32} +// CHECK: llvm.mlir.global internal constant @[[file_name:.*]]("{{.*}}gpu-to-nvvm.mlir{{.*}}") {addr_space = 0 : i32} +// CHECK: llvm.mlir.global internal constant @[[message:.*]]("assert message\00") {addr_space = 0 : i32} +// CHECK: llvm.func @__assertfail(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, i64) attributes {passthrough = ["noreturn"]} +// CHECK: llvm.func @test_assert(%[[cond:.*]]: i1) attributes {gpu.kernel, nvvm.kernel} { +// CHECK: llvm.cond_br %[[cond]], ^[[after_block:.*]], ^[[assert_block:.*]] +// CHECK: ^[[assert_block]]: +// CHECK: %[[message_ptr:.*]] = llvm.mlir.addressof @[[message]] : !llvm.ptr +// CHECK: %[[message_start:.*]] = llvm.getelementptr %[[message_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<15 x i8> +// CHECK: %[[file_ptr:.*]] = llvm.mlir.addressof @[[file_name]] : !llvm.ptr +// CHECK: %[[file_start:.*]] = llvm.getelementptr %[[file_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<{{.*}} x i8> +// CHECK: %[[func_ptr:.*]] = llvm.mlir.addressof @[[func_name]] : !llvm.ptr +// CHECK: %[[func_start:.*]] = llvm.getelementptr %[[func_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<{{.*}} x i8> +// CHECK: %[[line_num:.*]] = llvm.mlir.constant({{.*}} : i32) : i32 +// CHECK: %[[ptr:.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: llvm.call @__assertfail(%[[message_start]], %[[file_start]], %[[line_num]], %[[func_start]], %[[ptr]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, i64) -> () +// CHECK: llvm.br ^[[after_block]] +// CHECK: ^[[after_block]]: +// CHECK: llvm.return +// CHECK: } + +gpu.module @test_module_51 { + gpu.func @test_assert(%arg0: i1) kernel { + cf.assert %arg0, "assert message" + gpu.return + } +} + module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) { %gpu_module = transform.structured.match ops{["gpu.module"]} in %toplevel_module diff --git a/mlir/test/Integration/GPU/CUDA/assert.mlir b/mlir/test/Integration/GPU/CUDA/assert.mlir new file mode 100644 index 0000000000000..06a9c1ca0d114 --- /dev/null +++ b/mlir/test/Integration/GPU/CUDA/assert.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-opt %s -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-cpu-runner \ +// RUN: --shared-libs=%mlir_cuda_runtime \ +// RUN: --shared-libs=%mlir_runner_utils \ +// RUN: --entry-point-result=void 2>&1 \ +// RUN: | FileCheck %s + +// CHECK-DAG: thread 0: print after passing assertion +// CHECK-DAG: thread 1: print after passing assertion +// CHECK-DAG: callee_file.cc:7: callee_func_name: block: [0,0,0], thread: [0,0,0] Assertion `failing assertion` failed. +// CHECK-DAG: callee_file.cc:7: callee_func_name: block: [0,0,0], thread: [1,0,0] Assertion `failing assertion` failed. +// CHECK-NOT: print after failing assertion + +module attributes {gpu.container_module} { +gpu.module @kernels { +gpu.func @test_assert(%c0: i1, %c1: i1) kernel { + %0 = gpu.thread_id x + cf.assert %c1, "passing assertion" + gpu.printf "thread %lld: print after passing assertion\n" %0 : index + // Test callsite(callsite(name)) location. + cf.assert %c0, "failing assertion" loc(callsite(callsite("callee_func_name"("callee_file.cc":7:9) at "caller_file.cc":10:8) at "caller2_file.cc":11:12)) + gpu.printf "thread %lld: print after failing assertion\n" %0 : index + gpu.return +} +} + +func.func @main() { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c0_i1 = arith.constant 0 : i1 + %c1_i1 = arith.constant 1 : i1 + gpu.launch_func @kernels::@test_assert + blocks in (%c1, %c1, %c1) + threads in (%c2, %c1, %c1) + args(%c0_i1 : i1, %c1_i1 : i1) + return +} +}