From 5ad1d28725c931590083cb46f365e98205fef35f Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Mon, 5 May 2025 14:03:46 +0200 Subject: [PATCH 1/3] [mlir][amx] Simplify intrinsic generation Replaces separate amx named intrinsic operations with direct calls to LLVM intrinsic functions. The existing amx tests are updated and expanded. The separate conversion step translating amx intrinsics into LLVM IR is eliminated. Instead, this step is now performed by the existing llvm dialect infrastructure. Related RFC: https://discourse.llvm.org/t/rfc-simplify-x86-intrinsic-generation/85581 --- mlir/include/mlir/Dialect/AMX/AMX.td | 157 ++++++------ mlir/include/mlir/Dialect/AMX/AMXDialect.h | 4 + .../include/mlir/Dialect/AMX/AMXInterfaces.td | 31 +++ mlir/include/mlir/Dialect/AMX/CMakeLists.txt | 5 +- mlir/include/mlir/Dialect/AMX/Transforms.h | 3 - mlir/include/mlir/InitAllExtensions.h | 2 - .../Dialect/AMX/AMXToLLVMIRTranslation.h | 31 --- mlir/include/mlir/Target/LLVMIR/Dialect/All.h | 2 - mlir/lib/Dialect/AMX/IR/AMXDialect.cpp | 190 ++++++++++++++- mlir/lib/Dialect/AMX/IR/CMakeLists.txt | 1 + .../lib/Dialect/AMX/Transforms/CMakeLists.txt | 3 - .../AMX/Transforms/LegalizeForLLVMExport.cpp | 224 ++---------------- mlir/lib/Target/LLVMIR/CMakeLists.txt | 1 - .../Dialect/AMX/AMXToLLVMIRTranslation.cpp | 56 ----- .../Target/LLVMIR/Dialect/AMX/CMakeLists.txt | 16 -- mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt | 1 - mlir/test/Dialect/AMX/legalize-for-llvm.mlir | 54 ++--- mlir/test/Target/LLVMIR/amx.mlir | 97 +++++++- 18 files changed, 432 insertions(+), 446 deletions(-) create mode 100644 mlir/include/mlir/Dialect/AMX/AMXInterfaces.td delete mode 100644 mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h delete mode 100644 mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp delete mode 100644 mlir/lib/Target/LLVMIR/Dialect/AMX/CMakeLists.txt diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td index 8a51df1ea183f..a484f2ca009a2 100644 --- a/mlir/include/mlir/Dialect/AMX/AMX.td +++ b/mlir/include/mlir/Dialect/AMX/AMX.td @@ -25,10 +25,11 @@ // //===----------------------------------------------------------------------===// -#ifndef AMX -#define AMX +#ifndef AMX_OPS +#define AMX_OPS include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Dialect/AMX/AMXInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/BuiltinTypes.td" @@ -47,8 +48,6 @@ def AMX_Dialect : Dialect { This `AMX` dialect provides a bridge between MLIR concepts such as vectors and memrefs and the lower level LLVM IR support of AMX. - The dialect is split into user-facing AMX ops (AMX_Op) and - backend-facing intrinsic ops (AMX_IntrOp). Note that since configuration changes (implicit at dialect level) are costly, it is highly recommended to use the AMX dialect on same-shaped @@ -135,21 +134,17 @@ def AMXTileI8 : AMXTileOf<[I8]>; class AMX_Op traits = []> : Op {} -// The "internal" intrinsics are meant for compiler usage. -class AMX_IntrOp traits = []> : - LLVM_IntrOpBase; - //===----------------------------------------------------------------------===// -// AMX Op definitions (user facing). +// AMX Op definitions //===----------------------------------------------------------------------===// // // Tile reset. // -def TileZeroOp : AMX_Op<"tile_zero", [Pure]> { +def TileZeroOp : AMX_Op<"tile_zero", [Pure, + AMXIntrinsicOpInterface + ]> { let summary = "tile zero operation"; let description = [{ Zeroes the destination tile, with the shape defined by the 2-dim @@ -167,6 +162,14 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> { TileType getTileType() { return ::llvm::cast(getRes().getType()); } + + std::string getIntrinsicName() { + return "llvm.x86.tilezero.internal"; + } + SmallVector getIntrinsicOperands( + ::mlir::ArrayRef operands, + const ::mlir::LLVMTypeConverter &typeConverter, + ::mlir::RewriterBase &rewriter); }]; let assemblyFormat = "attr-dict `:` qualified(type($res))"; let hasVerifier = 1; @@ -176,7 +179,9 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> { // Tile memory operations. // -def TileLoadOp : AMX_Op<"tile_load", [Pure]> { +def TileLoadOp : AMX_Op<"tile_load", [Pure, + AMXIntrinsicOpInterface + ]> { let summary = "tile load operation"; let description = [{ Loads a tile from memory defined by a base and indices, with the @@ -200,13 +205,23 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> { TileType getTileType() { return ::llvm::cast(getRes().getType()); } + + std::string getIntrinsicName() { + return "llvm.x86.tileloadd64.internal"; + } + SmallVector getIntrinsicOperands( + ::mlir::ArrayRef operands, + const ::mlir::LLVMTypeConverter &typeConverter, + ::mlir::RewriterBase &rewriter); }]; let assemblyFormat = "$base `[` $indices `]` attr-dict `:` " "type($base) `into` qualified(type($res))"; let hasVerifier = 1; } -def TileStoreOp : AMX_Op<"tile_store"> { +def TileStoreOp : AMX_Op<"tile_store", [ + AMXIntrinsicOpInterface + ]> { let summary = "tile store operation"; let description = [{ Stores a tile to memory defined by a base and indices, with the @@ -230,6 +245,14 @@ def TileStoreOp : AMX_Op<"tile_store"> { TileType getTileType() { return ::llvm::cast(getVal().getType()); } + + std::string getIntrinsicName() { + return "llvm.x86.tilestored64.internal"; + } + SmallVector getIntrinsicOperands( + ::mlir::ArrayRef operands, + const ::mlir::LLVMTypeConverter &typeConverter, + ::mlir::RewriterBase &rewriter); }]; let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` " "type($base) `,` qualified(type($val))"; @@ -240,8 +263,10 @@ def TileStoreOp : AMX_Op<"tile_store"> { // Tile arithmetic operations. // -def TileMulFOp : AMX_Op<"tile_mulf", [ - Pure, AllTypesMatch<["acc", "res"]>]> { +def TileMulFOp : AMX_Op<"tile_mulf", [Pure, + AMXIntrinsicOpInterface, + AllTypesMatch<["acc", "res"]> + ]> { let summary = "tile multiplication operation (floating-point)"; let description = [{ Multiplies a "m x k" tile with a "k x n" tile and accumulates the results @@ -270,6 +295,19 @@ def TileMulFOp : AMX_Op<"tile_mulf", [ TileType getTileType() { return ::llvm::cast(getRes().getType()); } + + std::string getIntrinsicName() { + std::string intr = "llvm.x86.tdp"; + auto elementType = + getLhsTileType().getElementType(); + intr += elementType.isF16() ? "fp16" : "bf16"; + intr += "ps.internal"; + return intr; + } + SmallVector getIntrinsicOperands( + ::mlir::ArrayRef operands, + const ::mlir::LLVMTypeConverter &typeConverter, + ::mlir::RewriterBase &rewriter); }]; let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` " "qualified(type($lhs)) `,` qualified(type($rhs))" @@ -277,8 +315,10 @@ def TileMulFOp : AMX_Op<"tile_mulf", [ let hasVerifier = 1; } -def TileMulIOp : AMX_Op<"tile_muli", [ - Pure, AllTypesMatch<["acc", "res"]>]> { +def TileMulIOp : AMX_Op<"tile_muli", [Pure, + AMXIntrinsicOpInterface, + AllTypesMatch<["acc", "res"]> + ]> { let summary = "tile multiplication operation (integer)"; let description = [{ Multiplies a "m x k" tile with a "k x n" tile and accumulates the results @@ -313,77 +353,22 @@ def TileMulIOp : AMX_Op<"tile_muli", [ TileType getTileType() { return ::llvm::cast(getRes().getType()); } + + std::string getIntrinsicName() { + std::string intr = "llvm.x86.tdpb"; + intr += getIsZextLhs() ? "u" : "s"; + intr += getIsZextRhs() ? "u" : "s"; + intr += "d.internal"; + return intr; + } + SmallVector getIntrinsicOperands( + ::mlir::ArrayRef operands, + const ::mlir::LLVMTypeConverter &typeConverter, + ::mlir::RewriterBase &rewriter); }]; let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` " "qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) "; let hasVerifier = 1; } -//===----------------------------------------------------------------------===// -// AMX IntrOp definitions (LLVM compiler facing). -//===----------------------------------------------------------------------===// - -// -// Tile reset. Parameters define the tile size. -// - -def LLVM_x86_amx_tilezero : AMX_IntrOp<"tilezero", 1>, - Arguments<(ins AnyInteger, AnyInteger)>; - -// -// Tile memory operations. Parameters define the tile size, -// base address, and stride between consecutive rows for the -// memory operation. -// - -def LLVM_x86_amx_tileloadd64 : AMX_IntrOp<"tileloadd64", 1>, - Arguments<(ins AnyInteger, - AnyInteger, LLVM_AnyPointer, AnyInteger)>; - -def LLVM_x86_amx_tilestored64 : AMX_IntrOp<"tilestored64", 0>, - Arguments<(ins AnyInteger, - AnyInteger, LLVM_AnyPointer, AnyInteger, LLVM_Type)>; - -// -// Tile multiplication operations (series of dot products). Parameters -// define the tile sizes and source and destination tiles for the -// operation. Note that the prefix "tdp" stands for tile dot product. -// - -// Dot product of bf16 tiles into f32 tile. -def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>, - Arguments<(ins AnyInteger, - AnyInteger, - AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; - -// Dot product of f16 tiles into f32 tile. -def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>, - Arguments<(ins AnyInteger, - AnyInteger, - AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; - -// Dot product of i8 tiles into i32 tile (with sign/sign extension). -def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>, - Arguments<(ins AnyInteger, - AnyInteger, - AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; - -// Dot product of i8 tiles into i32 tile (with sign/zero extension). -def LLVM_x86_amx_tdpbsud : AMX_IntrOp<"tdpbsud", 1>, - Arguments<(ins AnyInteger, - AnyInteger, - AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; - -// Dot product of i8 tiles into i32 tile (with zero/sign extension). -def LLVM_x86_amx_tdpbusd : AMX_IntrOp<"tdpbusd", 1>, - Arguments<(ins AnyInteger, - AnyInteger, - AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; - -// Dot product of i8 tiles into i32 tile (with zero/zero extension). -def LLVM_x86_amx_tdpbuud : AMX_IntrOp<"tdpbuud", 1>, - Arguments<(ins AnyInteger, - AnyInteger, - AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; - -#endif // AMX +#endif // AMX_OPS diff --git a/mlir/include/mlir/Dialect/AMX/AMXDialect.h b/mlir/include/mlir/Dialect/AMX/AMXDialect.h index c0553ad8733fd..c79f31d4c994a 100644 --- a/mlir/include/mlir/Dialect/AMX/AMXDialect.h +++ b/mlir/include/mlir/Dialect/AMX/AMXDialect.h @@ -14,11 +14,15 @@ #define MLIR_DIALECT_AMX_AMXDIALECT_H_ #include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +/// Include the generated interface declarations. +#include "mlir/Dialect/AMX/AMXInterfaces.h.inc" + #include "mlir/Dialect/AMX/AMXDialect.h.inc" #define GET_TYPEDEF_CLASSES diff --git a/mlir/include/mlir/Dialect/AMX/AMXInterfaces.td b/mlir/include/mlir/Dialect/AMX/AMXInterfaces.td new file mode 100644 index 0000000000000..012d1ba7368f7 --- /dev/null +++ b/mlir/include/mlir/Dialect/AMX/AMXInterfaces.td @@ -0,0 +1,31 @@ +//===- AMXInterfaces.td - AMX interfaces -------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines interfaces for the AMX dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef AMX_INTERFACES +#define AMX_INTERFACES + +include "mlir/IR/Interfaces.td" +include "mlir/Dialect/LLVMIR/LLVMInterfaces.td" + +//===----------------------------------------------------------------------===// +// AMX Intrinsic Interface +//===----------------------------------------------------------------------===// + +def AMXIntrinsicOpInterface + : OpInterface<"AMXIntrinsicOp", [OneToOneIntrinsicOpInterface]> { + let description = [{ + A wrapper interface for operations representing AMX LLVM intrinsics. + }]; + let cppNamespace = "::mlir::amx"; +} + +#endif // AMX_INTERFACES diff --git a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt index f3f1aff5a6360..f875c78d240cc 100644 --- a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt @@ -1,6 +1,5 @@ add_mlir_dialect(AMX amx) add_mlir_doc(AMX AMX Dialects/ -gen-dialect-doc -dialect=amx) -set(LLVM_TARGET_DEFINITIONS AMX.td) -mlir_tablegen(AMXConversions.inc -gen-llvmir-conversions) -add_public_tablegen_target(MLIRAMXConversionsIncGen) +add_mlir_interface(AMXInterfaces) +add_dependencies(MLIRAMXIncGen MLIRAMXInterfacesIncGen) diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h index 7391ec2ff6b14..4a751d99ceeee 100644 --- a/mlir/include/mlir/Dialect/AMX/Transforms.h +++ b/mlir/include/mlir/Dialect/AMX/Transforms.h @@ -25,9 +25,6 @@ void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, /// intrinsics. void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target); -/// Register LLVM conversion interface for AMX dialect. -void registerConvertAMXToLLVMInterface(DialectRegistry ®istry); - } // namespace mlir #endif // MLIR_DIALECT_AMX_TRANSFORMS_H diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h index 37e4904cb48ed..1e3f7c649a8bd 100644 --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -32,7 +32,6 @@ #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" -#include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" #include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h" @@ -84,7 +83,6 @@ inline void registerAllExtensions(DialectRegistry ®istry) { registerConvertOpenMPToLLVMInterface(registry); registerConvertSCFToEmitCInterface(registry); ub::registerConvertUBToLLVMInterface(registry); - registerConvertAMXToLLVMInterface(registry); gpu::registerConvertGpuToLLVMInterface(registry); NVVM::registerConvertGpuToNVVMInterface(registry); vector::registerConvertVectorToLLVMInterface(registry); diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h deleted file mode 100644 index 4525ec3212196..0000000000000 --- a/mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h +++ /dev/null @@ -1,31 +0,0 @@ -//===- AMXToLLVMIRTranslation.h - AMX to LLVM IR ----------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This provides registration calls for AMX dialect to LLVM IR translation. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_TARGET_LLVMIR_DIALECT_AMX_AMXTOLLVMIRTRANSLATION_H -#define MLIR_TARGET_LLVMIR_DIALECT_AMX_AMXTOLLVMIRTRANSLATION_H - -namespace mlir { - -class DialectRegistry; -class MLIRContext; - -/// Register the AMX dialect and the translation from it to the LLVM IR -/// in the given registry; -void registerAMXDialectTranslation(DialectRegistry ®istry); - -/// Register the AMX dialect and the translation from it in the registry -/// associated with the given context. -void registerAMXDialectTranslation(MLIRContext &context); - -} // namespace mlir - -#endif // MLIR_TARGET_LLVMIR_DIALECT_AMX_AMXTOLLVMIRTRANSLATION_H diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h index e043ff2f6825c..60615cf601655 100644 --- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h @@ -14,7 +14,6 @@ #ifndef MLIR_TARGET_LLVMIR_DIALECT_ALL_H #define MLIR_TARGET_LLVMIR_DIALECT_ALL_H -#include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.h" @@ -37,7 +36,6 @@ class DialectRegistry; /// corresponding translation interfaces. static inline void registerAllToLLVMIRTranslations(DialectRegistry ®istry) { registerArmNeonDialectTranslation(registry); - registerAMXDialectTranslation(registry); registerArmSMEDialectTranslation(registry); registerArmSVEDialectTranslation(registry); registerBuiltinDialectTranslation(registry); diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp index 829f48e223383..69f524e1c311d 100644 --- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp +++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp @@ -11,6 +11,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/AMX/AMXDialect.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" @@ -21,6 +23,8 @@ using namespace mlir; +#include "mlir/Dialect/AMX/AMXInterfaces.cpp.inc" + #include "mlir/Dialect/AMX/AMXDialect.cpp.inc" void amx::AMXDialect::initialize() { @@ -60,24 +64,168 @@ static LogicalResult verifyMultShape(Operation *op, amx::TileType atp, return success(); } +/// Get pointer to a memref descriptor. +/// Optionally, the base pointer can be offset using linearized index computed +/// from the given indices. +static Value getBufferPtr(Location loc, MemRefType type, Value buffer, + ValueRange indices, + const LLVMTypeConverter &typeConverter, + RewriterBase &rewriter) { + auto [strides, offset] = type.getStridesAndOffset(); + + MemRefDescriptor memRefDescriptor(buffer); + Value base = memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type); + + int numIndices = indices.size(); + if (numIndices == 0) + return base; + + assert(type.getRank() == numIndices && + "expects number of indices equal to memref rank"); + Value index; + Type indexType = typeConverter.getIndexType(); + for (int i = 0; i < numIndices; ++i) { + Value increment = indices[i]; + if (strides[i] != 1) { // Skip if stride is 1. + Value stride = + ShapedType::isDynamic(strides[i]) + ? memRefDescriptor.stride(rewriter, loc, i) + : rewriter.create( + loc, indexType, rewriter.getIndexAttr(strides[i])); + increment = rewriter.create(loc, increment, stride); + } + index = + index ? rewriter.create(loc, index, increment) : increment; + } + + Type elementPtrType = memRefDescriptor.getElementPtrType(); + return rewriter.create( + loc, elementPtrType, typeConverter.convertType(type.getElementType()), + base, index); +} + +/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first +/// dimension directly translates into the number of rows of the tiles. +/// The second dimensions needs to be scaled by the number of bytes. +static SmallVector getTileSizes(Location loc, amx::TileType tType, + RewriterBase &rewriter) { + Type llvmInt16Type = rewriter.getIntegerType(16); + unsigned width = tType.getElementType().getIntOrFloatBitWidth(); + assert(llvm::isPowerOf2_64(width) && width >= 8); + unsigned bytes = width >> 3; + auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0)); + auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes); + return SmallVector{ + rewriter.create(loc, llvmInt16Type, mattr), + rewriter.create(loc, llvmInt16Type, nattr)}; +} + +/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer +/// shape may "envelop" the actual tile shape, and may be dynamically sized. +/// Returns failure if proper stride couldn't be found. +static Value getStride(Location loc, MemRefType mType, Value base, + RewriterBase &rewriter) { + assert(mType.getRank() >= 2 && "Invalid shape for AMX strides"); + int64_t preLast = mType.getRank() - 2; + Type llvmInt64Type = rewriter.getIntegerType(64); + unsigned width = mType.getElementType().getIntOrFloatBitWidth(); + assert(llvm::isPowerOf2_64(width) && width >= 8); + unsigned bytes = width >> 3; + auto [strides, offset] = mType.getStridesAndOffset(); + if (strides[preLast] == ShapedType::kDynamic) { + // Dynamic stride needs code to compute the stride at runtime. + MemRefDescriptor memrefDescriptor(base); + auto attr = rewriter.getI64IntegerAttr(bytes); + Value scale = rewriter.create(loc, llvmInt64Type, attr); + return rewriter + .create(loc, llvmInt64Type, scale, + memrefDescriptor.stride(rewriter, loc, preLast)) + .getResult(); + } + // Use direct constant for static stride. + auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes); + return rewriter.create(loc, llvmInt64Type, attr) + .getResult(); +} + LogicalResult amx::TileZeroOp::verify() { return verifyTileSize(*this, getTileType()); } +SmallVector +amx::TileZeroOp::getIntrinsicOperands(ArrayRef operands, + const LLVMTypeConverter &typeConverter, + RewriterBase &rewriter) { + return getTileSizes(getLoc(), getTileType(), rewriter); +} + LogicalResult amx::TileLoadOp::verify() { - unsigned rank = getMemRefType().getRank(); + MemRefType memrefTy = getMemRefType(); + unsigned rank = memrefTy.getRank(); + if (rank < 2) + return emitOpError("requires at least 2D memref"); if (getIndices().size() != rank) return emitOpError("requires ") << rank << " indices"; + SmallVector strides; + int64_t offset; + if (failed(memrefTy.getStridesAndOffset(strides, offset)) || + strides.back() != 1) + return emitOpError("requires memref with unit innermost stride"); return verifyTileSize(*this, getTileType()); } +SmallVector +amx::TileLoadOp::getIntrinsicOperands(ArrayRef operands, + const LLVMTypeConverter &typeConverter, + RewriterBase &rewriter) { + auto loc = getLoc(); + Adaptor adaptor(operands, *this); + + SmallVector intrinsicOperands; + intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter)); + intrinsicOperands.push_back( + getBufferPtr(loc, getMemRefType(), adaptor.getBase(), + adaptor.getIndices(), typeConverter, rewriter)); + intrinsicOperands.push_back( + getStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); + + return intrinsicOperands; +} + LogicalResult amx::TileStoreOp::verify() { - unsigned rank = getMemRefType().getRank(); + MemRefType memrefTy = getMemRefType(); + unsigned rank = memrefTy.getRank(); + if (rank < 2) + return emitOpError("requires at least 2D memref"); if (getIndices().size() != rank) return emitOpError("requires ") << rank << " indices"; + SmallVector strides; + int64_t offset; + if (failed(memrefTy.getStridesAndOffset(strides, offset)) || + strides.back() != 1) + return emitOpError("requires memref with unit innermost stride"); return verifyTileSize(*this, getTileType()); } +SmallVector +amx::TileStoreOp::getIntrinsicOperands(ArrayRef operands, + const LLVMTypeConverter &typeConverter, + RewriterBase &rewriter) { + auto loc = getLoc(); + Adaptor adaptor(operands, *this); + + SmallVector intrinsicOperands; + intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter)); + intrinsicOperands.push_back( + getBufferPtr(loc, getMemRefType(), adaptor.getBase(), + adaptor.getIndices(), typeConverter, rewriter)); + intrinsicOperands.push_back( + getStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); + intrinsicOperands.push_back(adaptor.getVal()); + + return intrinsicOperands; +} + LogicalResult amx::TileMulFOp::verify() { amx::TileType aType = getLhsTileType(); amx::TileType bType = getRhsTileType(); @@ -95,6 +243,25 @@ LogicalResult amx::TileMulFOp::verify() { return success(); } +SmallVector +amx::TileMulFOp::getIntrinsicOperands(ArrayRef operands, + const LLVMTypeConverter &typeConverter, + RewriterBase &rewriter) { + auto loc = getLoc(); + Adaptor adaptor(operands, *this); + + amx::TileType aType = getLhsTileType(); + amx::TileType bType = getRhsTileType(); + SmallVector tsza = getTileSizes(loc, aType, rewriter); + SmallVector tszb = getTileSizes(loc, bType, rewriter); + + SmallVector intrinsicOperands = {tsza[0], tszb[1], + tsza[1], adaptor.getAcc(), + adaptor.getLhs(), adaptor.getRhs()}; + + return intrinsicOperands; +} + LogicalResult amx::TileMulIOp::verify() { amx::TileType aType = getLhsTileType(); amx::TileType bType = getRhsTileType(); @@ -112,6 +279,25 @@ LogicalResult amx::TileMulIOp::verify() { return success(); } +SmallVector +amx::TileMulIOp::getIntrinsicOperands(ArrayRef operands, + const LLVMTypeConverter &typeConverter, + RewriterBase &rewriter) { + auto loc = getLoc(); + Adaptor adaptor(operands, *this); + + amx::TileType aType = getLhsTileType(); + amx::TileType bType = getRhsTileType(); + SmallVector tsza = getTileSizes(loc, aType, rewriter); + SmallVector tszb = getTileSizes(loc, bType, rewriter); + + SmallVector intrinsicOperands = {tsza[0], tszb[1], + tsza[1], adaptor.getAcc(), + adaptor.getLhs(), adaptor.getRhs()}; + + return intrinsicOperands; +} + Type amx::TileType::parse(AsmParser &parser) { if (parser.parseLess()) return nullptr; diff --git a/mlir/lib/Dialect/AMX/IR/CMakeLists.txt b/mlir/lib/Dialect/AMX/IR/CMakeLists.txt index d109547b2438b..b6e2759843d5e 100644 --- a/mlir/lib/Dialect/AMX/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/AMX/IR/CMakeLists.txt @@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRAMXDialect LINK_LIBS PUBLIC MLIRIR + MLIRLLVMCommonConversion MLIRLLVMDialect MLIRSideEffectInterfaces ) diff --git a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt index 29340d4f45dd1..e827bc475e930 100644 --- a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt @@ -1,9 +1,6 @@ add_mlir_dialect_library(MLIRAMXTransforms LegalizeForLLVMExport.cpp - DEPENDS - MLIRAMXConversionsIncGen - LINK_LIBS PUBLIC MLIRAMXDialect MLIRIR diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp index 2168409184549..7471dc797e0fc 100644 --- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp @@ -21,224 +21,42 @@ using namespace mlir::amx; namespace { -/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first -/// dimension directly translates into the number of rows of the tiles. -/// The second dimensions needs to be scaled by the number of bytes. -std::pair getTileSizes(ConversionPatternRewriter &rewriter, - const LLVMTypeConverter &typeConverter, - amx::TileType tType, Location loc) { - Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16); - unsigned width = tType.getElementType().getIntOrFloatBitWidth(); - assert(llvm::isPowerOf2_64(width) && width >= 8); - unsigned bytes = width >> 3; - auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0)); - auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes); - return std::make_pair( - rewriter.create(loc, llvmInt16Type, mattr), - rewriter.create(loc, llvmInt16Type, nattr)); -} - -/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer -/// shape may "envelop" the actual tile shape, and may be dynamically sized. -/// Returns failure if proper stride couldn't be found. -FailureOr getStride(ConversionPatternRewriter &rewriter, - const LLVMTypeConverter &typeConverter, - MemRefType mType, Value base, Location loc) { - if (mType.getRank() < 2) - return failure(); - int64_t preLast = mType.getRank() - 2; - Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64); - unsigned width = mType.getElementType().getIntOrFloatBitWidth(); - assert(llvm::isPowerOf2_64(width) && width >= 8); - unsigned bytes = width >> 3; - int64_t offset; - SmallVector strides; - if (failed(mType.getStridesAndOffset(strides, offset)) || strides.back() != 1) - return failure(); - if (strides[preLast] == ShapedType::kDynamic) { - // Dynamic stride needs code to compute the stride at runtime. - MemRefDescriptor memrefDescriptor(base); - auto attr = rewriter.getI64IntegerAttr(bytes); - Value scale = rewriter.create(loc, llvmInt64Type, attr); - return rewriter - .create(loc, llvmInt64Type, scale, - memrefDescriptor.stride(rewriter, loc, preLast)) - .getResult(); - } - // Use direct constant for static stride. - auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes); - return rewriter.create(loc, llvmInt64Type, attr) - .getResult(); -} - -struct TileZeroConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - LogicalResult - matchAndRewrite(TileZeroOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - amx::TileType tType = op.getTileType(); - // Determine m x n tile sizes. - std::pair tsz = - getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc()); - // Replace operation with intrinsic. - Type resType = typeConverter->convertType(tType); - rewriter.replaceOpWithNewOp(op, resType, tsz.first, - tsz.second); - return success(); - } -}; - -struct TileLoadConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(TileLoadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - MemRefType mType = op.getMemRefType(); - amx::TileType tType = op.getTileType(); - // Determine m x n tile sizes. - std::pair tsz = - getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc()); - // Determine stride. - auto stride = getStride(rewriter, *getTypeConverter(), mType, - adaptor.getBase(), op.getLoc()); - if (failed(stride)) - return failure(); - // Replace operation with intrinsic. - Value ptr = getStridedElementPtr(rewriter, op.getLoc(), mType, - adaptor.getBase(), adaptor.getIndices()); - Type resType = typeConverter->convertType(tType); - rewriter.replaceOpWithNewOp( - op, resType, tsz.first, tsz.second, ptr, stride.value()); - return success(); - } -}; - -struct TileStoreConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(TileStoreOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - MemRefType mType = op.getMemRefType(); - amx::TileType tType = op.getTileType(); - // Determine m x n tile sizes. - std::pair tsz = - getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc()); - // Determine stride. - auto stride = getStride(rewriter, *getTypeConverter(), mType, - adaptor.getBase(), op.getLoc()); - if (failed(stride)) - return failure(); - // Replace operation with intrinsic. - Value ptr = getStridedElementPtr(rewriter, op.getLoc(), mType, - adaptor.getBase(), adaptor.getIndices()); - rewriter.replaceOpWithNewOp( - op, tsz.first, tsz.second, ptr, stride.value(), adaptor.getVal()); - return success(); - } -}; +/// Generic one-to-one conversion of simply mappable operations into calls +/// to their respective LLVM intrinsics. +struct AMXIntrinsicOpConversion + : public OpInterfaceConversionPattern { + using OpInterfaceConversionPattern< + amx::AMXIntrinsicOp>::OpInterfaceConversionPattern; + + AMXIntrinsicOpConversion(const LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : OpInterfaceConversionPattern(typeConverter, &typeConverter.getContext(), + benefit), + typeConverter(typeConverter) {} -struct TileMulFConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(TileMulFOp op, OpAdaptor adaptor, + matchAndRewrite(amx::AMXIntrinsicOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - amx::TileType aType = op.getLhsTileType(); - amx::TileType bType = op.getRhsTileType(); - amx::TileType cType = op.getTileType(); - // Determine m x n x k tile sizes. - std::pair tsza = - getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc()); - std::pair tszb = - getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc()); - // Replace operation with intrinsic. - Type resType = typeConverter->convertType(cType); - if (aType.getElementType().isBF16()) - rewriter.replaceOpWithNewOp( - op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), - adaptor.getLhs(), adaptor.getRhs()); - else if (aType.getElementType().isF16()) - rewriter.replaceOpWithNewOp( - op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), - adaptor.getLhs(), adaptor.getRhs()); - else - llvm_unreachable("Unexpected element type for amx.mulf"); - return success(); + return LLVM::detail::intrinsicRewrite( + op, rewriter.getStringAttr(op.getIntrinsicName()), + op.getIntrinsicOperands(operands, typeConverter, rewriter), + typeConverter, rewriter); } -}; -struct TileMulIConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - LogicalResult - matchAndRewrite(TileMulIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - amx::TileType aType = op.getLhsTileType(); - amx::TileType bType = op.getRhsTileType(); - amx::TileType cType = op.getTileType(); - // Determine m x n x k tile sizes. - std::pair tsza = - getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc()); - std::pair tszb = - getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc()); - // Replace operation with intrinsic. - Type resType = typeConverter->convertType(cType); - bool zexta = op.getIsZextLhs(); - bool zextb = op.getIsZextRhs(); - if (zexta && zextb) - rewriter.replaceOpWithNewOp( - op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), - adaptor.getLhs(), adaptor.getRhs()); - else if (zexta && !zextb) - rewriter.replaceOpWithNewOp( - op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), - adaptor.getLhs(), adaptor.getRhs()); - else if (!zexta && zextb) - rewriter.replaceOpWithNewOp( - op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), - adaptor.getLhs(), adaptor.getRhs()); - else - rewriter.replaceOpWithNewOp( - op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), - adaptor.getLhs(), adaptor.getRhs()); - return success(); - } +private: + const LLVMTypeConverter &typeConverter; }; } // namespace void mlir::populateAMXLegalizeForLLVMExportPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { - patterns.add(converter); + patterns.add(converter); converter.addConversion([&](amx::TileType type) { return LLVM::LLVMX86AMXType::get(&converter.getContext()); }); } void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) { - target.addLegalOp(); - target.addIllegalOp(); -} - -namespace { -/// Implement the interface to convert AMX to LLVM. -struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface { - using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; - - void populateConvertToLLVMConversionPatterns( - ConversionTarget &target, LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns) const final { - populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns); - } -}; -} // namespace - -void mlir::registerConvertAMXToLLVMInterface(DialectRegistry ®istry) { - registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) { - dialect->addInterfaces(); - }); + target.addIllegalDialect(); } diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt index 4ace3964e8ae0..af22a7ff04bf0 100644 --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -51,7 +51,6 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration MLIRArmNeonToLLVMIRTranslation MLIRArmSMEToLLVMIRTranslation MLIRArmSVEToLLVMIRTranslation - MLIRAMXToLLVMIRTranslation MLIRBuiltinToLLVMIRTranslation MLIRGPUToLLVMIRTranslation MLIRLLVMToLLVMIRTranslation diff --git a/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp deleted file mode 100644 index 044462d33cfd1..0000000000000 --- a/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp +++ /dev/null @@ -1,56 +0,0 @@ -//===- AMXToLLVMIRTranslation.cpp - Translate AMX to LLVM IR --------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements a translation between the AMX dialect and LLVM IR. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h" -#include "mlir/Dialect/AMX/AMXDialect.h" -#include "mlir/IR/Operation.h" -#include "mlir/Target/LLVMIR/ModuleTranslation.h" - -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/IntrinsicsX86.h" - -using namespace mlir; -using namespace mlir::LLVM; - -namespace { -/// Implementation of the dialect interface that converts operations belonging -/// to the AMX dialect to LLVM IR. -class AMXDialectLLVMIRTranslationInterface - : public LLVMTranslationDialectInterface { -public: - using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; - - /// Translates the given operation to LLVM IR using the provided IR builder - /// and saving the state in `moduleTranslation`. - LogicalResult - convertOperation(Operation *op, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) const final { - Operation &opInst = *op; -#include "mlir/Dialect/AMX/AMXConversions.inc" - - return failure(); - } -}; -} // namespace - -void mlir::registerAMXDialectTranslation(DialectRegistry ®istry) { - registry.insert(); - registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) { - dialect->addInterfaces(); - }); -} - -void mlir::registerAMXDialectTranslation(MLIRContext &context) { - DialectRegistry registry; - registerAMXDialectTranslation(registry); - context.appendDialectRegistry(registry); -} diff --git a/mlir/lib/Target/LLVMIR/Dialect/AMX/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/AMX/CMakeLists.txt deleted file mode 100644 index 733b4c2e31b80..0000000000000 --- a/mlir/lib/Target/LLVMIR/Dialect/AMX/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ -add_mlir_translation_library(MLIRAMXToLLVMIRTranslation - AMXToLLVMIRTranslation.cpp - - DEPENDS - MLIRAMXConversionsIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRIR - MLIRAMXDialect - MLIRLLVMDialect - MLIRSupport - MLIRTargetLLVMIRExport - ) diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt index 40df6e3f4b642..f030fa78942d5 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt @@ -1,7 +1,6 @@ add_subdirectory(ArmNeon) add_subdirectory(ArmSME) add_subdirectory(ArmSVE) -add_subdirectory(AMX) add_subdirectory(Builtin) add_subdirectory(GPU) add_subdirectory(LLVMIR) diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir index 8085f5f59fcaf..7e562b00a46a9 100644 --- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir @@ -1,17 +1,17 @@ // RUN: mlir-opt %s -convert-vector-to-llvm="enable-amx" | mlir-opt | FileCheck %s // CHECK-LABEL: muli( -// CHECK: amx.tilezero -// CHECK: amx.tileloadd64 -// CHECK: amx.tileloadd64 -// CHECK: amx.tdpbuud -// CHECK: amx.tilestored64 -// CHECK: amx.tdpbssd -// CHECK: amx.tilestored64 -// CHECK: amx.tdpbusd -// CHECK: amx.tilestored64 -// CHECK: amx.tdpbsud -// CHECK: amx.tilestored64 +// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tdpbuud.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tdpbssd.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tdpbusd.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tdpbsud.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal" func.func @muli(%arg0: memref, %arg1: memref) { %0 = arith.constant 0 : index %1 = amx.tile_zero : !amx.tile<16x64xi8> @@ -29,11 +29,11 @@ func.func @muli(%arg0: memref, %arg1: memref) { } // CHECK-LABEL: mulbf16( -// CHECK: amx.tilezero -// CHECK: amx.tileloadd64 -// CHECK: amx.tileloadd64 -// CHECK: amx.tdpbf16ps -// CHECK: amx.tilestored64 +// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tdpbf16ps.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal" func.func @mulbf16(%arg0: memref, %arg1: memref) { %0 = arith.constant 0 : index %1 = amx.tile_zero : !amx.tile<16x32xbf16> @@ -45,11 +45,11 @@ func.func @mulbf16(%arg0: memref, %arg1: memref) { } // CHECK-LABEL: mulfp16( -// CHECK: amx.tilezero -// CHECK: amx.tileloadd64 -// CHECK: amx.tileloadd64 -// CHECK: amx.tdpfp16ps -// CHECK: amx.tilestored64 +// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tdpfp16ps.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal" func.func @mulfp16(%arg0: memref, %arg1: memref) { %0 = arith.constant 0 : index %1 = amx.tile_zero : !amx.tile<16x32xf16> @@ -62,21 +62,21 @@ func.func @mulfp16(%arg0: memref, %arg1: memref) { // CHECK-LABEL: strides( // CHECK: %[[CST_64_1:.+]] = llvm.mlir.constant(64 : i64) : i64 -// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_1]] +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_1]] // CHECK: %[[CST_128_1:.+]] = llvm.mlir.constant(128 : i64) : i64 -// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_1]] +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_1]] // CHECK: llvm.mlir.constant(2 : i64) : i64 // CHECK: llvm.extractvalue %{{.+}}[4, 0] // CHECK: %[[STRIDE_1:.+]] = llvm.mul -// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_1]] +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_1]] // CHECK: %[[CST_64_2:.+]] = llvm.mlir.constant(64 : i64) : i64 -// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_2]] +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_2]] // CHECK: %[[CST_128_2:.+]] = llvm.mlir.constant(128 : i64) : i64 -// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_2]] +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_2]] // CHECK: llvm.mlir.constant(2 : i64) : i64 // CHECK: llvm.extractvalue %{{.+}}[4, 0] // CHECK: %[[STRIDE_2:.+]] = llvm.mul -// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_2]] +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_2]] func.func @strides(%arg0: memref<16x32xbf16>, %arg1: memref<16x32xbf16, strided<[64, 1]>>, %arg2: memref<16x32xbf16, strided<[?, 1]>>) { %0 = arith.constant 0 : index %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into !amx.tile<16x32xbf16> diff --git a/mlir/test/Target/LLVMIR/amx.mlir b/mlir/test/Target/LLVMIR/amx.mlir index 0281dfcd6ad69..094475040436d 100644 --- a/mlir/test/Target/LLVMIR/amx.mlir +++ b/mlir/test/Target/LLVMIR/amx.mlir @@ -1,13 +1,90 @@ -// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s +// RUN: mlir-opt %s --convert-vector-to-llvm="enable-amx" --convert-to-llvm -reconcile-unrealized-casts \ +// RUN: | mlir-translate --mlir-to-llvmir \ +// RUN: | FileCheck %s -// CHECK-LABEL: define void @target(ptr %0) -// CHECK: %[[c:.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 16) -// CHECK: call void @llvm.x86.tilestored64.internal(i16 16, i16 16, ptr %0, i64 32, x86_amx %[[c]] -llvm.func @target(%ptr: !llvm.ptr) { - %c = llvm.mlir.constant(16 : i16) : i16 - %s = llvm.mlir.constant(32 : i64) : i64 - %0 = "amx.tilezero"(%c, %c) : (i16, i16) -> !llvm.array<16 x vector<16xbf16>> - "amx.tilestored64"(%c, %c, %ptr, %s, %0) : (i16, i16, !llvm.ptr, i64, !llvm.array<16 x vector<16xbf16>>) -> () - llvm.return +// CHECK-LABEL: define void @amx_tile_zero +func.func @amx_tile_zero(%out: memref, %idx: index) +{ + // CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64) + // CHECK: call void @llvm.x86.tilestored64.internal + %zero = amx.tile_zero : !amx.tile<16x16xf32> + amx.tile_store %out[%idx, %idx], %zero : memref, !amx.tile<16x16xf32> + return } +// CHECK-LABEL: define void @amx_tile_load_store +func.func @amx_tile_load_store(%base: memref, %out: memref, + %idx: index) +{ + // CHECK: call x86_amx @llvm.x86.tileloadd64.internal + // CHECK: call void @llvm.x86.tilestored64.internal + %val = amx.tile_load %base[%idx, %idx] : memref into !amx.tile<16x64xi8> + amx.tile_store %out[%idx, %idx], %val : memref, !amx.tile<16x64xi8> + return +} + +// CHECK-LABEL: define void @amx_tile_mulf_bf16 +func.func @amx_tile_mulf_bf16( + %matA: memref, %matB: memref, %idx: index, + %out: memref) +{ + // CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64) + %acc = amx.tile_zero : !amx.tile<16x16xf32> + // CHECK-COUNT-2: call x86_amx @llvm.x86.tileloadd64.internal + %tA = amx.tile_load %matA[%idx, %idx] : memref into !amx.tile<16x32xbf16> + %tB = amx.tile_load %matB[%idx, %idx] : memref into !amx.tile<16x32xbf16> + // CHECK: call x86_amx @llvm.x86.tdpbf16ps.internal + %tRes = amx.tile_mulf %tA, %tB, %acc + : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> + // CHECK: call void @llvm.x86.tilestored64.internal + amx.tile_store %out[%idx, %idx], %tRes : memref, !amx.tile<16x16xf32> + return +} + +// CHECK-LABEL: define void @amx_tile_mulf_f16 +func.func @amx_tile_mulf_f16( + %matA: memref, %matB: memref, %idx: index, + %out: memref) +{ + // CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64) + %acc = amx.tile_zero : !amx.tile<16x16xf32> + // CHECK-COUNT-2: call x86_amx @llvm.x86.tileloadd64.internal + %tA = amx.tile_load %matA[%idx, %idx] : memref into !amx.tile<16x32xf16> + %tB = amx.tile_load %matB[%idx, %idx] : memref into !amx.tile<16x32xf16> + // CHECK: call x86_amx @llvm.x86.tdpfp16ps.internal + %tRes = amx.tile_mulf %tA, %tB, %acc + : !amx.tile<16x32xf16>, !amx.tile<16x32xf16>, !amx.tile<16x16xf32> + // CHECK: call void @llvm.x86.tilestored64.internal + amx.tile_store %out[%idx, %idx], %tRes : memref, !amx.tile<16x16xf32> + return +} + +// CHECK-LABEL: define void @amx_tile_muli +func.func @amx_tile_muli(%matA: memref, %matB: memref, + %matC: memref, %idx: index, %out: memref) +{ + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + // CHECK-COUNT-3: call x86_amx @llvm.x86.tileloadd64.internal + %tA = amx.tile_load %matA[%idx, %idx] : memref into !amx.tile<16x64xi8> + %tB = amx.tile_load %matB[%idx, %idx] : memref into !amx.tile<16x64xi8> + %acc = amx.tile_load %matC[%idx, %idx] : memref into !amx.tile<16x16xi32> + // CHECK: call x86_amx @llvm.x86.tdpbuud.internal + // CHECK: call x86_amx @llvm.x86.tdpbssd.internal + // CHECK: call x86_amx @llvm.x86.tdpbusd.internal + // CHECK: call x86_amx @llvm.x86.tdpbsud.internal + %res = amx.tile_muli %tA zext, %tB zext, %acc + : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32> + %res1 = amx.tile_muli %tA, %tB, %acc + : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32> + %res2 = amx.tile_muli %tA zext, %tB, %acc + : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32> + %res3 = amx.tile_muli %tA, %tB zext, %acc + : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32> + // CHECK-COUNT-4: call void @llvm.x86.tilestored64.internal + amx.tile_store %out[%c0, %c0], %res : memref, !amx.tile<16x16xi32> + amx.tile_store %out[%c0, %c16], %res1 : memref, !amx.tile<16x16xi32> + amx.tile_store %out[%c16, %c0], %res2 : memref, !amx.tile<16x16xi32> + amx.tile_store %out[%c16, %c16], %res3 : memref, !amx.tile<16x16xi32> + return +} From 3c67d3a9529c80643fa97fdd749188438083b594 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Tue, 20 May 2025 13:04:48 +0200 Subject: [PATCH 2/3] Address comments --- .../mlir/Conversion/LLVMCommon/Pattern.h | 13 ++- mlir/include/mlir/Dialect/AMX/AMX.td | 6 +- mlir/lib/Conversion/LLVMCommon/Pattern.cpp | 94 ++++++++++--------- mlir/lib/Dialect/AMX/IR/AMXDialect.cpp | 49 +--------- 4 files changed, 69 insertions(+), 93 deletions(-) diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h index 2bf9a021f48e1..011131a17fbcf 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -69,6 +69,15 @@ SmallVector decomposeValue(OpBuilder &builder, Location loc, Value src, /// function is used to combine multiple values into a single value. Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType); + +/// Performs the index computation to get to the element at `indices` of the +/// memory pointed to by `memRefDesc`, using the layout map of `type`. +/// The indices are linearized as: +/// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. +Value getStridedElementPtr(OpBuilder &builder, Location loc, + const LLVMTypeConverter &converter, MemRefType type, + Value memRefDesc, ValueRange indices, + LLVM::GEPNoWrapFlags noWrapFlags = LLVM::GEPNoWrapFlags::none); } // namespace LLVM /// Base class for operation conversions targeting the LLVM IR dialect. It @@ -107,8 +116,8 @@ class ConvertToLLVMPattern : public ConversionPattern { static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value); - // This is a strided getElementPtr variant that linearizes subscripts as: - // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. + /// Convenience wrapper for the corresponding helper utility. + /// This is a strided getElementPtr variant with linearized subscripts. Value getStridedElementPtr( ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td index a484f2ca009a2..6bbde43e2d011 100644 --- a/mlir/include/mlir/Dialect/AMX/AMX.td +++ b/mlir/include/mlir/Dialect/AMX/AMX.td @@ -25,8 +25,8 @@ // //===----------------------------------------------------------------------===// -#ifndef AMX_OPS -#define AMX_OPS +#ifndef AMX +#define AMX include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Dialect/AMX/AMXInterfaces.td" @@ -371,4 +371,4 @@ def TileMulIOp : AMX_Op<"tile_muli", [Pure, let hasVerifier = 1; } -#endif // AMX_OPS +#endif // AMX diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index 48fbcbcdbbde9..86fb9166b7223 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -62,49 +62,8 @@ Value ConvertToLLVMPattern::getStridedElementPtr( ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags) const { - - auto [strides, offset] = type.getStridesAndOffset(); - - MemRefDescriptor memRefDescriptor(memRefDesc); - // Use a canonical representation of the start address so that later - // optimizations have a longer sequence of instructions to CSE. - // If we don't do that we would sprinkle the memref.offset in various - // position of the different address computations. - Value base = - memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(), type); - - LLVM::IntegerOverflowFlags intOverflowFlags = - LLVM::IntegerOverflowFlags::none; - if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nusw)) { - intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nsw; - } - if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nuw)) { - intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nuw; - } - - Type indexType = getIndexType(); - Value index; - for (int i = 0, e = indices.size(); i < e; ++i) { - Value increment = indices[i]; - if (strides[i] != 1) { // Skip if stride is 1. - Value stride = - ShapedType::isDynamic(strides[i]) - ? memRefDescriptor.stride(rewriter, loc, i) - : createIndexAttrConstant(rewriter, loc, indexType, strides[i]); - increment = rewriter.create(loc, increment, stride, - intOverflowFlags); - } - index = index ? rewriter.create(loc, index, increment, - intOverflowFlags) - : increment; - } - - Type elementPtrType = memRefDescriptor.getElementPtrType(); - return index ? rewriter.create( - loc, elementPtrType, - getTypeConverter()->convertType(type.getElementType()), - base, index, noWrapFlags) - : base; + return LLVM::getStridedElementPtr(rewriter, loc, *getTypeConverter(), type, + memRefDesc, indices, noWrapFlags); } // Check if the MemRefType `type` is supported by the lowering. We currently @@ -524,3 +483,52 @@ Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src, return res; } + +Value mlir::LLVM::getStridedElementPtr(OpBuilder &builder, Location loc, + const LLVMTypeConverter &converter, + MemRefType type, Value memRefDesc, + ValueRange indices, + LLVM::GEPNoWrapFlags noWrapFlags) { + auto [strides, offset] = type.getStridesAndOffset(); + + MemRefDescriptor memRefDescriptor(memRefDesc); + // Use a canonical representation of the start address so that later + // optimizations have a longer sequence of instructions to CSE. + // If we don't do that we would sprinkle the memref.offset in various + // position of the different address computations. + Value base = memRefDescriptor.bufferPtr(builder, loc, converter, type); + + LLVM::IntegerOverflowFlags intOverflowFlags = + LLVM::IntegerOverflowFlags::none; + if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nusw)) { + intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nsw; + } + if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nuw)) { + intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nuw; + } + + Type indexType = converter.getIndexType(); + Value index; + for (int i = 0, e = indices.size(); i < e; ++i) { + Value increment = indices[i]; + if (strides[i] != 1) { // Skip if stride is 1. + Value stride = + ShapedType::isDynamic(strides[i]) + ? memRefDescriptor.stride(builder, loc, i) + : builder.create( + loc, indexType, builder.getIndexAttr(strides[i])); + increment = + builder.create(loc, increment, stride, intOverflowFlags); + } + index = index ? builder.create(loc, index, increment, + intOverflowFlags) + : increment; + } + + Type elementPtrType = memRefDescriptor.getElementPtrType(); + return index ? builder.create( + loc, elementPtrType, + converter.convertType(type.getElementType()), base, index, + noWrapFlags) + : base; +} diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp index 69f524e1c311d..12b375b373fa9 100644 --- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp +++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp @@ -64,46 +64,6 @@ static LogicalResult verifyMultShape(Operation *op, amx::TileType atp, return success(); } -/// Get pointer to a memref descriptor. -/// Optionally, the base pointer can be offset using linearized index computed -/// from the given indices. -static Value getBufferPtr(Location loc, MemRefType type, Value buffer, - ValueRange indices, - const LLVMTypeConverter &typeConverter, - RewriterBase &rewriter) { - auto [strides, offset] = type.getStridesAndOffset(); - - MemRefDescriptor memRefDescriptor(buffer); - Value base = memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type); - - int numIndices = indices.size(); - if (numIndices == 0) - return base; - - assert(type.getRank() == numIndices && - "expects number of indices equal to memref rank"); - Value index; - Type indexType = typeConverter.getIndexType(); - for (int i = 0; i < numIndices; ++i) { - Value increment = indices[i]; - if (strides[i] != 1) { // Skip if stride is 1. - Value stride = - ShapedType::isDynamic(strides[i]) - ? memRefDescriptor.stride(rewriter, loc, i) - : rewriter.create( - loc, indexType, rewriter.getIndexAttr(strides[i])); - increment = rewriter.create(loc, increment, stride); - } - index = - index ? rewriter.create(loc, index, increment) : increment; - } - - Type elementPtrType = memRefDescriptor.getElementPtrType(); - return rewriter.create( - loc, elementPtrType, typeConverter.convertType(type.getElementType()), - base, index); -} - /// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first /// dimension directly translates into the number of rows of the tiles. /// The second dimensions needs to be scaled by the number of bytes. @@ -122,7 +82,6 @@ static SmallVector getTileSizes(Location loc, amx::TileType tType, /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer /// shape may "envelop" the actual tile shape, and may be dynamically sized. -/// Returns failure if proper stride couldn't be found. static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter) { assert(mType.getRank() >= 2 && "Invalid shape for AMX strides"); @@ -184,8 +143,8 @@ amx::TileLoadOp::getIntrinsicOperands(ArrayRef operands, SmallVector intrinsicOperands; intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter)); intrinsicOperands.push_back( - getBufferPtr(loc, getMemRefType(), adaptor.getBase(), - adaptor.getIndices(), typeConverter, rewriter)); + LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(), + adaptor.getBase(), adaptor.getIndices())); intrinsicOperands.push_back( getStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); @@ -217,8 +176,8 @@ amx::TileStoreOp::getIntrinsicOperands(ArrayRef operands, SmallVector intrinsicOperands; intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter)); intrinsicOperands.push_back( - getBufferPtr(loc, getMemRefType(), adaptor.getBase(), - adaptor.getIndices(), typeConverter, rewriter)); + LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(), + adaptor.getBase(), adaptor.getIndices())); intrinsicOperands.push_back( getStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); intrinsicOperands.push_back(adaptor.getVal()); From a6593b0d0f01f2e8884ecbe3ec7ccf7338f4726f Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Wed, 21 May 2025 10:49:38 +0200 Subject: [PATCH 3/3] Fix formatting after rebase --- mlir/include/mlir/Conversion/LLVMCommon/Pattern.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h index 011131a17fbcf..7e946495e3e7f 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -74,10 +74,10 @@ Value composeValue(OpBuilder &builder, Location loc, ValueRange src, /// memory pointed to by `memRefDesc`, using the layout map of `type`. /// The indices are linearized as: /// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. -Value getStridedElementPtr(OpBuilder &builder, Location loc, - const LLVMTypeConverter &converter, MemRefType type, - Value memRefDesc, ValueRange indices, - LLVM::GEPNoWrapFlags noWrapFlags = LLVM::GEPNoWrapFlags::none); +Value getStridedElementPtr( + OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, + MemRefType type, Value memRefDesc, ValueRange indices, + LLVM::GEPNoWrapFlags noWrapFlags = LLVM::GEPNoWrapFlags::none); } // namespace LLVM /// Base class for operation conversions targeting the LLVM IR dialect. It