From f7bd42b79c1f6f09147231abdf8fe96352cb4127 Mon Sep 17 00:00:00 2001 From: David Green Date: Tue, 12 Nov 2024 10:11:13 +0000 Subject: [PATCH] [LoopVectorizer][ARM] Detect reduce(ext(mul(ext, ext))) patterns more reliably. We would detect ext(mul(ext, ext)) patterns when looking up through the tree, but not when looking down. This hopefully brings the cost model closer to the vplan version, avoiding some asserts. --- .../Transforms/Vectorize/LoopVectorize.cpp | 15 +++++++- .../LoopVectorize/ARM/mve-reductions.ll | 38 +++++++++---------- 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 1ebc62f984390..568aeae2260f1 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -5818,6 +5818,15 @@ LoopVectorizationCostModel::getReductionPatternCost( if (match(RetI, m_OneUse(m_Mul(m_Value(), m_Value()))) && RetI->user_back()->getOpcode() == Instruction::Add) { RetI = RetI->user_back(); + } else if (match(RetI, m_OneUse(m_Mul(m_Value(), m_Value()))) && + ((match(I, m_ZExt(m_Value())) && + match(RetI->user_back(), m_OneUse(m_ZExt(m_Value())))) || + (match(I, m_SExt(m_Value())) && + match(RetI->user_back(), m_OneUse(m_SExt(m_Value()))))) && + RetI->user_back()->user_back()->getOpcode() == Instruction::Add) { + // This looks through ext(mul(ext, ext)), making sure that the extensions + // are the same sign. + RetI = RetI->user_back()->user_back(); } // Test if the found instruction is a reduction, and if not return an invalid @@ -7316,7 +7325,7 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF, // Also include the operands of instructions in the chain, as the cost-model // may mark extends as free. // - // For ARM, some of the instruction can folded into the reducion + // For ARM, some of the instructions can be folded into the reduction // instruction. So we need to mark all folded instructions free. // For example: We can fold reduce(mul(ext(A), ext(B))) into one // instruction. @@ -7324,6 +7333,10 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF, for (Value *Op : ChainOp->operands()) { if (auto *I = dyn_cast(Op)) { ChainOpsAndOperands.insert(I); + if (IsZExtOrSExt(I->getOpcode())) { + ChainOpsAndOperands.insert(I); + I = dyn_cast(I->getOperand(0)); + } if (I->getOpcode() == Instruction::Mul) { auto *Ext0 = dyn_cast(I->getOperand(0)); auto *Ext1 = dyn_cast(I->getOperand(1)); diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll index c115c91cff896..a4f96adccb64b 100644 --- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll +++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll @@ -1722,10 +1722,10 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 { ; CHECK-NEXT: [[TMP0:%.*]] = add nsw i32 [[N]], -1 ; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[TMP0]], 1 ; CHECK-NEXT: [[TMP2:%.*]] = add nuw i32 [[TMP1]], 1 -; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 7 +; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 15 ; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]] ; CHECK: vector.ph: -; CHECK-NEXT: [[N_VEC:%.*]] = and i32 [[TMP2]], -4 +; CHECK-NEXT: [[N_VEC:%.*]] = and i32 [[TMP2]], -8 ; CHECK-NEXT: [[IND_END:%.*]] = shl i32 [[N_VEC]], 1 ; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] ; CHECK: vector.body: @@ -1733,26 +1733,26 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 { ; CHECK-NEXT: [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP16:%.*]], [[VECTOR_BODY]] ] ; CHECK-NEXT: [[OFFSET_IDX:%.*]] = shl i32 [[INDEX]], 1 ; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[OFFSET_IDX]] -; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <8 x i16>, ptr [[TMP3]], align 2 -; CHECK-NEXT: [[STRIDED_VEC:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> -; CHECK-NEXT: [[STRIDED_VEC1:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> -; CHECK-NEXT: [[TMP5:%.*]] = sext <4 x i16> [[STRIDED_VEC]] to <4 x i32> +; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <16 x i16>, ptr [[TMP3]], align 2 +; CHECK-NEXT: [[STRIDED_VEC:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> +; CHECK-NEXT: [[STRIDED_VEC1:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> +; CHECK-NEXT: [[TMP5:%.*]] = sext <8 x i16> [[STRIDED_VEC]] to <8 x i32> ; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[OFFSET_IDX]] -; CHECK-NEXT: [[WIDE_VEC2:%.*]] = load <8 x i16>, ptr [[TMP4]], align 2 -; CHECK-NEXT: [[STRIDED_VEC3:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> -; CHECK-NEXT: [[STRIDED_VEC4:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> -; CHECK-NEXT: [[TMP6:%.*]] = sext <4 x i16> [[STRIDED_VEC3]] to <4 x i32> -; CHECK-NEXT: [[TMP7:%.*]] = mul nsw <4 x i32> [[TMP6]], [[TMP5]] -; CHECK-NEXT: [[TMP8:%.*]] = sext <4 x i32> [[TMP7]] to <4 x i64> -; CHECK-NEXT: [[TMP13:%.*]] = sext <4 x i16> [[STRIDED_VEC1]] to <4 x i32> -; CHECK-NEXT: [[TMP14:%.*]] = sext <4 x i16> [[STRIDED_VEC4]] to <4 x i32> -; CHECK-NEXT: [[TMP11:%.*]] = mul nsw <4 x i32> [[TMP14]], [[TMP13]] -; CHECK-NEXT: [[TMP12:%.*]] = sext <4 x i32> [[TMP11]] to <4 x i64> -; CHECK-NEXT: [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP8]]) +; CHECK-NEXT: [[WIDE_VEC2:%.*]] = load <16 x i16>, ptr [[TMP4]], align 2 +; CHECK-NEXT: [[STRIDED_VEC3:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> +; CHECK-NEXT: [[STRIDED_VEC4:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> +; CHECK-NEXT: [[TMP6:%.*]] = sext <8 x i16> [[STRIDED_VEC3]] to <8 x i32> +; CHECK-NEXT: [[TMP7:%.*]] = mul nsw <8 x i32> [[TMP6]], [[TMP5]] +; CHECK-NEXT: [[TMP8:%.*]] = sext <8 x i32> [[TMP7]] to <8 x i64> +; CHECK-NEXT: [[TMP13:%.*]] = sext <8 x i16> [[STRIDED_VEC1]] to <8 x i32> +; CHECK-NEXT: [[TMP14:%.*]] = sext <8 x i16> [[STRIDED_VEC4]] to <8 x i32> +; CHECK-NEXT: [[TMP11:%.*]] = mul nsw <8 x i32> [[TMP14]], [[TMP13]] +; CHECK-NEXT: [[TMP12:%.*]] = sext <8 x i32> [[TMP11]] to <8 x i64> +; CHECK-NEXT: [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP8]]) ; CHECK-NEXT: [[TMP10:%.*]] = add i64 [[TMP9]], [[VEC_PHI]] -; CHECK-NEXT: [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP12]]) +; CHECK-NEXT: [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP12]]) ; CHECK-NEXT: [[TMP16]] = add i64 [[TMP15]], [[TMP10]] -; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4 +; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8 ; CHECK-NEXT: [[TMP17:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]] ; CHECK-NEXT: br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP37:![0-9]+]] ; CHECK: middle.block: