Skip to content

Commit 0597138

Browse files
committed
[WebAssembly] [Backend] Optimize illegal bitmask
1 parent 9dd2a14 commit 0597138

File tree

2 files changed

+103
-596
lines changed

2 files changed

+103
-596
lines changed

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818
#include "WebAssemblySubtarget.h"
1919
#include "WebAssemblyTargetMachine.h"
2020
#include "WebAssemblyUtilities.h"
21+
#include "llvm/ADT/SmallVector.h"
2122
#include "llvm/CodeGen/CallingConvLower.h"
2223
#include "llvm/CodeGen/MachineFrameInfo.h"
2324
#include "llvm/CodeGen/MachineInstrBuilder.h"
2425
#include "llvm/CodeGen/MachineJumpTableInfo.h"
2526
#include "llvm/CodeGen/MachineModuleInfo.h"
2627
#include "llvm/CodeGen/MachineRegisterInfo.h"
28+
#include "llvm/CodeGen/SDPatternMatch.h"
2729
#include "llvm/CodeGen/SelectionDAG.h"
2830
#include "llvm/CodeGen/SelectionDAGNodes.h"
2931
#include "llvm/IR/DiagnosticInfo.h"
@@ -3214,20 +3216,26 @@ static SDValue performTruncateCombine(SDNode *N,
32143216

32153217
static SDValue performBitcastCombine(SDNode *N,
32163218
TargetLowering::DAGCombinerInfo &DCI) {
3219+
using namespace llvm::SDPatternMatch;
32173220
auto &DAG = DCI.DAG;
32183221
SDLoc DL(N);
32193222
SDValue Src = N->getOperand(0);
32203223
EVT VT = N->getValueType(0);
32213224
EVT SrcVT = Src.getValueType();
32223225

3223-
// bitcast <N x i1> to iN
3226+
bool Vectorizable = DCI.isBeforeLegalize() && VT.isScalarInteger() &&
3227+
SrcVT.isFixedLengthVector() &&
3228+
SrcVT.getScalarType() == MVT::i1;
3229+
3230+
if (!Vectorizable)
3231+
return SDValue();
3232+
3233+
unsigned NumElts = SrcVT.getVectorNumElements();
3234+
EVT Width = MVT::getIntegerVT(128 / NumElts);
3235+
3236+
// bitcast <N x i1> to iN, where N = 2, 4, 8, 16 (legal)
32243237
// ==> bitmask
3225-
if (DCI.isBeforeLegalize() && VT.isScalarInteger() &&
3226-
SrcVT.isFixedLengthVector() && SrcVT.getScalarType() == MVT::i1) {
3227-
unsigned NumElts = SrcVT.getVectorNumElements();
3228-
if (NumElts != 2 && NumElts != 4 && NumElts != 8 && NumElts != 16)
3229-
return SDValue();
3230-
EVT Width = MVT::getIntegerVT(128 / NumElts);
3238+
if (NumElts == 2 || NumElts == 4 || NumElts == 8 || NumElts == 16) {
32313239
return DAG.getZExtOrTrunc(
32323240
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::i32,
32333241
{DAG.getConstant(Intrinsic::wasm_bitmask, DL, MVT::i32),
@@ -3236,6 +3244,57 @@ static SDValue performBitcastCombine(SDNode *N,
32363244
DL, VT);
32373245
}
32383246

3247+
// bitcast <N x i1>(setcc ...) to concat iN, where N = 32 and 64 (illegal)
3248+
if (NumElts == 32 || NumElts == 64) {
3249+
// Strategy: We will setcc them seperately in v16i1
3250+
// Bitcast them to i16, extend them to either i32 or i64.
3251+
// Add them together, shifting left 1 by 1.
3252+
SDValue Concat, SetCCVector;
3253+
ISD::CondCode SetCond;
3254+
3255+
if (!sd_match(N, m_BitCast(m_c_SetCC(m_Value(Concat),
3256+
m_VectorVT(m_Value(SetCCVector)),
3257+
m_CondCode(SetCond)))))
3258+
return SDValue();
3259+
// COMMITTED at this point, SDValue() if match fails.
3260+
if (Concat.getOpcode() != ISD::CONCAT_VECTORS)
3261+
return SDValue();
3262+
// CHECK IF VECTOR is a constant, i.e all values are the same
3263+
if (!ISD::isBuildVectorOfConstantSDNodes(SetCCVector.getNode()))
3264+
return SDValue();
3265+
3266+
SmallVector<SDValue> Vec;
3267+
for (SDValue Const : SetCCVector->ops()) {
3268+
Vec.push_back(Const);
3269+
if (Vec.size() >= 16)
3270+
break;
3271+
}
3272+
3273+
// Build our own version of splat Vector.
3274+
SDValue SplitSetCCVec = DAG.getBuildVector(MVT::v16i8, DL, Vec);
3275+
3276+
SmallVector<SDValue> VectorsToShuffle;
3277+
for (SDValue V : Concat->ops())
3278+
VectorsToShuffle.push_back(DAG.getBitcast(
3279+
MVT::i16, DAG.getSetCC(DL, MVT::v16i1, V, SplitSetCCVec, SetCond)));
3280+
3281+
MVT ReturnType = VectorsToShuffle.size() == 2 ? MVT::i32 : MVT::i64;
3282+
SDValue ReturningInteger = DAG.getConstant(0, DL, ReturnType);
3283+
3284+
for (SDValue V : VectorsToShuffle) {
3285+
ReturningInteger = DAG.getNode(
3286+
ISD::SHL, DL, ReturnType,
3287+
{DAG.getShiftAmountConstant(16, ReturnType, DL), ReturningInteger});
3288+
3289+
SDValue ExtendedV = DAG.getZExtOrTrunc(V, DL, ReturnType);
3290+
ReturningInteger =
3291+
DAG.getNode(ISD::ADD, DL, ReturnType, {ReturningInteger, ExtendedV});
3292+
}
3293+
3294+
// ReturningInteger->print(llvm::errs());
3295+
return ReturningInteger;
3296+
}
3297+
32393298
return SDValue();
32403299
}
32413300

0 commit comments

Comments
 (0)