diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp index 3cd923c0ba058..9a77a2ccfc989 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -18,12 +18,14 @@ #include "WebAssemblySubtarget.h" #include "WebAssemblyTargetMachine.h" #include "WebAssemblyUtilities.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineJumpTableInfo.h" #include "llvm/CodeGen/MachineModuleInfo.h" #include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/SDPatternMatch.h" #include "llvm/CodeGen/SelectionDAG.h" #include "llvm/CodeGen/SelectionDAGNodes.h" #include "llvm/IR/DiagnosticInfo.h" @@ -184,6 +186,10 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering( // Combine partial.reduce.add before legalization gets confused. setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN); + // Combine EXTRACT VECTOR ELT of AND(AND(X, SHUFFLE(X)), SHUFFLE(...)), 0 + // to all_true + setTargetDAGCombine(ISD::EXTRACT_VECTOR_ELT); + // Combine wide-vector muls, with extend inputs, to extmul_half. setTargetDAGCombine(ISD::MUL); @@ -3287,6 +3293,87 @@ static SDValue performSETCCCombine(SDNode *N, return SDValue(); } +static SmallVector buildMaskArrayByPower(unsigned FromPower, + unsigned NumElements) { + // Generate 1-index array of elements from 2^Power to 2^(Power+1) exclusive + // The rest is filled with -1. + // + // For example, with NumElements = 4: + // When Power = 1: <1 -1 -1 -1> + // When Power = 2: <2 3 -1 -1> + // When Power = 4: <4 5 6 7> + assert(FromPower <= 256); + unsigned ToPower = NextPowerOf2(FromPower); + assert(FromPower < NumElements && ToPower <= NumElements); + + SmallVector Res; + for (unsigned I = FromPower; I < ToPower; I++) + Res.push_back(I); + Res.resize(NumElements, -1); + + return Res; +} +static SDValue matchAndOfShuffle(SDNode *N, int Power = 1) { + // Matching on the case of + // + // Base case: A [bitcast for a] setcc(v, <0>, ne). + // Recursive case: N = and(X, shuffle(X, power mask)) where X is either + // recursive or base case. + using namespace llvm::SDPatternMatch; + + EVT VT = N->getValueType(0); + + SDValue LHS = N->getOperand(0); + int NumElements = VT.getVectorNumElements(); + + if (NumElements < Power) + return SDValue(); + + if (N->getOpcode() != ISD::AND && NumElements == Power) { + SDValue BitCast, Matched; + + // Try for a setcc first. + if (sd_match(N, m_c_SetCC(m_Value(Matched), m_Zero(), + m_SpecificCondCode(ISD::SETNE)))) + return Matched; + + // Now try for bitcast + if (!sd_match(N, m_BitCast(m_Value(BitCast)))) + return SDValue(); + + if (!sd_match(BitCast, m_c_SetCC(m_Value(Matched), m_Zero(), + m_SpecificCondCode(ISD::SETNE)))) + return SDValue(); + return Matched; + } + + SmallVector PowerIndices = buildMaskArrayByPower(Power, NumElements); + if (sd_match(N, m_And(m_Value(LHS), + m_Shuffle(m_Value(LHS), m_VectorVT(m_Opc(ISD::POISON)), + m_SpecificMask(PowerIndices))))) + return matchAndOfShuffle(LHS.getNode(), NextPowerOf2(Power)); + + return SDValue(); +} +static SDValue performExtractVecEltCombine(SDNode *N, SelectionDAG &DAG) { + using namespace llvm::SDPatternMatch; + + assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT); + SDLoc DL(N); + + SDValue And; + if (!sd_match(N, m_ExtractElt(m_VectorVT(m_Value(And)), m_Zero()))) + return SDValue(); + + if (SDValue Matched = matchAndOfShuffle(And.getNode())) + return DAG.getZExtOrTrunc( + DAG.getNode( + ISD::INTRINSIC_WO_CHAIN, DL, MVT::i32, + {DAG.getConstant(Intrinsic::wasm_alltrue, DL, MVT::i32), Matched}), + DL, N->getValueType(0)); + + return SDValue(); +} static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG) { assert(N->getOpcode() == ISD::MUL); @@ -3402,6 +3489,8 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N, return performTruncateCombine(N, DCI); case ISD::INTRINSIC_WO_CHAIN: return performLowerPartialReduction(N, DCI.DAG); + case ISD::EXTRACT_VECTOR_ELT: + return performExtractVecEltCombine(N, DCI.DAG); case ISD::MUL: return performMulCombine(N, DCI.DAG); } diff --git a/llvm/test/CodeGen/WebAssembly/simd-reduceand.ll b/llvm/test/CodeGen/WebAssembly/simd-reduceand.ll new file mode 100644 index 0000000000000..f494691941b64 --- /dev/null +++ b/llvm/test/CodeGen/WebAssembly/simd-reduceand.ll @@ -0,0 +1,47 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc < %s -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+simd128 | FileCheck %s +target triple = "wasm64" + +define i1 @reduce_and_to_all_true_16i8(<16 x i8> %0) { +; CHECK-LABEL: reduce_and_to_all_true_16i8: +; CHECK: .functype reduce_and_to_all_true_16i8 (v128) -> (i32) +; CHECK-NEXT: # %bb.0: +; CHECK-NEXT: i8x16.all_true $push0=, $0 +; CHECK-NEXT: return $pop0 + %2 = icmp ne <16 x i8> %0, zeroinitializer + %3 = sext <16 x i1> %2 to <16 x i8> + %4 = bitcast <16 x i8> %3 to <4 x i32> + %5 = tail call i32 @llvm.vector.reduce.and.v4i32(<4 x i32> %4) + %6 = icmp ne i32 %5, 0 + ret i1 %6 +} + + +define i1 @reduce_and_to_all_true_4i32(<4 x i32> %0) { +; CHECK-LABEL: reduce_and_to_all_true_4i32: +; CHECK: .functype reduce_and_to_all_true_4i32 (v128) -> (i32) +; CHECK-NEXT: # %bb.0: +; CHECK-NEXT: i32x4.all_true $push0=, $0 +; CHECK-NEXT: return $pop0 + %2 = icmp ne <4 x i32> %0, zeroinitializer + %3 = sext <4 x i1> %2 to <4 x i32> + %4 = tail call i32 @llvm.vector.reduce.and.v4i32(<4 x i32> %3) + %5 = icmp ne i32 %4, 0 + ret i1 %5 +} + + + +define i1 @reduce_and_to_all_true_2i64(<2 x i64> %0) { +; CHECK-LABEL: reduce_and_to_all_true_2i64: +; CHECK: .functype reduce_and_to_all_true_2i64 (v128) -> (i32) +; CHECK-NEXT: # %bb.0: +; CHECK-NEXT: i32x4.all_true $push0=, $0 +; CHECK-NEXT: return $pop0 + %2 = bitcast <2 x i64> %0 to <4 x i32> + %3 = icmp ne <4 x i32> %2, zeroinitializer + %4 = sext <4 x i1> %3 to <4 x i32> + %5 = tail call i32 @llvm.vector.reduce.and.v4i32(<4 x i32> %4) + %6 = icmp ne i32 %5, 0 + ret i1 %6 +}