From e7e14efa6df888889b576027305d3a083ca5fd5b Mon Sep 17 00:00:00 2001 From: Yi Qian Date: Sat, 15 Mar 2025 05:07:21 +0000 Subject: [PATCH 1/3] [AMD][ROCDL][AMDGPU] Support packed conversions fp8/bf8->bf16 and fp8/bf8->fp32 Add packed conversions fp8/bf8->bf16 in gfx950 and fp8/bf8->fp32 in gfx942 Update amdgpu.ext_packed_fp8 lowering to use ROCDL CvtPkF32Fp8Op --- mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 13 +- mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 145 ++++++++++++------ .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 10 +- .../ArithToAMDGPU/ArithToAMDGPU.cpp | 32 ++-- .../AMDGPUToROCDL/8-bit-floats-ocp.mlir | 32 ++-- .../AMDGPUToROCDL/8-bit-floats.mlir | 32 ++-- .../ArithToAMDGPU/8-bit-floats-ocp.mlir | 90 +++++------ .../ArithToAMDGPU/8-bit-floats.mlir | 94 +++++------- mlir/test/Dialect/AMDGPU/ops.mlir | 6 +- mlir/test/Dialect/LLVMIR/rocdl.mlir | 8 + mlir/test/Target/LLVMIR/rocdl.mlir | 4 + 11 files changed, 251 insertions(+), 215 deletions(-) diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index 3acc383923ca8..3ed6e84d19044 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -85,11 +85,12 @@ def AMDGPU_ExtPackedFp8Op : AMDGPU_Op<"ext_packed_fp8", [Pure]>, Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN, VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source, - ConfinedAttr]>:$index)>, - Results<(outs F32:$res)> { - let summary = "Extend one of a vector of packed fp8 values to a float"; + ConfinedAttr]>:$wordIndex)>, + Results<(outs FixedVectorOfLengthAndType<[2], [F32]>:$res)> { + let summary = "Extend a vector of packed fp8 values to two floats"; + let description = [{ - Extend the value `source[index]` to a 32-bit float and return it. + Extend the two 8-bit floats in `source[wordrIndex]` to two 32-bit floats and return them. This rather unusual signature arises from the fact that AMD GPUs cannot easily work with sub 32-bit quantities, so the compiler intrinsics for @@ -97,11 +98,11 @@ def AMDGPU_ExtPackedFp8Op : this operation) take packed vectors of 4 such floats. If the passed-in vector has fewer than four elements, or the input is scalar, - the remaining values in the <4 x i8> will be filled with with + the remaining values in the <4 x i8> will be filled with undefined values as needed. }]; let assemblyFormat = [{ - attr-dict $source `[` $index `]` `:` type($source) `to` type($res) + attr-dict $source `[` $wordIndex `]` `:` type($source) `to` type($res) }]; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index f194e70ee275b..9a433202e3149 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -681,26 +681,26 @@ def ROCDL_CvtPkRtz: }]; } -def ROCDL_CvtScaleF32PkFp8F16 : +def ROCDL_CvtScaleF32PkFp8F16Op : ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f16", [], [], [Pure], 1>, Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> { let summary = "Scale and convert f16 to packed fp8"; let description = [{ - Scale `src` by the exponent in `scale` then convert to packed fp8. - Store the result in low/high word based on $wordSel, preserving the other word. + Scale `src` by the exponent in `scale`, then convert to packed fp8. + Store the result in low/high word of `old` based on $wordSel, preserving the other word. }]; let assemblyFormat = [{ attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res) }]; } -def ROCDL_CvtScaleF32PkFp8Bf16 : +def ROCDL_CvtScaleF32PkFp8Bf16Op : ROCDL_IntrOp<"cvt.scalef32.pk.fp8.bf16", [], [], [Pure], 1>, Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> { let summary = "Scale and convert packed bf16 to packed fp8"; let description = [{ - Scale `src` by the exponent in `scale` then convert to packed fp8. - Store the result in low/high word based on $wordSel, preserving the other word. + Scale `src` by the exponent in `scale`, then convert to packed fp8. + Store the result in low/high word of `old` based on $wordSel, preserving the other word. }]; let assemblyFormat = [{ attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res) @@ -708,13 +708,13 @@ def ROCDL_CvtScaleF32PkFp8Bf16 : } -def ROCDL_CvtScaleF32PkBf8F16 : +def ROCDL_CvtScaleF32PkBf8F16Op : ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f16", [], [], [Pure], 1>, Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> { let summary = "Scale and convert f16 to packed bf8"; let description = [{ - Scale `src` by the exponent in `scale` then convert to packed bf8. - Store the result in low/high word based on $wordSel, preserving the other word. + Scale `src` by the exponent in `scale`, then convert to packed bf8. + Store the result in low/high word of `old` based on $wordSel, preserving the other word. }]; let assemblyFormat = [{ attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res) @@ -722,26 +722,26 @@ def ROCDL_CvtScaleF32PkBf8F16 : } -def ROCDL_CvtScaleF32PkBf8Bf16 : +def ROCDL_CvtScaleF32PkBf8Bf16Op : ROCDL_IntrOp<"cvt.scalef32.pk.bf8.bf16", [], [], [Pure], 1>, Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> { let summary = "Scale and convert bf16 to packed bf8"; let description = [{ - Scale `src` by the exponent in `scale` then convert to packed bf8. - Store the result in low/high word based on $wordSel, preserving the other word. + Scale `src` by the exponent in `scale`, then convert to packed bf8. + Store the result in low/high word of `old` based on $wordSel, preserving the other word. }]; let assemblyFormat = [{ attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res) }]; } -def ROCDL_CvtScaleF32SrFp8F16 : +def ROCDL_CvtScaleF32SrFp8F16Op : ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f16", [], [], [Pure], 1>, Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> { let summary = "Scale and convert f16 to packed fp8 using stochastic rounding"; let description = [{ - Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding - using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others. + Scale `src` by the exponent in `scale`, then convert to packed p8 with stochastic rounding + using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others. }]; let assemblyFormat = [{ @@ -749,13 +749,13 @@ def ROCDL_CvtScaleF32SrFp8F16 : }]; } -def ROCDL_CvtScaleF32SrBf8F16 : +def ROCDL_CvtScaleF32SrBf8F16Op : ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f16", [], [], [Pure], 1>, Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> { let summary = "Scale and convert f16 to packed bf8 using stochastic rounding"; let description = [{ - Scale `src` by the exponent in `scale` then convert to packed bf8 with stochastic rounding - using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others. + Scale `src` by the exponent in `scale`, then convert to packed bf8 with stochastic rounding + using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others. }]; let assemblyFormat = [{ @@ -763,13 +763,13 @@ def ROCDL_CvtScaleF32SrBf8F16 : }]; } -def ROCDL_CvtScaleF32SrFp8Bf16 : +def ROCDL_CvtScaleF32SrFp8Bf16Op : ROCDL_IntrOp<"cvt.scalef32.sr.fp8.bf16", [], [], [Pure], 1>, Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> { let summary = "Scale and convert packed bf16 to packed fp8 using stochastic rounding"; let description = [{ - Scale `src` by the exponent in `scale` then convert to packed fp8 with stochastic rounding - using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others. + Scale `src` by the exponent in `scale`, then convert to packed fp8 with stochastic rounding + using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others. }]; let assemblyFormat = [{ @@ -777,13 +777,13 @@ def ROCDL_CvtScaleF32SrFp8Bf16 : }]; } -def ROCDL_CvtScaleF32SrBf8Bf16: +def ROCDL_CvtScaleF32SrBf8Bf16Op : ROCDL_IntrOp<"cvt.scalef32.sr.bf8.bf16", [], [], [Pure], 1>, Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> { let summary = "Scale and convert bf16 to packed fp8 using stochastic rounding"; let description = [{ - Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding - using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others. + Scale `src` by the exponent in `scale`, then convert to packed p8 with stochastic rounding + using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others. }]; let assemblyFormat = [{ @@ -791,48 +791,74 @@ def ROCDL_CvtScaleF32SrBf8Bf16: }]; } -def ROCDL_CvtScaleF32PkF16Fp8 : +def ROCDL_CvtScaleF32PkF16Fp8Op : ROCDL_IntrOp<"cvt.scalef32.pk.f16.fp8", [], [], [Pure], 1>, Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> { - let summary = "Scale and convert fp8 to packed f16"; - let description = [{ Scale `src` based on $wordSel by the exponent in `scale` - then convert to packed f16. + let summary = "Convert fp8 to packed f16 and scale"; + let description = [{ Convert `src` based on $wordSel to packed f16, then scale + the packed values by the exponent in `scale`. }]; let assemblyFormat = [{ attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res) }]; } -def ROCDL_CvtScaleF32PkF16Bf8 : +def ROCDL_CvtScaleF32PkF16Bf8Op : ROCDL_IntrOp<"cvt.scalef32.pk.f16.bf8", [], [], [Pure], 1>, Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> { - let summary = "Scale and convert bf8 to packed f16"; - let description = [{ Scale `src` based on $wordSel by the exponent in `scale` - then convert to packed f16. + let summary = "convert bf8 to packed f16 and scale"; + let description = [{ Convert `src` based on $wordSel to packed f16, then scale + the packed values by exponent in `scale`. }]; let assemblyFormat = [{ attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res) }]; } -def ROCDL_CvtScaleF16Fp8 : +def ROCDL_CvtScaleF32PkBf16Fp8Op : + ROCDL_IntrOp<"cvt.scalef32.pk.bf16.fp8", [], [], [Pure], 1>, + Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> { + let summary = "Convert fp8 to packed bf16 and scale"; + let description = [{ Convert `src` based on $wordSel to packed bf16, then scale + the packed values by the exponent in `scale`. + }]; + let assemblyFormat = [{ + attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res) + }]; +} + +def ROCDL_CvtScaleF32PkBf16Bf8Op : + ROCDL_IntrOp<"cvt.scalef32.pk.bf16.bf8", [], [], [Pure], 1>, + Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> { + let summary = "Convert bf8 to packed bf16 and scale"; + let description = [{ Convert `src` based on $wordSel to packed bf16, then scale + the packed values by the exponent in `scale`. + }]; + let assemblyFormat = [{ + attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res) + }]; +} + +def ROCDL_CvtScaleF16Fp8Op : ROCDL_IntrOp<"cvt.scalef32.f16.fp8", [], [], [Pure], 1>, Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> { let summary = "Scale and convert fp8 to f16"; - let description = [{ Scale `src` based on $wordSel by the exponent in `scale` - then convert to f16 store into the `byteSel`th byte of `old`, preserving the others. + let description = [{ Convert `src` based on $wordSel to f16, then scale the value + by the exponent in `scale`. Store the result into the `byteSel`th byte of `old`, + preserving the others. }]; let assemblyFormat = [{ attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res) }]; } -def ROCDL_CvtScaleF16Bf8 : +def ROCDL_CvtScaleF16Bf8Op : ROCDL_IntrOp<"cvt.scalef32.f16.bf8", [], [], [Pure], 1>, Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> { let summary = "Scale and convert fp8 to f16"; - let description = [{ Scale `src` based on $wordSel by the exponent in `scale` - then convert to f16 store into the `byteSel`th byte of `old`, preserving the others. + let description = [{ Convert `src` based on $wordSel to f16, then scale the value + by the exponent in `scale`. Store the result into the `byteSel`th byte of `old`, + preserving the others. }]; let assemblyFormat = [{ attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res) @@ -842,25 +868,25 @@ def ROCDL_CvtScaleF16Bf8 : //===---------------------------------------------------------------------===// // 32-bit float intrinsics //===---------------------------------------------------------------------===// -def ROCDL_CvtScale32PkF32Fp8 : +def ROCDL_CvtScaleF32PkF32Fp8Op : ROCDL_IntrOp<"cvt.scalef32.pk.f32.fp8", [], [], [Pure], 1>, Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> { let summary = "Scale and convert packed fp8 to packed f32"; let description = [{ - Scale `src` by the exponent in `scale` then convert to packed fp32. - Store the result in low/high word based on $wordSel, preserving the other word. + Convert `src` based on $wordSel to packed fp32, then scale the packed values by + the exponent in `scale`. Store the result in a vector. }]; let assemblyFormat = [{ attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res) }]; } -def ROCDL_CvtScale32PkF32Bf8 : +def ROCDL_CvtScaleF32PkF32Bf8Op : ROCDL_IntrOp<"cvt.scalef32.pk.f32.bf8", [], [], [Pure], 1>, Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> { let summary = "Scale and convert packed bf8 to packed f32"; let description = [{ - Scale `src` by the exponent in `scale` then convert to packed fp32. - Store the result in low/high word based on $wordSel, preserving the other word. + Convert `src` based on $wordSel to packed fp32, then scale the packed values by + the exponent in `scale`. Store the result in a vector. }]; let assemblyFormat = [{ attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res) @@ -869,7 +895,7 @@ def ROCDL_CvtScale32PkF32Bf8 : //===---------------------------------------------------------------------===// // 8-bit float scale intrinsics //===---------------------------------------------------------------------===// -def ROCDL_CvtScaleF32PkFp8F32: +def ROCDL_CvtScaleF32PkFp8F32Op : ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f32", [], [], [Pure], 1>, Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32:$scale, I1:$wordSel)> { let summary = "Scale and convert two f32's to packed fp8"; @@ -882,7 +908,7 @@ def ROCDL_CvtScaleF32PkFp8F32: }]; } -def ROCDL_CvtScaleF32PkBf8F32: +def ROCDL_CvtScaleF32PkBf8F32Op : ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f32", [], [], [Pure], 1>, Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32: $scale, I1:$wordSel)> { let summary = "Scale and convert two f32's to packed bf8"; @@ -895,7 +921,7 @@ def ROCDL_CvtScaleF32PkBf8F32: }]; } -def ROCDL_CvtScaleF32SrFp8F32: +def ROCDL_CvtScaleF32SrFp8F32Op : ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f32", [], [], [Pure], 1>, Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> { let summary = "Scale and convert f32 to fp8 using stochastic rounding"; @@ -909,7 +935,7 @@ def ROCDL_CvtScaleF32SrFp8F32: } -def ROCDL_CvtScaleF32SrBf8F32: +def ROCDL_CvtScaleF32SrBf8F32Op : ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f32", [], [], [Pure], 1>, Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> { let summary = "Scale and convert f32 to bf8 using stochastic rounding"; @@ -978,6 +1004,29 @@ def ROCDL_CvtScaleF32Fp8Op : }]; } +def ROCDL_CvtPkF32Fp8Op : + ROCDL_IntrOp<"cvt.pk.f32.fp8", [], [], [Pure], 1>, + Arguments<(ins I32:$src, I1:$wordSel)> { + let summary = "Convert packed fp8 to packed f32"; + let description = [{ + Convert `src` based on $wordSel to packed fp32. + }]; + let assemblyFormat = [{ + attr-dict $src `[` $wordSel `]` `:` type($res) + }]; +} + +def ROCDL_CvtPkF32Bf8Op : + ROCDL_IntrOp<"cvt.pk.f32.bf8", [], [], [Pure], 1>, + Arguments<(ins I32:$src, I1:$wordSel)> { + let summary = "Convert packed bf8 to packed f32"; + let description = [{ + Convert `src` based on $wordSel to packed fp32, + }]; + let assemblyFormat = [{ + attr-dict $src `[` $wordSel `]` `:` type($res) + }]; +} def ROCDL_CvtPkBf8F32Op : ROCDL_IntrOp<"cvt.pk.bf8.f32", [], [], [Pure], 1>, diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 949424db7c4d6..768d21384412d 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -977,13 +977,13 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( source = longVec; } Value i32Source = rewriter.create(loc, i32, source); - Value wordSel = createI32Constant(rewriter, loc, op.getIndex()); + Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex()); if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) { - rewriter.replaceOpWithNewOp(op, f32, i32Source, - wordSel); + rewriter.replaceOpWithNewOp(op, f32, i32Source, + wordSel); } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) { - rewriter.replaceOpWithNewOp(op, f32, i32Source, - wordSel); + rewriter.replaceOpWithNewOp(op, f32, i32Source, + wordSel); } return success(); } diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index 27be54728c1a1..f9b685d1e90f6 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -83,14 +83,15 @@ static bool isSupportedF8(Type elementType, Chipset chipset) { return false; } -static Value castF32To(Type elementType, Value f32, Location loc, +static Value castF32To(Type desType, Value f32, Location loc, PatternRewriter &rewriter) { + Type elementType = getElementTypeOrSelf(desType); if (elementType.isF32()) return f32; if (elementType.getIntOrFloatBitWidth() < 32) - return rewriter.create(loc, elementType, f32); + return rewriter.create(loc, desType, f32); if (elementType.getIntOrFloatBitWidth() > 32) - return rewriter.create(loc, elementType, f32); + return rewriter.create(loc, desType, f32); llvm_unreachable("The only 32-bit float type is f32"); } @@ -110,10 +111,12 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op, Location loc = op.getLoc(); Value in = op.getIn(); Type outElemType = getElementTypeOrSelf(op.getOut().getType()); + VectorType extResType = VectorType::get(2, rewriter.getF32Type()); if (!inVecType) { - Value asFloat = rewriter.create( - loc, rewriter.getF32Type(), in, 0); - Value result = castF32To(outElemType, asFloat, loc, rewriter); + Value asFloats = + rewriter.create(loc, extResType, in, 0); + Value resFloat = rewriter.create(loc, asFloats, 0); + Value result = castF32To(outElemType, resFloat, loc, rewriter); rewriter.replaceOp(op, result); return success(); } @@ -150,11 +153,18 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op, int64_t elemsThisOp = std::min(numElements, i + 4) - i; Value inSlice = rewriter.create( loc, in, i, elemsThisOp, 1); - for (int64_t j = 0; j < elemsThisOp; ++j) { - Value asFloat = rewriter.create( - loc, rewriter.getF32Type(), inSlice, j); - Value asType = castF32To(outElemType, asFloat, loc, rewriter); - result = rewriter.create(loc, asType, result, i + j); + for (int64_t j = 0; j < elemsThisOp; j += 2) { + Value asFloats = rewriter.create(loc, extResType, + inSlice, j / 2); + Type desType = VectorType::get(2, outElemType); + Value asType = castF32To(desType, asFloats, loc, rewriter); + if (i + j + 1 < numElements) + result = rewriter.create( + loc, asType, result, i + j, 1); + else { + asType = rewriter.create(loc, asType, 0); + result = rewriter.create(loc, asType, result, i + j); + } } } diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir index 70775a603e54d..0fb03ff13b558 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir @@ -7,12 +7,12 @@ // CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8> // CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32 -// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32 -// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32 +// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(false) : i1 +// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : vector<2xf32> // CHECK: return [[EXT]] -func.func @ext_scalar(%v: f8E5M2) -> f32 { - %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2 to f32 - func.return %ret : f32 +func.func @ext_scalar(%v: f8E5M2) -> vector<2xf32> { + %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2 to vector<2xf32> + func.return %ret : vector<2xf32> } // CHECK-LABEL: func @ext_short_vec @@ -25,24 +25,24 @@ func.func @ext_scalar(%v: f8E5M2) -> f32 { // CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8> // CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8> // CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32 -// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32 -// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32 +// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(false) : i1 +// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : vector<2xf32> // CHECK: return [[EXT]] -func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> f32 { - %ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FN> to f32 - func.return %ret : f32 +func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> vector<2xf32> { + %ret = amdgpu.ext_packed_fp8 %v[0] : vector<2xf8E4M3FN> to vector<2xf32> + func.return %ret : vector<2xf32> } // CHECK-LABEL: func @ext_full_vec( // CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8> // CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32 -// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32 -// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32 -// CHECK: return [[EXT]] : f32 +// CHECK: [[C3:%.+]] = llvm.mlir.constant(true) : i1 +// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C3]]] : vector<2xf32> +// CHECK: return [[EXT]] : vector<2xf32> -func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> f32 { - %ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FN> to f32 - func.return %ret : f32 +func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> vector<2xf32> { + %ret = amdgpu.ext_packed_fp8 %v[1] : vector<4xf8E4M3FN> to vector<2xf32> + func.return %ret : vector<2xf32> } // CHECK-LABEL: func @packed_trunc diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir index a313aaffdf5cc..0a4a960d59ce8 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir @@ -6,12 +6,12 @@ // CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8> // CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32 -// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32 -// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32 +// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(false) : i1 +// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : vector<2xf32> // CHECK: return [[EXT]] -func.func @ext_scalar(%v: f8E5M2FNUZ) -> f32 { - %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2FNUZ to f32 - func.return %ret : f32 +func.func @ext_scalar(%v: f8E5M2FNUZ) -> vector<2xf32> { + %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2FNUZ to vector<2xf32> + func.return %ret : vector<2xf32> } // CHECK-LABEL: func @ext_short_vec @@ -24,24 +24,24 @@ func.func @ext_scalar(%v: f8E5M2FNUZ) -> f32 { // CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8> // CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8> // CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32 -// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32 -// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32 +// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(false) : i1 +// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : vector<2xf32> // CHECK: return [[EXT]] -func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> f32 { - %ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FNUZ> to f32 - func.return %ret : f32 +func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> vector<2xf32> { + %ret = amdgpu.ext_packed_fp8 %v[0] : vector<2xf8E4M3FNUZ> to vector<2xf32> + func.return %ret : vector<2xf32> } // CHECK-LABEL: func @ext_full_vec( // CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FNUZ> to vector<4xi8> // CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32 -// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32 -// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32 -// CHECK: return [[EXT]] : f32 +// CHECK: [[C3:%.+]] = llvm.mlir.constant(true) : i1 +// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C3]]] : vector<2xf32> +// CHECK: return [[EXT]] : vector<2xf32> -func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> f32 { - %ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FNUZ> to f32 - func.return %ret : f32 +func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> { + %ret = amdgpu.ext_packed_fp8 %v[1] : vector<4xf8E4M3FNUZ> to vector<2xf32> + func.return %ret : vector<2xf32> } // CHECK-LABEL: func @packed_trunc diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir index 0e7f58c9e6749..b75b69c1b5d27 100644 --- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir +++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir @@ -1,10 +1,11 @@ // RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx950" | FileCheck %s // RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx1200" | FileCheck %s - + // CHECK-LABEL: func.func @scalar_ext // CHECK-SAME: ([[V:%.+]]: f8E5M2) -// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2 to f32 -// CHECK: [[W:%.+]] = arith.truncf [[FLOAT]] : f32 to f16 +// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2 to vector<2xf32> +// CHECK: [[EXT:%.+]] = vector.extract [[FLOAT]][0] : f32 from vector<2xf32> +// CHECK: [[W:%.+]] = arith.truncf [[EXT]] : f32 to f16 // CHECK: return [[W]] func.func @scalar_ext(%v: f8E5M2) -> f16 { %w = arith.extf %v : f8E5M2 to f16 @@ -17,14 +18,9 @@ func.func @scalar_ext(%v: f8E5M2) -> f16 { // CHECK-LABEL: func.func @vector_ext_short // CHECK-SAME: ([[V:%.+]]: vector<2xf8E5M2>) -// CHECK-DAG: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<2xf64> -// CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2> to f32 -// CHECK: [[EXT0:%.+]] = arith.extf [[FLOAT0]] : f32 to f64 -// CHECK: [[W0:%.+]] = vector.insert [[EXT0]], [[ZEROES]] [0] -// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[V]][1] : vector<2xf8E5M2> to f32 -// CHECK: [[EXT1:%.+]] = arith.extf [[FLOAT1]] -// CHECK: [[W1:%.+]] = vector.insert [[EXT1]], [[W0]] [1] -// CHECK: return [[W1]] : vector<2xf64> +// CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2> to vector<2xf32> +// CHECK: [[EXT:%.+]] = arith.extf [[FLOAT0]] : vector<2xf32> to vector<2xf64> +// CHECK: return [[EXT]] : vector<2xf64> func.func @vector_ext_short(%v: vector<2xf8E5M2>) -> vector<2xf64> { %w = arith.extf %v : vector<2xf8E5M2> to vector<2xf64> @@ -35,30 +31,22 @@ func.func @vector_ext_short(%v: vector<2xf8E5M2>) -> vector<2xf64> { // CHECK-LABEL: func.func @vector_ext_long // CHECK-SAME: ([[V:%.+]]: vector<9xf8E4M3FN>) -// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]} -// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] -// CHECK: [[W0:%.+]] = vector.insert [[F0]] -// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] -// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]] -// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2] -// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]] -// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3] -// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]] - -// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN> -// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] -// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]] -// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] -// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]] -// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2] -// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]] -// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3] -// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]] - -// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN> -// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] -// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]] -// CHECK: return [[W8]] +// CHECK: [[W0:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf32> +// CHECK: [[IN1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN> +// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[IN1]][0] : vector<4xf8E4M3FN> to vector<2xf32> +// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[FLOAT1]], [[W0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<9xf32> +// CHECK: [[FLOAT2:%.+]] = amdgpu.ext_packed_fp8 [[IN1]][1] : vector<4xf8E4M3FN> to vector<2xf32> +// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[FLOAT2]], [[W1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<9xf32> +// CHECK: [[IN2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN> +// CHECK: [[FLOAT3:%.+]] = amdgpu.ext_packed_fp8 [[IN2]][0] : vector<4xf8E4M3FN> to vector<2xf32> +// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[FLOAT3]], [[W2]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<9xf32> +// CHECK: [[FLOAT4:%.+]] = amdgpu.ext_packed_fp8 [[IN2]][1] : vector<4xf8E4M3FN> to vector<2xf32> +// CHECK: [[W4:%.+]] = vector.insert_strided_slice [[FLOAT4]], [[W3]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32> +// CHECK: [[IN3:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN> +// CHECK: [[FLOAT5:%.+]] = amdgpu.ext_packed_fp8 [[IN3]][0] : vector<1xf8E4M3FN> to vector<2xf32> +// CHECK: [[FLOAT6:%.+]] = vector.extract [[FLOAT5]][0] : f32 from vector<2xf32> +// CHECK: [[W5:%.+]] = vector.insert [[FLOAT6]], [[W4]] [8] : f32 into vector<9xf32> +// CHECK: return [[W5]] func.func @vector_ext_long(%v: vector<9xf8E4M3FN>) -> vector<9xf32> { %w = arith.extf %v : vector<9xf8E4M3FN> to vector<9xf32> return %w : vector<9xf32> @@ -144,31 +132,25 @@ func.func @vector_trunc_long_2d(%v: vector<1x9xf32>) -> vector<1x9xf8E4M3FN> { // CHECK-LABEL: func.func @vector_ext_long_2d // CHECK-SAME: ([[V:%.+]]: vector<1x9xf8E4M3FN>) +// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf32> // CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x9xf8E4M3FN> to vector<9xf8E4M3FN> // CHECK: [[V0:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [0], sizes = [4], strides = [1]} -// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] -// CHECK: [[W0:%.+]] = vector.insert [[F0]] -// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] -// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]] -// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2] -// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]] -// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3] -// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]] +// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] : vector<4xf8E4M3FN> to vector<2xf32> +// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[F0]], [[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<9xf32> +// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] : vector<4xf8E4M3FN> to vector<2xf32> +// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[F1]], [[W0]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<9xf32> // CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN> -// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] -// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]] -// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] -// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]] -// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2] -// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]] -// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3] -// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]] +// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] : vector<4xf8E4M3FN> to vector<2xf32> +// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[F2]], [[W1]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<9xf32> +// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] : vector<4xf8E4M3FN> to vector<2xf32> +// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[F3]], [[W2]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32> // CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN> -// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] -// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]] -// CHECK: [[CAST:%.+]] = vector.shape_cast [[W8]] : vector<9xf32> to vector<1x9xf32> +// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] : vector<1xf8E4M3FN> to vector<2xf32> +// CHECK: [[E0:%.+]] = vector.extract [[F4]][0] : f32 from vector<2xf32> +// CHECK: [[W4:%.+]] = vector.insert [[E0]], [[W3]] [8] : f32 into vector<9xf32> +// CHECK: [[CAST:%.+]] = vector.shape_cast [[W4]] : vector<9xf32> to vector<1x9xf32> // CHECK: return [[CAST]] func.func @vector_ext_long_2d(%v: vector<1x9xf8E4M3FN>) -> vector<1x9xf32> { %w = arith.extf %v : vector<1x9xf8E4M3FN> to vector<1x9xf32> diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir index 6bb5b9771c015..2ed3f47e8ab73 100644 --- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir +++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir @@ -2,8 +2,9 @@ // CHECK-LABEL: func.func @scalar_ext // CHECK-SAME: ([[V:%.+]]: f8E5M2FNUZ) -// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2FNUZ to f32 -// CHECK: [[W:%.+]] = arith.truncf [[FLOAT]] : f32 to f16 +// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2FNUZ to vector<2xf32> +// CHECK: [[EXT:%.+]] = vector.extract [[FLOAT]][0] : f32 from vector<2xf32> +// CHECK: [[W:%.+]] = arith.truncf [[EXT]] : f32 to f16 // CHECK: return [[W]] func.func @scalar_ext(%v: f8E5M2FNUZ) -> f16 { %w = arith.extf %v : f8E5M2FNUZ to f16 @@ -16,8 +17,9 @@ func.func @scalar_ext(%v: f8E5M2FNUZ) -> f16 { // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: vector) -> vector // CHECK: %[[CONST:.+]] = arith.constant dense<0.000000e+00> : vector // CHECK: %[[EXTRACT:.+]] = vector.extract %[[ARG0]][] : f8E5M2FNUZ from vector -// CHECK: %[[CONVERT:.+]] = amdgpu.ext_packed_fp8 %[[EXTRACT]][0] : f8E5M2FNUZ to f32 -// CHECK: %[[RESULT:.+]] = vector.insert %[[CONVERT]], %[[CONST]] [] : f32 into vector +// CHECK: %[[CONVERT:.+]] = amdgpu.ext_packed_fp8 %[[EXTRACT]][0] : f8E5M2FNUZ to vector<2xf32> +// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[CONVERT]][0] : f32 from vector<2xf32> +// CHECK: %[[RESULT:.+]] = vector.insert %[[EXTRACT2]], %[[CONST]] [] : f32 into vector // CHECK: return %[[RESULT]] : vector func.func @vector_zero_d(%v: vector) -> vector { %w = arith.extf %v : vector to vector @@ -28,15 +30,9 @@ func.func @vector_zero_d(%v: vector) -> vector { // CHECK-LABEL: func.func @vector_ext_short // CHECK-SAME: ([[V:%.+]]: vector<2xf8E5M2FNUZ>) -// CHECK-DAG: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<2xf64> -// CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2FNUZ> to f32 -// CHECK: [[EXT0:%.+]] = arith.extf [[FLOAT0]] : f32 to f64 -// CHECK: [[W0:%.+]] = vector.insert [[EXT0]], [[ZEROES]] [0] -// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[V]][1] : vector<2xf8E5M2FNUZ> to f32 -// CHECK: [[EXT1:%.+]] = arith.extf [[FLOAT1]] -// CHECK: [[W1:%.+]] = vector.insert [[EXT1]], [[W0]] [1] -// CHECK: return [[W1]] : vector<2xf64> - +// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2FNUZ> to vector<2xf32> +// CHECK: [[EXT:%.+]] = arith.extf [[FLOAT]] : vector<2xf32> to vector<2xf64> +// CHECK: return [[EXT]] : vector<2xf64> func.func @vector_ext_short(%v: vector<2xf8E5M2FNUZ>) -> vector<2xf64> { %w = arith.extf %v : vector<2xf8E5M2FNUZ> to vector<2xf64> return %w : vector<2xf64> @@ -46,30 +42,22 @@ func.func @vector_ext_short(%v: vector<2xf8E5M2FNUZ>) -> vector<2xf64> { // CHECK-LABEL: func.func @vector_ext_long // CHECK-SAME: ([[V:%.+]]: vector<9xf8E4M3FNUZ>) -// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]} -// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] -// CHECK: [[W0:%.+]] = vector.insert [[F0]] -// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] -// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]] -// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2] -// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]] -// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3] -// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]] - -// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ> -// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] -// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]] -// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] -// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]] -// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2] -// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]] -// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3] -// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]] - -// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ> -// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] -// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]] -// CHECK: return [[W8]] +// CHECK: [[W0:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf32> +// CHECK: [[IN1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ> +// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[IN1]][0] : vector<4xf8E4M3FNUZ> to vector<2xf32> +// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[FLOAT1]], [[W0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<9xf32> +// CHECK: [[FLOAT2:%.+]] = amdgpu.ext_packed_fp8 [[IN1]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32> +// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[FLOAT2]], [[W1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<9xf32> +// CHECK: [[IN2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ> +// CHECK: [[FLOAT3:%.+]] = amdgpu.ext_packed_fp8 [[IN2]][0] : vector<4xf8E4M3FNUZ> to vector<2xf32> +// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[FLOAT3]], [[W2]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<9xf32> +// CHECK: [[FLOAT4:%.+]] = amdgpu.ext_packed_fp8 [[IN2]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32> +// CHECK: [[W4:%.+]] = vector.insert_strided_slice [[FLOAT4]], [[W3]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32> +// CHECK: [[IN3:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ> +// CHECK: [[FLOAT5:%.+]] = amdgpu.ext_packed_fp8 [[IN3]][0] : vector<1xf8E4M3FNUZ> to vector<2xf32> +// CHECK: [[FLOAT6:%.+]] = vector.extract [[FLOAT5]][0] : f32 from vector<2xf32> +// CHECK: [[W5:%.+]] = vector.insert [[FLOAT6]], [[W4]] [8] : f32 into vector<9xf32> +// CHECK: return [[W5]] func.func @vector_ext_long(%v: vector<9xf8E4M3FNUZ>) -> vector<9xf32> { %w = arith.extf %v : vector<9xf8E4M3FNUZ> to vector<9xf32> return %w : vector<9xf32> @@ -155,31 +143,25 @@ func.func @vector_trunc_long_2d(%v: vector<1x9xf32>) -> vector<1x9xf8E4M3FNUZ> { // CHECK-LABEL: func.func @vector_ext_long_2d // CHECK-SAME: ([[V:%.+]]: vector<1x9xf8E4M3FNUZ>) +// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf32> // CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x9xf8E4M3FNUZ> to vector<9xf8E4M3FNUZ> // CHECK: [[V0:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [0], sizes = [4], strides = [1]} -// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] -// CHECK: [[W0:%.+]] = vector.insert [[F0]] -// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] -// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]] -// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2] -// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]] -// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3] -// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]] +// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] : vector<4xf8E4M3FNUZ> to vector<2xf32> +// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[F0]], [[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<9xf32> +// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32> +// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[F1]], [[W0]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<9xf32> // CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ> -// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] -// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]] -// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] -// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]] -// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2] -// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]] -// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3] -// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]] +// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] : vector<4xf8E4M3FNUZ> to vector<2xf32> +// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[F2]], [[W1]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<9xf32> +// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32> +// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[F3]], [[W2]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32> // CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ> -// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] -// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]] -// CHECK: [[CAST:%.+]] = vector.shape_cast [[W8]] : vector<9xf32> to vector<1x9xf32> +// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] : vector<1xf8E4M3FNUZ> to vector<2xf32> +// CHECK: [[E0:%.+]] = vector.extract [[F4]][0] : f32 from vector<2xf32> +// CHECK: [[W4:%.+]] = vector.insert [[E0]], [[W3]] [8] : f32 into vector<9xf32> +// CHECK: [[CAST:%.+]] = vector.shape_cast [[W4]] : vector<9xf32> to vector<1x9xf32> // CHECK: return [[CAST]] func.func @vector_ext_long_2d(%v: vector<1x9xf8E4M3FNUZ>) -> vector<1x9xf32> { %w = arith.extf %v : vector<1x9xf8E4M3FNUZ> to vector<1x9xf32> diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir index 567e6498330a3..bf312ead32712 100644 --- a/mlir/test/Dialect/AMDGPU/ops.mlir +++ b/mlir/test/Dialect/AMDGPU/ops.mlir @@ -6,9 +6,9 @@ // CHECK-LABEL: func @ext_packed_fp8 // CHECK: amdgpu.ext_packed_fp8 -func.func @ext_packed_fp8(%v: vector<4xf8E4M3FNUZ>) -> f32 { - %ret = amdgpu.ext_packed_fp8 %v[0] : vector<4xf8E4M3FNUZ> to f32 - func.return %ret : f32 +func.func @ext_packed_fp8(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> { + %ret = amdgpu.ext_packed_fp8 %v[0] : vector<4xf8E4M3FNUZ> to vector<2xf32> + func.return %ret : vector<2xf32> } // CHECK-LABEL: func @packed_trunc_2xfp8 diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index bc917041998d8..cce2c0aee62f3 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -767,10 +767,14 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf // CHECK: rocdl.cvt.scalef32.f32.fp8 // CHECK: rocdl.cvt.scalef32.pk.f16.bf8 // CHECK: rocdl.cvt.scalef32.pk.f16.fp8 +// CHECK: rocdl.cvt.scalef32.pk.bf16.bf8 +// CHECK: rocdl.cvt.scalef32.pk.bf16.fp8 // CHECK: rocdl.cvt.scalef32.f16.fp8 // CHECK: rocdl.cvt.scalef32.f16.bf8 // CHECK: rocdl.cvt.pk.bf8.f32 // CHECK: rocdl.cvt.pk.fp8.f32 +// CHECK: rocdl.cvt.pk.f32.bf8 +// CHECK: rocdl.cvt.pk.f32.fp8 // CHECK: rocdl.cvt.sr.bf8.f32 // CHECK: rocdl.cvt.sr.fp8.f32 // CHECK: rocdl.cvt.scalef32.sr.fp8.f32 @@ -793,10 +797,14 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf %v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[%c0], %c4 : f32 %v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[%false], %c4 : vector<2xf16> %v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : vector<2xf16> + %v3_scaled_bf16 = rocdl.cvt.scalef32.pk.bf16.bf8 %source[%false], %c4 : vector<2xbf16> + %v4_scaled_bf16 = rocdl.cvt.scalef32.pk.bf16.fp8 %source[%false], %c4 : vector<2xbf16> %v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %v3_scaled[%c0] : f16 %v6 = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %v3_scaled[%c0] : f16 %source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32 %source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32 + %source2_ext = rocdl.cvt.pk.f32.bf8 %source[%false] : vector<2xf32> + %source3_ext = rocdl.cvt.pk.f32.fp8 %source[%false] : vector<2xf32> %source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32 %source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32 %source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[%c3] : i32 diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index 11f2faa2761ff..e70617bfff99e 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -1042,6 +1042,8 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf // CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.pk.f16.fp8(i32 %{{.+}}, float 1.000000e+00, i1 false) // CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.f16.fp8(<2 x half> %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 0, i1 false) // CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.f16.bf8(<2 x half> %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 0, i1 false) +// CHECK: call <2 x bfloat> @llvm.amdgcn.cvt.scalef32.pk.bf16.bf8(i32 %{{.+}}, float 1.000000e+00, i1 false) +// CHECK: call <2 x bfloat> @llvm.amdgcn.cvt.scalef32.pk.bf16.fp8(i32 %{{.+}}, float 1.000000e+00, i1 false) // CHECK: call i32 @llvm.amdgcn.cvt.pk.bf8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false) // CHECK: call i32 @llvm.amdgcn.cvt.pk.fp8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false) // CHECK: call i32 @llvm.amdgcn.cvt.sr.bf8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 2) @@ -1068,6 +1070,8 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf %v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : i32 %v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %source_packed[%c0] : f16 %v6 = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %source_packed[%c0] : f16 + %v7 = rocdl.cvt.scalef32.pk.bf16.bf8 %source[%false], %c4 : i32 + %v8 = rocdl.cvt.scalef32.pk.bf16.fp8 %source[%false], %c4 : i32 %source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32 %source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32 %source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32 From ae25d7357b40f38caa599af78f41db97516005b7 Mon Sep 17 00:00:00 2001 From: Yi Qian Date: Wed, 19 Mar 2025 04:45:36 +0000 Subject: [PATCH 2/3] Allow amdgpu.ext_packed_fp8 to return a scalar or vector type --- mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 11 ++-- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 26 +++++++--- .../ArithToAMDGPU/ArithToAMDGPU.cpp | 23 +++++---- .../AMDGPUToROCDL/8-bit-floats-ocp.mlir | 48 ++++++++++++++---- .../AMDGPUToROCDL/8-bit-floats.mlir | 48 ++++++++++++++---- .../ArithToAMDGPU/8-bit-floats-ocp.mlir | 45 ++++++++--------- .../ArithToAMDGPU/8-bit-floats.mlir | 50 +++++++++---------- 7 files changed, 159 insertions(+), 92 deletions(-) diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index 3ed6e84d19044..c0b3e5540b1df 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -85,12 +85,13 @@ def AMDGPU_ExtPackedFp8Op : AMDGPU_Op<"ext_packed_fp8", [Pure]>, Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN, VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source, - ConfinedAttr]>:$wordIndex)>, - Results<(outs FixedVectorOfLengthAndType<[2], [F32]>:$res)> { - let summary = "Extend a vector of packed fp8 values to two floats"; + ConfinedAttr]>:$index)>, + Results<(outs AnyTypeOf<[F32, FixedVectorOfLengthAndType<[2], [F32]>]>:$res)> { + let summary = "Extend a fp8 value to a float or a vector of packed fp8 values to two floats"; let description = [{ - Extend the two 8-bit floats in `source[wordrIndex]` to two 32-bit floats and return them. + Extend one or two 8-bit floats in `source[index]` to a 32-bit float or + two floats and return them. This rather unusual signature arises from the fact that AMD GPUs cannot easily work with sub 32-bit quantities, so the compiler intrinsics for @@ -102,7 +103,7 @@ def AMDGPU_ExtPackedFp8Op : undefined values as needed. }]; let assemblyFormat = [{ - attr-dict $source `[` $wordIndex `]` `:` type($source) `to` type($res) + attr-dict $source `[` $index `]` `:` type($source) `to` type($res) }]; } diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 768d21384412d..3acd470cff7f5 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -959,6 +959,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( Value source = adaptor.getSource(); auto sourceVecType = dyn_cast(op.getSource().getType()); + auto resultVecType = dyn_cast(op.getResult().getType()); Type sourceElemType = getElementTypeOrSelf(op.getSource()); // Extend to a v4i8 if (!sourceVecType || sourceVecType.getNumElements() < 4) { @@ -977,13 +978,24 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( source = longVec; } Value i32Source = rewriter.create(loc, i32, source); - Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex()); - if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) { - rewriter.replaceOpWithNewOp(op, f32, i32Source, - wordSel); - } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) { - rewriter.replaceOpWithNewOp(op, f32, i32Source, - wordSel); + if (resultVecType) { + Value wordSel = createI1Constant(rewriter, loc, op.getIndex()); + if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) { + rewriter.replaceOpWithNewOp(op, f32, i32Source, + wordSel); + } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) { + rewriter.replaceOpWithNewOp(op, f32, i32Source, + wordSel); + } + } else { + Value byteSel = createI32Constant(rewriter, loc, op.getIndex()); + if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) { + rewriter.replaceOpWithNewOp(op, f32, i32Source, + byteSel); + } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) { + rewriter.replaceOpWithNewOp(op, f32, i32Source, + byteSel); + } } return success(); } diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index f9b685d1e90f6..3596b3235a631 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -113,10 +113,9 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op, Type outElemType = getElementTypeOrSelf(op.getOut().getType()); VectorType extResType = VectorType::get(2, rewriter.getF32Type()); if (!inVecType) { - Value asFloats = - rewriter.create(loc, extResType, in, 0); - Value resFloat = rewriter.create(loc, asFloats, 0); - Value result = castF32To(outElemType, resFloat, loc, rewriter); + Value asFloat = rewriter.create( + loc, rewriter.getF32Type(), in, 0); + Value result = castF32To(outElemType, asFloat, loc, rewriter); rewriter.replaceOp(op, result); return success(); } @@ -154,15 +153,17 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op, Value inSlice = rewriter.create( loc, in, i, elemsThisOp, 1); for (int64_t j = 0; j < elemsThisOp; j += 2) { - Value asFloats = rewriter.create(loc, extResType, - inSlice, j / 2); - Type desType = VectorType::get(2, outElemType); - Value asType = castF32To(desType, asFloats, loc, rewriter); - if (i + j + 1 < numElements) + if (i + j + 1 < numElements) { // Convert two 8-bit elements + Value asFloats = rewriter.create( + loc, extResType, inSlice, j / 2); + Type desType = VectorType::get(2, outElemType); + Value asType = castF32To(desType, asFloats, loc, rewriter); result = rewriter.create( loc, asType, result, i + j, 1); - else { - asType = rewriter.create(loc, asType, 0); + } else { // Convert a 8-bit element + Value asFloat = rewriter.create( + loc, rewriter.getF32Type(), inSlice, j / 2 * 2); + Value asType = castF32To(outElemType, asFloat, loc, rewriter); result = rewriter.create(loc, asType, result, i + j); } } diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir index 0fb03ff13b558..eb483b0880294 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir @@ -7,12 +7,12 @@ // CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8> // CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32 -// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(false) : i1 -// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : vector<2xf32> -// CHECK: return [[EXT]] -func.func @ext_scalar(%v: f8E5M2) -> vector<2xf32> { - %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2 to vector<2xf32> - func.return %ret : vector<2xf32> +// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32 +// CHECK: return [[EXT]] : f32 +func.func @ext_scalar(%v: f8E5M2) -> f32 { + %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2 to f32 + func.return %ret : f32 } // CHECK-LABEL: func @ext_short_vec @@ -25,22 +25,50 @@ func.func @ext_scalar(%v: f8E5M2) -> vector<2xf32> { // CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8> // CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8> // CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32 +// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32 +// CHECK: return [[EXT]] : f32 +func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> f32 { + %ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FN> to f32 + func.return %ret : f32 +} + +// CHECK-LABEL: func @ext_full_vec( +// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8> +// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32 +// CHECK: [[C3:%.+]] = llvm.mlir.constant(2 : i32) : i32 +// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32 +// CHECK: return [[EXT]] : f32 +func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> f32 { + %ret = amdgpu.ext_packed_fp8 %v[2] : vector<4xf8E4M3FN> to f32 + func.return %ret : f32 +} + +// CHECK-LABEL: func @ext_packed_2xfp8 +// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<2xf8E4M3FN> to vector<2xi8> +// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8> +// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8> +// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8> +// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8> +// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8> +// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32 // CHECK: [[C1_2:%.+]] = llvm.mlir.constant(false) : i1 // CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : vector<2xf32> // CHECK: return [[EXT]] -func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> vector<2xf32> { +func.func @ext_packed_2xfp8(%v: vector<2xf8E4M3FN>) -> vector<2xf32> { %ret = amdgpu.ext_packed_fp8 %v[0] : vector<2xf8E4M3FN> to vector<2xf32> func.return %ret : vector<2xf32> } -// CHECK-LABEL: func @ext_full_vec( +// CHECK-LABEL: func @ext_packed_4xfp8 // CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8> // CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32 // CHECK: [[C3:%.+]] = llvm.mlir.constant(true) : i1 // CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C3]]] : vector<2xf32> // CHECK: return [[EXT]] : vector<2xf32> - -func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> vector<2xf32> { +func.func @ext_packed_4xfp8(%v: vector<4xf8E4M3FN>) -> vector<2xf32> { %ret = amdgpu.ext_packed_fp8 %v[1] : vector<4xf8E4M3FN> to vector<2xf32> func.return %ret : vector<2xf32> } diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir index 0a4a960d59ce8..4029d14650d7f 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir @@ -6,12 +6,12 @@ // CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8> // CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32 -// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(false) : i1 -// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : vector<2xf32> -// CHECK: return [[EXT]] -func.func @ext_scalar(%v: f8E5M2FNUZ) -> vector<2xf32> { - %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2FNUZ to vector<2xf32> - func.return %ret : vector<2xf32> +// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32 +// CHECK: return [[EXT]] : f32 +func.func @ext_scalar(%v: f8E5M2FNUZ) -> f32 { + %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2FNUZ to f32 + func.return %ret : f32 } // CHECK-LABEL: func @ext_short_vec @@ -24,22 +24,50 @@ func.func @ext_scalar(%v: f8E5M2FNUZ) -> vector<2xf32> { // CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8> // CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8> // CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32 +// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32 +// CHECK: return [[EXT]] : f32 +func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> f32 { + %ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FNUZ> to f32 + func.return %ret : f32 +} + +// CHECK-LABEL: func @ext_full_vec +// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FNUZ> to vector<4xi8> +// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32 +// CHECK: [[C3:%.+]] = llvm.mlir.constant(2 : i32) : i32 +// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32 +// CHECK: return [[EXT]] : f32 +func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> f32 { + %ret = amdgpu.ext_packed_fp8 %v[2] : vector<4xf8E4M3FNUZ> to f32 + func.return %ret : f32 +} + +// CHECK-LABEL: func @ext_packed_2xfp8 +// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<2xf8E4M3FNUZ> to vector<2xi8> +// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8> +// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8> +// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8> +// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8> +// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8> +// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32 // CHECK: [[C1_2:%.+]] = llvm.mlir.constant(false) : i1 // CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : vector<2xf32> // CHECK: return [[EXT]] -func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> vector<2xf32> { +func.func @ext_packed_2xfp8(%v: vector<2xf8E4M3FNUZ>) -> vector<2xf32> { %ret = amdgpu.ext_packed_fp8 %v[0] : vector<2xf8E4M3FNUZ> to vector<2xf32> func.return %ret : vector<2xf32> } -// CHECK-LABEL: func @ext_full_vec( +// CHECK-LABEL: func @ext_packed_4xfp8( // CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FNUZ> to vector<4xi8> // CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32 // CHECK: [[C3:%.+]] = llvm.mlir.constant(true) : i1 // CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C3]]] : vector<2xf32> // CHECK: return [[EXT]] : vector<2xf32> - -func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> { +func.func @ext_packed_4xfp8(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> { %ret = amdgpu.ext_packed_fp8 %v[1] : vector<4xf8E4M3FNUZ> to vector<2xf32> func.return %ret : vector<2xf32> } diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir index b75b69c1b5d27..7fb5fbfe0c89e 100644 --- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir +++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir @@ -3,9 +3,8 @@ // CHECK-LABEL: func.func @scalar_ext // CHECK-SAME: ([[V:%.+]]: f8E5M2) -// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2 to vector<2xf32> -// CHECK: [[EXT:%.+]] = vector.extract [[FLOAT]][0] : f32 from vector<2xf32> -// CHECK: [[W:%.+]] = arith.truncf [[EXT]] : f32 to f16 +// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2 to f32 +// CHECK: [[W:%.+]] = arith.truncf [[FLOAT]] : f32 to f16 // CHECK: return [[W]] func.func @scalar_ext(%v: f8E5M2) -> f16 { %w = arith.extf %v : f8E5M2 to f16 @@ -43,9 +42,8 @@ func.func @vector_ext_short(%v: vector<2xf8E5M2>) -> vector<2xf64> { // CHECK: [[FLOAT4:%.+]] = amdgpu.ext_packed_fp8 [[IN2]][1] : vector<4xf8E4M3FN> to vector<2xf32> // CHECK: [[W4:%.+]] = vector.insert_strided_slice [[FLOAT4]], [[W3]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32> // CHECK: [[IN3:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN> -// CHECK: [[FLOAT5:%.+]] = amdgpu.ext_packed_fp8 [[IN3]][0] : vector<1xf8E4M3FN> to vector<2xf32> -// CHECK: [[FLOAT6:%.+]] = vector.extract [[FLOAT5]][0] : f32 from vector<2xf32> -// CHECK: [[W5:%.+]] = vector.insert [[FLOAT6]], [[W4]] [8] : f32 into vector<9xf32> +// CHECK: [[FLOAT5:%.+]] = amdgpu.ext_packed_fp8 [[IN3]][0] : vector<1xf8E4M3FN> to f32 +// CHECK: [[W5:%.+]] = vector.insert [[FLOAT5]], [[W4]] [8] : f32 into vector<9xf32> // CHECK: return [[W5]] func.func @vector_ext_long(%v: vector<9xf8E4M3FN>) -> vector<9xf32> { %w = arith.extf %v : vector<9xf8E4M3FN> to vector<9xf32> @@ -131,28 +129,29 @@ func.func @vector_trunc_long_2d(%v: vector<1x9xf32>) -> vector<1x9xf8E4M3FN> { // ----- // CHECK-LABEL: func.func @vector_ext_long_2d -// CHECK-SAME: ([[V:%.+]]: vector<1x9xf8E4M3FN>) -// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf32> -// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x9xf8E4M3FN> to vector<9xf8E4M3FN> +// CHECK-SAME: ([[V:%.+]]: vector<1x11xf8E4M3FN>) +// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<11xf32> +// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x11xf8E4M3FN> to vector<11xf8E4M3FN> // CHECK: [[V0:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [0], sizes = [4], strides = [1]} // CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] : vector<4xf8E4M3FN> to vector<2xf32> -// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[F0]], [[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<9xf32> +// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[F0]], [[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<11xf32> // CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] : vector<4xf8E4M3FN> to vector<2xf32> -// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[F1]], [[W0]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<9xf32> +// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[F1]], [[W0]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<11xf32> -// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN> +// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<11xf8E4M3FN> to vector<4xf8E4M3FN> // CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] : vector<4xf8E4M3FN> to vector<2xf32> -// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[F2]], [[W1]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<9xf32> +// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[F2]], [[W1]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<11xf32> // CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] : vector<4xf8E4M3FN> to vector<2xf32> -// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[F3]], [[W2]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32> - -// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN> -// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] : vector<1xf8E4M3FN> to vector<2xf32> -// CHECK: [[E0:%.+]] = vector.extract [[F4]][0] : f32 from vector<2xf32> -// CHECK: [[W4:%.+]] = vector.insert [[E0]], [[W3]] [8] : f32 into vector<9xf32> -// CHECK: [[CAST:%.+]] = vector.shape_cast [[W4]] : vector<9xf32> to vector<1x9xf32> +// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[F3]], [[W2]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<11xf32> + +// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [3], strides = [1]} : vector<11xf8E4M3FN> to vector<3xf8E4M3FN> +// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] : vector<3xf8E4M3FN> to vector<2xf32> +// CHECK: [[W4:%.+]] = vector.insert_strided_slice [[F4]], [[W3]] {offsets = [8], strides = [1]} : vector<2xf32> into vector<11xf32> +// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V2]][2] : vector<3xf8E4M3FN> to f32 +// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]] [10] : f32 into vector<11xf32> +// CHECK: [[CAST:%.+]] = vector.shape_cast [[W5]] : vector<11xf32> to vector<1x11xf32> // CHECK: return [[CAST]] -func.func @vector_ext_long_2d(%v: vector<1x9xf8E4M3FN>) -> vector<1x9xf32> { - %w = arith.extf %v : vector<1x9xf8E4M3FN> to vector<1x9xf32> - return %w : vector<1x9xf32> +func.func @vector_ext_long_2d(%v: vector<1x11xf8E4M3FN>) -> vector<1x11xf32> { + %w = arith.extf %v : vector<1x11xf8E4M3FN> to vector<1x11xf32> + return %w : vector<1x11xf32> } diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir index 2ed3f47e8ab73..59ed6bd95ae8b 100644 --- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir +++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir @@ -2,9 +2,8 @@ // CHECK-LABEL: func.func @scalar_ext // CHECK-SAME: ([[V:%.+]]: f8E5M2FNUZ) -// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2FNUZ to vector<2xf32> -// CHECK: [[EXT:%.+]] = vector.extract [[FLOAT]][0] : f32 from vector<2xf32> -// CHECK: [[W:%.+]] = arith.truncf [[EXT]] : f32 to f16 +// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2FNUZ to f32 +// CHECK: [[W:%.+]] = arith.truncf [[FLOAT]] : f32 to f16 // CHECK: return [[W]] func.func @scalar_ext(%v: f8E5M2FNUZ) -> f16 { %w = arith.extf %v : f8E5M2FNUZ to f16 @@ -17,9 +16,8 @@ func.func @scalar_ext(%v: f8E5M2FNUZ) -> f16 { // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: vector) -> vector // CHECK: %[[CONST:.+]] = arith.constant dense<0.000000e+00> : vector // CHECK: %[[EXTRACT:.+]] = vector.extract %[[ARG0]][] : f8E5M2FNUZ from vector -// CHECK: %[[CONVERT:.+]] = amdgpu.ext_packed_fp8 %[[EXTRACT]][0] : f8E5M2FNUZ to vector<2xf32> -// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[CONVERT]][0] : f32 from vector<2xf32> -// CHECK: %[[RESULT:.+]] = vector.insert %[[EXTRACT2]], %[[CONST]] [] : f32 into vector +// CHECK: %[[CONVERT:.+]] = amdgpu.ext_packed_fp8 %[[EXTRACT]][0] : f8E5M2FNUZ to f32 +// CHECK: %[[RESULT:.+]] = vector.insert %[[CONVERT]], %[[CONST]] [] : f32 into vector // CHECK: return %[[RESULT]] : vector func.func @vector_zero_d(%v: vector) -> vector { %w = arith.extf %v : vector to vector @@ -54,9 +52,8 @@ func.func @vector_ext_short(%v: vector<2xf8E5M2FNUZ>) -> vector<2xf64> { // CHECK: [[FLOAT4:%.+]] = amdgpu.ext_packed_fp8 [[IN2]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32> // CHECK: [[W4:%.+]] = vector.insert_strided_slice [[FLOAT4]], [[W3]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32> // CHECK: [[IN3:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ> -// CHECK: [[FLOAT5:%.+]] = amdgpu.ext_packed_fp8 [[IN3]][0] : vector<1xf8E4M3FNUZ> to vector<2xf32> -// CHECK: [[FLOAT6:%.+]] = vector.extract [[FLOAT5]][0] : f32 from vector<2xf32> -// CHECK: [[W5:%.+]] = vector.insert [[FLOAT6]], [[W4]] [8] : f32 into vector<9xf32> +// CHECK: [[FLOAT5:%.+]] = amdgpu.ext_packed_fp8 [[IN3]][0] : vector<1xf8E4M3FNUZ> to f32 +// CHECK: [[W5:%.+]] = vector.insert [[FLOAT5]], [[W4]] [8] : f32 into vector<9xf32> // CHECK: return [[W5]] func.func @vector_ext_long(%v: vector<9xf8E4M3FNUZ>) -> vector<9xf32> { %w = arith.extf %v : vector<9xf8E4M3FNUZ> to vector<9xf32> @@ -142,28 +139,29 @@ func.func @vector_trunc_long_2d(%v: vector<1x9xf32>) -> vector<1x9xf8E4M3FNUZ> { // ----- // CHECK-LABEL: func.func @vector_ext_long_2d -// CHECK-SAME: ([[V:%.+]]: vector<1x9xf8E4M3FNUZ>) -// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf32> -// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x9xf8E4M3FNUZ> to vector<9xf8E4M3FNUZ> +// CHECK-SAME: ([[V:%.+]]: vector<1x11xf8E4M3FNUZ>) +// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<11xf32> +// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x11xf8E4M3FNUZ> to vector<11xf8E4M3FNUZ> // CHECK: [[V0:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [0], sizes = [4], strides = [1]} // CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] : vector<4xf8E4M3FNUZ> to vector<2xf32> -// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[F0]], [[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<9xf32> +// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[F0]], [[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<11xf32> // CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32> -// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[F1]], [[W0]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<9xf32> +// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[F1]], [[W0]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<11xf32> -// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ> +// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<11xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ> // CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] : vector<4xf8E4M3FNUZ> to vector<2xf32> -// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[F2]], [[W1]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<9xf32> +// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[F2]], [[W1]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<11xf32> // CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32> -// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[F3]], [[W2]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32> - -// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ> -// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] : vector<1xf8E4M3FNUZ> to vector<2xf32> -// CHECK: [[E0:%.+]] = vector.extract [[F4]][0] : f32 from vector<2xf32> -// CHECK: [[W4:%.+]] = vector.insert [[E0]], [[W3]] [8] : f32 into vector<9xf32> -// CHECK: [[CAST:%.+]] = vector.shape_cast [[W4]] : vector<9xf32> to vector<1x9xf32> +// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[F3]], [[W2]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<11xf32> + +// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [3], strides = [1]} : vector<11xf8E4M3FNUZ> to vector<3xf8E4M3FNUZ> +// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] : vector<3xf8E4M3FNUZ> to vector<2xf32> +// CHECK: [[W4:%.+]] = vector.insert_strided_slice [[F4]], [[W3]] {offsets = [8], strides = [1]} : vector<2xf32> into vector<11xf32> +// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V2]][2] : vector<3xf8E4M3FNUZ> to f32 +// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]] [10] : f32 into vector<11xf32> +// CHECK: [[CAST:%.+]] = vector.shape_cast [[W5]] : vector<11xf32> to vector<1x11xf32> // CHECK: return [[CAST]] -func.func @vector_ext_long_2d(%v: vector<1x9xf8E4M3FNUZ>) -> vector<1x9xf32> { - %w = arith.extf %v : vector<1x9xf8E4M3FNUZ> to vector<1x9xf32> - return %w : vector<1x9xf32> +func.func @vector_ext_long_2d(%v: vector<1x11xf8E4M3FNUZ>) -> vector<1x11xf32> { + %w = arith.extf %v : vector<1x11xf8E4M3FNUZ> to vector<1x11xf32> + return %w : vector<1x11xf32> } From 8ef233ffa33881bc401baca20237b18b77ec5010 Mon Sep 17 00:00:00 2001 From: Yi Qian Date: Wed, 19 Mar 2025 05:00:41 +0000 Subject: [PATCH 3/3] Update test cases --- .../Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir | 4 ++-- .../test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir | 4 ++-- mlir/test/Dialect/AMDGPU/ops.mlir | 13 ++++++++++--- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir index eb483b0880294..ea0c3afbd9021 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir @@ -36,11 +36,11 @@ func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> f32 { // CHECK-LABEL: func @ext_full_vec( // CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8> // CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32 -// CHECK: [[C3:%.+]] = llvm.mlir.constant(2 : i32) : i32 +// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32 // CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32 // CHECK: return [[EXT]] : f32 func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> f32 { - %ret = amdgpu.ext_packed_fp8 %v[2] : vector<4xf8E4M3FN> to f32 + %ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FN> to f32 func.return %ret : f32 } diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir index 4029d14650d7f..219f822ca9a1c 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir @@ -35,11 +35,11 @@ func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> f32 { // CHECK-LABEL: func @ext_full_vec // CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FNUZ> to vector<4xi8> // CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32 -// CHECK: [[C3:%.+]] = llvm.mlir.constant(2 : i32) : i32 +// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32 // CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32 // CHECK: return [[EXT]] : f32 func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> f32 { - %ret = amdgpu.ext_packed_fp8 %v[2] : vector<4xf8E4M3FNUZ> to f32 + %ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FNUZ> to f32 func.return %ret : f32 } diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir index bf312ead32712..665674f2a7873 100644 --- a/mlir/test/Dialect/AMDGPU/ops.mlir +++ b/mlir/test/Dialect/AMDGPU/ops.mlir @@ -4,9 +4,16 @@ // Verify the generic form can be parsed. // RUN: mlir-opt -allow-unregistered-dialect -mlir-print-op-generic %s | mlir-opt -allow-unregistered-dialect | FileCheck %s -// CHECK-LABEL: func @ext_packed_fp8 -// CHECK: amdgpu.ext_packed_fp8 -func.func @ext_packed_fp8(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> { +// CHECK-LABEL: func @ext_packed_fp8_s +// CHECK: amdgpu.ext_packed_fp8 {{.*}} vector<4xf8E4M3FNUZ> to f32 +func.func @ext_packed_fp8_s(%v: vector<4xf8E4M3FNUZ>) -> f32 { + %ret = amdgpu.ext_packed_fp8 %v[0] : vector<4xf8E4M3FNUZ> to f32 + func.return %ret : f32 +} + +// CHECK-LABEL: func @ext_packed_fp8_v +// CHECK: amdgpu.ext_packed_fp8 {{.*}} vector<4xf8E4M3FNUZ> to vector<2xf32 +func.func @ext_packed_fp8_v(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> { %ret = amdgpu.ext_packed_fp8 %v[0] : vector<4xf8E4M3FNUZ> to vector<2xf32> func.return %ret : vector<2xf32> }