18
18
#include " WebAssemblySubtarget.h"
19
19
#include " WebAssemblyTargetMachine.h"
20
20
#include " WebAssemblyUtilities.h"
21
+ #include " llvm/ADT/SmallVector.h"
21
22
#include " llvm/CodeGen/CallingConvLower.h"
22
23
#include " llvm/CodeGen/MachineFrameInfo.h"
23
24
#include " llvm/CodeGen/MachineInstrBuilder.h"
24
25
#include " llvm/CodeGen/MachineJumpTableInfo.h"
25
26
#include " llvm/CodeGen/MachineModuleInfo.h"
26
27
#include " llvm/CodeGen/MachineRegisterInfo.h"
28
+ #include " llvm/CodeGen/SDPatternMatch.h"
27
29
#include " llvm/CodeGen/SelectionDAG.h"
28
30
#include " llvm/CodeGen/SelectionDAGNodes.h"
29
31
#include " llvm/IR/DiagnosticInfo.h"
@@ -3214,20 +3216,26 @@ static SDValue performTruncateCombine(SDNode *N,
3214
3216
3215
3217
static SDValue performBitcastCombine (SDNode *N,
3216
3218
TargetLowering::DAGCombinerInfo &DCI) {
3219
+ using namespace llvm ::SDPatternMatch;
3217
3220
auto &DAG = DCI.DAG ;
3218
3221
SDLoc DL (N);
3219
3222
SDValue Src = N->getOperand (0 );
3220
3223
EVT VT = N->getValueType (0 );
3221
3224
EVT SrcVT = Src.getValueType ();
3222
3225
3223
- // bitcast <N x i1> to iN
3226
+ bool Vectorizable = DCI.isBeforeLegalize () && VT.isScalarInteger () &&
3227
+ SrcVT.isFixedLengthVector () &&
3228
+ SrcVT.getScalarType () == MVT::i1;
3229
+
3230
+ if (!Vectorizable)
3231
+ return SDValue ();
3232
+
3233
+ unsigned NumElts = SrcVT.getVectorNumElements ();
3234
+ EVT Width = MVT::getIntegerVT (128 / NumElts);
3235
+
3236
+ // bitcast <N x i1> to iN, where N = 2, 4, 8, 16 (legal)
3224
3237
// ==> bitmask
3225
- if (DCI.isBeforeLegalize () && VT.isScalarInteger () &&
3226
- SrcVT.isFixedLengthVector () && SrcVT.getScalarType () == MVT::i1) {
3227
- unsigned NumElts = SrcVT.getVectorNumElements ();
3228
- if (NumElts != 2 && NumElts != 4 && NumElts != 8 && NumElts != 16 )
3229
- return SDValue ();
3230
- EVT Width = MVT::getIntegerVT (128 / NumElts);
3238
+ if (NumElts == 2 || NumElts == 4 || NumElts == 8 || NumElts == 16 ) {
3231
3239
return DAG.getZExtOrTrunc (
3232
3240
DAG.getNode (ISD::INTRINSIC_WO_CHAIN, DL, MVT::i32 ,
3233
3241
{DAG.getConstant (Intrinsic::wasm_bitmask, DL, MVT::i32 ),
@@ -3236,6 +3244,57 @@ static SDValue performBitcastCombine(SDNode *N,
3236
3244
DL, VT);
3237
3245
}
3238
3246
3247
+ // bitcast <N x i1>(setcc ...) to concat iN, where N = 32 and 64 (illegal)
3248
+ if (NumElts == 32 || NumElts == 64 ) {
3249
+ // Strategy: We will setcc them seperately in v16i1
3250
+ // Bitcast them to i16, extend them to either i32 or i64.
3251
+ // Add them together, shifting left 1 by 1.
3252
+ SDValue Concat, SetCCVector;
3253
+ ISD::CondCode SetCond;
3254
+
3255
+ if (!sd_match (N, m_BitCast (m_c_SetCC (m_Value (Concat),
3256
+ m_VectorVT (m_Value (SetCCVector)),
3257
+ m_CondCode (SetCond)))))
3258
+ return SDValue ();
3259
+ // COMMITTED at this point, SDValue() if match fails.
3260
+ if (Concat.getOpcode () != ISD::CONCAT_VECTORS)
3261
+ return SDValue ();
3262
+ // CHECK IF VECTOR is a constant, i.e all values are the same
3263
+ if (!ISD::isBuildVectorOfConstantSDNodes (SetCCVector.getNode ()))
3264
+ return SDValue ();
3265
+
3266
+ SmallVector<SDValue> Vec;
3267
+ for (SDValue Const : SetCCVector->ops ()) {
3268
+ Vec.push_back (Const);
3269
+ if (Vec.size () >= 16 )
3270
+ break ;
3271
+ }
3272
+
3273
+ // Build our own version of splat Vector.
3274
+ SDValue SplitSetCCVec = DAG.getBuildVector (MVT::v16i8, DL, Vec);
3275
+
3276
+ SmallVector<SDValue> VectorsToShuffle;
3277
+ for (SDValue V : Concat->ops ())
3278
+ VectorsToShuffle.push_back (DAG.getBitcast (
3279
+ MVT::i16 , DAG.getSetCC (DL, MVT::v16i1, V, SplitSetCCVec, SetCond)));
3280
+
3281
+ MVT ReturnType = VectorsToShuffle.size () == 2 ? MVT::i32 : MVT::i64 ;
3282
+ SDValue ReturningInteger = DAG.getConstant (0 , DL, ReturnType);
3283
+
3284
+ for (SDValue V : VectorsToShuffle) {
3285
+ ReturningInteger = DAG.getNode (
3286
+ ISD::SHL, DL, ReturnType,
3287
+ {DAG.getShiftAmountConstant (16 , ReturnType, DL), ReturningInteger});
3288
+
3289
+ SDValue ExtendedV = DAG.getZExtOrTrunc (V, DL, ReturnType);
3290
+ ReturningInteger =
3291
+ DAG.getNode (ISD::ADD, DL, ReturnType, {ReturningInteger, ExtendedV});
3292
+ }
3293
+
3294
+ // ReturningInteger->print(llvm::errs());
3295
+ return ReturningInteger;
3296
+ }
3297
+
3239
3298
return SDValue ();
3240
3299
}
3241
3300
0 commit comments