diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index 082f2b15512b8..75546c600f733 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -24,6 +24,7 @@ #include "flang/Optimizer/Support/TypeCode.h" #include "flang/Optimizer/Support/Utils.h" #include "flang/Runtime/CUDA/descriptor.h" +#include "flang/Runtime/CUDA/memory.h" #include "flang/Runtime/allocator-registry-consts.h" #include "flang/Runtime/descriptor-consts.h" #include "flang/Semantics/runtime-type-info.h" @@ -1135,6 +1136,93 @@ convertSubcomponentIndices(mlir::Location loc, mlir::Type eleTy, return result; } +static mlir::Value genSourceFile(mlir::Location loc, mlir::ModuleOp mod, + mlir::ConversionPatternRewriter &rewriter) { + auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); + if (auto flc = mlir::dyn_cast(loc)) { + auto fn = flc.getFilename().str() + '\0'; + std::string globalName = fir::factory::uniqueCGIdent("cl", fn); + + if (auto g = mod.lookupSymbol(globalName)) { + return rewriter.create(loc, ptrTy, g.getName()); + } else if (auto g = mod.lookupSymbol(globalName)) { + return rewriter.create(loc, ptrTy, g.getName()); + } + + auto crtInsPt = rewriter.saveInsertionPoint(); + rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end()); + auto arrayTy = mlir::LLVM::LLVMArrayType::get( + mlir::IntegerType::get(rewriter.getContext(), 8), fn.size()); + mlir::LLVM::GlobalOp globalOp = rewriter.create( + loc, arrayTy, /*constant=*/true, mlir::LLVM::Linkage::Linkonce, + globalName, mlir::Attribute()); + + mlir::Region ®ion = globalOp.getInitializerRegion(); + mlir::Block *block = rewriter.createBlock(®ion); + rewriter.setInsertionPoint(block, block->begin()); + mlir::Value constValue = rewriter.create( + loc, arrayTy, rewriter.getStringAttr(fn)); + rewriter.create(loc, constValue); + rewriter.restoreInsertionPoint(crtInsPt); + return rewriter.create(loc, ptrTy, + globalOp.getName()); + } + return rewriter.create(loc, ptrTy); +} + +static mlir::Value genSourceLine(mlir::Location loc, + mlir::ConversionPatternRewriter &rewriter) { + if (auto flc = mlir::dyn_cast(loc)) + return rewriter.create(loc, rewriter.getI32Type(), + flc.getLine()); + return rewriter.create(loc, rewriter.getI32Type(), 0); +} + +static mlir::Value +genCUFAllocDescriptor(mlir::Location loc, + mlir::ConversionPatternRewriter &rewriter, + mlir::ModuleOp mod, fir::BaseBoxType boxTy, + const fir::LLVMTypeConverter &typeConverter) { + std::optional dl = + fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/true); + if (!dl) + mlir::emitError(mod.getLoc(), + "module operation must carry a data layout attribute " + "to generate llvm IR from FIR"); + + mlir::Value sourceFile = genSourceFile(loc, mod, rewriter); + mlir::Value sourceLine = genSourceLine(loc, rewriter); + + mlir::MLIRContext *ctx = mod.getContext(); + + mlir::LLVM::LLVMPointerType llvmPointerType = + mlir::LLVM::LLVMPointerType::get(ctx); + mlir::Type llvmInt32Type = mlir::IntegerType::get(ctx, 32); + mlir::Type llvmIntPtrType = + mlir::IntegerType::get(ctx, typeConverter.getPointerBitwidth(0)); + auto fctTy = mlir::LLVM::LLVMFunctionType::get( + llvmPointerType, {llvmIntPtrType, llvmPointerType, llvmInt32Type}); + + auto llvmFunc = mod.lookupSymbol( + RTNAME_STRING(CUFAllocDesciptor)); + auto funcFunc = + mod.lookupSymbol(RTNAME_STRING(CUFAllocDesciptor)); + if (!llvmFunc && !funcFunc) + mlir::OpBuilder::atBlockEnd(mod.getBody()) + .create(loc, RTNAME_STRING(CUFAllocDesciptor), + fctTy); + + mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy); + std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8; + mlir::Value sizeInBytes = + genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize); + llvm::SmallVector args = {sizeInBytes, sourceFile, sourceLine}; + return rewriter + .create(loc, fctTy, RTNAME_STRING(CUFAllocDesciptor), + args) + .getResult(); +} + /// Common base class for embox to descriptor conversion. template struct EmboxCommonConversion : public fir::FIROpConversion { @@ -1548,15 +1636,24 @@ struct EmboxCommonConversion : public fir::FIROpConversion { mlir::Value placeInMemoryIfNotGlobalInit(mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, mlir::Type boxTy, - mlir::Value boxValue) const { + mlir::Value boxValue, + bool needDeviceAllocation = false) const { if (isInGlobalOp(rewriter)) return boxValue; mlir::Type llvmBoxTy = boxValue.getType(); - auto alloca = this->genAllocaAndAddrCastWithType(loc, llvmBoxTy, - defaultAlign, rewriter); - auto storeOp = rewriter.create(loc, boxValue, alloca); + mlir::Value storage; + if (needDeviceAllocation) { + auto mod = boxValue.getDefiningOp()->getParentOfType(); + auto baseBoxTy = mlir::dyn_cast(boxTy); + storage = + genCUFAllocDescriptor(loc, rewriter, mod, baseBoxTy, this->lowerTy()); + } else { + storage = this->genAllocaAndAddrCastWithType(loc, llvmBoxTy, defaultAlign, + rewriter); + } + auto storeOp = rewriter.create(loc, boxValue, storage); this->attachTBAATag(storeOp, boxTy, boxTy, nullptr); - return alloca; + return storage; } }; @@ -1608,6 +1705,18 @@ struct EmboxOpConversion : public EmboxCommonConversion { } }; +static bool isDeviceAllocation(mlir::Value val) { + if (auto convertOp = + mlir::dyn_cast_or_null(val.getDefiningOp())) + val = convertOp.getValue(); + if (auto callOp = mlir::dyn_cast_or_null(val.getDefiningOp())) + if (callOp.getCallee() && + callOp.getCallee().value().getRootReference().getValue().starts_with( + RTNAME_STRING(CUFMemAlloc))) + return true; + return false; +} + /// Create a generic box on a memory reference. struct XEmboxOpConversion : public EmboxCommonConversion { using EmboxCommonConversion::EmboxCommonConversion; @@ -1791,9 +1900,8 @@ struct XEmboxOpConversion : public EmboxCommonConversion { dest = insertBaseAddress(rewriter, loc, dest, base); if (fir::isDerivedTypeWithLenParams(boxTy)) TODO(loc, "fir.embox codegen of derived with length parameters"); - - mlir::Value result = - placeInMemoryIfNotGlobalInit(rewriter, loc, boxTy, dest); + mlir::Value result = placeInMemoryIfNotGlobalInit( + rewriter, loc, boxTy, dest, isDeviceAllocation(xbox.getMemref())); rewriter.replaceOp(xbox, result); return mlir::success(); } @@ -2971,93 +3079,6 @@ struct GlobalOpConversion : public fir::FIROpConversion { } }; -static mlir::Value genSourceFile(mlir::Location loc, mlir::ModuleOp mod, - mlir::ConversionPatternRewriter &rewriter) { - auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); - if (auto flc = mlir::dyn_cast(loc)) { - auto fn = flc.getFilename().str() + '\0'; - std::string globalName = fir::factory::uniqueCGIdent("cl", fn); - - if (auto g = mod.lookupSymbol(globalName)) { - return rewriter.create(loc, ptrTy, g.getName()); - } else if (auto g = mod.lookupSymbol(globalName)) { - return rewriter.create(loc, ptrTy, g.getName()); - } - - auto crtInsPt = rewriter.saveInsertionPoint(); - rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end()); - auto arrayTy = mlir::LLVM::LLVMArrayType::get( - mlir::IntegerType::get(rewriter.getContext(), 8), fn.size()); - mlir::LLVM::GlobalOp globalOp = rewriter.create( - loc, arrayTy, /*constant=*/true, mlir::LLVM::Linkage::Linkonce, - globalName, mlir::Attribute()); - - mlir::Region ®ion = globalOp.getInitializerRegion(); - mlir::Block *block = rewriter.createBlock(®ion); - rewriter.setInsertionPoint(block, block->begin()); - mlir::Value constValue = rewriter.create( - loc, arrayTy, rewriter.getStringAttr(fn)); - rewriter.create(loc, constValue); - rewriter.restoreInsertionPoint(crtInsPt); - return rewriter.create(loc, ptrTy, - globalOp.getName()); - } - return rewriter.create(loc, ptrTy); -} - -static mlir::Value genSourceLine(mlir::Location loc, - mlir::ConversionPatternRewriter &rewriter) { - if (auto flc = mlir::dyn_cast(loc)) - return rewriter.create(loc, rewriter.getI32Type(), - flc.getLine()); - return rewriter.create(loc, rewriter.getI32Type(), 0); -} - -static mlir::Value -genCUFAllocDescriptor(mlir::Location loc, - mlir::ConversionPatternRewriter &rewriter, - mlir::ModuleOp mod, fir::BaseBoxType boxTy, - const fir::LLVMTypeConverter &typeConverter) { - std::optional dl = - fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/true); - if (!dl) - mlir::emitError(mod.getLoc(), - "module operation must carry a data layout attribute " - "to generate llvm IR from FIR"); - - mlir::Value sourceFile = genSourceFile(loc, mod, rewriter); - mlir::Value sourceLine = genSourceLine(loc, rewriter); - - mlir::MLIRContext *ctx = mod.getContext(); - - mlir::LLVM::LLVMPointerType llvmPointerType = - mlir::LLVM::LLVMPointerType::get(ctx); - mlir::Type llvmInt32Type = mlir::IntegerType::get(ctx, 32); - mlir::Type llvmIntPtrType = - mlir::IntegerType::get(ctx, typeConverter.getPointerBitwidth(0)); - auto fctTy = mlir::LLVM::LLVMFunctionType::get( - llvmPointerType, {llvmIntPtrType, llvmPointerType, llvmInt32Type}); - - auto llvmFunc = mod.lookupSymbol( - RTNAME_STRING(CUFAllocDesciptor)); - auto funcFunc = - mod.lookupSymbol(RTNAME_STRING(CUFAllocDesciptor)); - if (!llvmFunc && !funcFunc) - mlir::OpBuilder::atBlockEnd(mod.getBody()) - .create(loc, RTNAME_STRING(CUFAllocDesciptor), - fctTy); - - mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy); - std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8; - mlir::Value sizeInBytes = - genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize); - llvm::SmallVector args = {sizeInBytes, sourceFile, sourceLine}; - return rewriter - .create(loc, fctTy, RTNAME_STRING(CUFAllocDesciptor), - args) - .getResult(); -} - /// `fir.load` --> `llvm.load` struct LoadOpConversion : public fir::FIROpConversion { using FIROpConversion::FIROpConversion; diff --git a/flang/test/Fir/CUDA/cuda-code-gen.mlir b/flang/test/Fir/CUDA/cuda-code-gen.mlir index 55e473ef2549e..a34c2770c5f6c 100644 --- a/flang/test/Fir/CUDA/cuda-code-gen.mlir +++ b/flang/test/Fir/CUDA/cuda-code-gen.mlir @@ -1,7 +1,6 @@ // RUN: fir-opt --split-input-file --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" %s | FileCheck %s module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry, dense<64> : vector<4xi64>>, #dlti.dl_entry, dense<32> : vector<4xi64>>, #dlti.dl_entry, dense<32> : vector<4xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} { - func.func @_QQmain() attributes {fir.bindc_name = "cufkernel_global"} { %c0 = arith.constant 0 : index %0 = fir.address_of(@_QQclX3C737464696E3E00) : !fir.ref> @@ -27,3 +26,33 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry : } func.func private @_FortranACUFAllocDesciptor(i64, !fir.ref, i32) -> !fir.ref> attributes {fir.runtime} } + +// ----- + +module attributes {dlti.dl_spec = #dlti.dl_spec : vector<2xi64>, i128 = dense<128> : vector<2xi64>, i64 = dense<64> : vector<2xi64>, !llvm.ptr<272> = dense<64> : vector<4xi64>, !llvm.ptr<271> = dense<32> : vector<4xi64>, !llvm.ptr<270> = dense<32> : vector<4xi64>, f128 = dense<128> : vector<2xi64>, f64 = dense<64> : vector<2xi64>, f16 = dense<16> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, i16 = dense<16> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, i1 = dense<8> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>, "dlti.endianness" = "little", "dlti.stack_alignment" = 128 : i64>} { + func.func @_QQmain() attributes {fir.bindc_name = "test"} { + %c10 = arith.constant 10 : index + %c20 = arith.constant 20 : index + %0 = fir.address_of(@_QQclX64756D6D792E6D6C697200) : !fir.ref> + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %1 = arith.muli %c200, %c4 : index + %c6_i32 = arith.constant 6 : i32 + %c0_i32 = arith.constant 0 : i32 + %2 = fir.convert %1 : (index) -> i64 + %3 = fir.convert %0 : (!fir.ref>) -> !fir.ref + %4 = fir.call @_FortranACUFMemAlloc(%2, %c0_i32, %3, %c6_i32) : (i64, i32, !fir.ref, i32) -> !fir.llvm_ptr + %5 = fir.convert %4 : (!fir.llvm_ptr) -> !fir.ref> + %6 = fircg.ext_embox %5(%c10, %c20) : (!fir.ref>, index, index) -> !fir.box> + return + } + fir.global linkonce @_QQclX64756D6D792E6D6C697200 constant : !fir.char<1,11> { + %0 = fir.string_lit "dummy.mlir\00"(11) : !fir.char<1,11> + fir.has_value %0 : !fir.char<1,11> + } + func.func private @_FortranACUFMemAlloc(i64, i32, !fir.ref, i32) -> !fir.llvm_ptr attributes {fir.runtime} +} + +// CHECK-LABEL: llvm.func @_QQmain() +// CHECK: llvm.call @_FortranACUFMemAlloc +// CHECK: llvm.call @_FortranACUFAllocDesciptor