-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[WebAssembly] [Backend] Combine and(X, shuffle(X, pow 2 mask)) to all true #145108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
For the generic DAG nodes like ISD::AND you need to tell SelectionDAG that your target wants to perform custom combines on them by calling |
it works now! thank you! |
This shows that reduceand is not well-optimized in WebAssembly. The long chain of shuffle should be turned to all_true.
86f26c1
to
430a54b
Compare
@llvm/pr-subscribers-backend-webassembly Author: jjasmine (badumbatish) ChangesI'm hooking up dagcombine for AND(AND(AND(...), SHUFFLE(...)), SHUFFLE(...)) to reduce it to all_true. SDValue
WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
// N->print(llvm::errs());
// std::cout << "\n"; llvm ir Godbolt link: https://godbolt.org/z/qYEvPn1KW Local input to selection dag: Initial selection DAG: %bb.0 'bar:entry'
SelectionDAG has 21 nodes:
t2: v4i32 = WebAssemblyISD::ARGUMENT TargetConstant:i32<0>
t3: v16i8 = bitcast t2
t5: v16i8 = BUILD_VECTOR Constant:i8<0>, Constant:i8<0>, Constant:i8<0>, Constant:i8<0>, Constant:i8<0>, Constant:i8<0>, Constant:i8<0>, Constant:i8<0>, Constant:i8<0>, Constant:i8<0>, Constant:i8<0>, Constant:i8<0>, Constant:i8<0>, Constant:i8<0>, Constant:i8<0>, Constant:i8<0>
t7: v16i1 = setcc t3, t5, setne:ch
t8: v16i8 = sign_extend t7
t9: v4i32 = bitcast t8
t11: v4i32 = vector_shuffle<2,3,u,u> t9, poison:v4i32
t12: v4i32 = and t9, t11
t0: ch,glue = EntryToken
t13: v4i32 = vector_shuffle<1,u,u,u> t12, poison:v4i32
t14: v4i32 = and t12, t13
t17: i32 = extract_vector_elt t14, Constant:i64<0>
t18: i1 = setcc t17, Constant:i32<0>, setne:ch
t19: i32 = zero_extend t18
t20: ch = WebAssemblyISD::RETURN t0, t19
Combining: t20: ch = WebAssemblyISD::RETURN t0, t19
Combining: t19: i32 = zero_extend t18
Creating constant: t21: i1 = Constant<-1>
Creating constant: t22: i1 = Constant<0>
Combining: t18: i1 = setcc t17, Constant:i32<0>, setne:ch
Combining: t17: i32 = extract_vector_elt t14, Constant:i64<0>
Combining: t16: i64 = Constant<0>
Combining: t15: i32 = Constant<0>
Combining: t14: v4i32 = and t12, t13
Combining: t13: v4i32 = vector_shuffle<1,u,u,u> t12, poison:v4i32
Combining: t12: v4i32 = and t9, t11
Combining: t11: v4i32 = vector_shuffle<2,3,u,u> t9, poison:v4i32
Combining: t10: v4i32 = poison
Combining: t9: v4i32 = bitcast t8
Combining: t8: v16i8 = sign_extend t7
Creating new node: t23: v16i8 = setcc t3, t5, setne:ch
... into: t23: v16i8 = setcc t3, t5, setne:ch
Combining: t23: v16i8 = setcc t3, t5, setne:ch
Combining: t9: v4i32 = bitcast t23
Combining: t6: ch = setne
Combining: t5: v16i8 = BUILD_VECTOR Constant:i8<0>, Constant:i8<0>, Constant:i8<0>, Constant:i8<0>, Constant:i8<0>, Constant:i8<0>, Constant:i8<0>, Constant:i8<0>, Constant:i8<0>, Constant:i8<0>, Constant:i8<0>, Constant:i8<0>, Constant:i8<0>, Constant:i8<0>, Constant:i8<0>, Constant:i8<0>
Creating new node: t24: v16i8 = splat_vector Constant:i8<0>
... into: t24: v16i8 = splat_vector Constant:i8<0>
Combining: t24: v16i8 = splat_vector Constant:i8<0>
Combining: t23: v16i8 = setcc t3, t24, setne:ch
Combining: t4: i8 = Constant<0>
Combining: t3: v16i8 = bitcast t2
Combining: t2: v4i32 = WebAssemblyISD::ARGUMENT TargetConstant:i32<0>
Combining: t1: i32 = TargetConstant<0>
Combining: t0: ch,glue = EntryToken
Optimized lowered selection DAG: %bb.0 'bar:entry'
SelectionDAG has 20 nodes:
t2: v4i32 = WebAssemblyISD::ARGUMENT TargetConstant:i32<0>
t3: v16i8 = bitcast t2
t24: v16i8 = splat_vector Constant:i8<0>
t23: v16i8 = setcc t3, t24, setne:ch
t9: v4i32 = bitcast t23
t11: v4i32 = vector_shuffle<2,3,u,u> t9, poison:v4i32
t12: v4i32 = and t9, t11
t0: ch,glue = EntryToken
t13: v4i32 = vector_shuffle<1,u,u,u> t12, poison:v4i32
t14: v4i32 = and t12, t13
t17: i32 = extract_vector_elt t14, Constant:i64<0>
t18: i1 = setcc t17, Constant:i32<0>, setne:ch
t19: i32 = zero_extend t18
t20: ch = WebAssemblyISD::RETURN t0, t19 Full diff: https://github.com/llvm/llvm-project/pull/145108.diff 2 Files Affected:
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index 3cd923c0ba058..d9c2f789e2248 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -18,12 +18,14 @@
#include "WebAssemblySubtarget.h"
#include "WebAssemblyTargetMachine.h"
#include "WebAssemblyUtilities.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/CodeGen/CallingConvLower.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineJumpTableInfo.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/SDPatternMatch.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
#include "llvm/IR/DiagnosticInfo.h"
@@ -184,6 +186,10 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
// Combine partial.reduce.add before legalization gets confused.
setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
+ // Combine EXTRACT VECTOR ELT of AND(AND(X, SHUFFLE(X)), SHUFFLE(...)), 0
+ // to all_true
+ setTargetDAGCombine(ISD::EXTRACT_VECTOR_ELT);
+
// Combine wide-vector muls, with extend inputs, to extmul_half.
setTargetDAGCombine(ISD::MUL);
@@ -3287,6 +3293,85 @@ static SDValue performSETCCCombine(SDNode *N,
return SDValue();
}
+static SmallVector<int> buildMaskArrayByPower(int Power, size_t NumElements) {
+ // Generate 1-index array of elements from 2^Power to 2^(Power+1) exclusive
+ // The rest is filled with -1.
+ //
+ // For example, with NumElements = 4:
+ // When Power = 0: <1 -1 -1 -1>
+ // When Power = 1: <2 3 -1 -1>
+ // When Power = 2: <4 5 6 7>
+
+ uint From = pow(2, Power), To = pow(2, Power + 1);
+ assert(From < NumElements && To <= NumElements);
+
+ SmallVector<int> Res;
+ for (uint I = From; I < To; I++)
+ Res.push_back(I);
+ Res.resize(NumElements, -1);
+
+ return Res;
+}
+static SDValue matchAndOfShuffle(SDNode *N, int Power) {
+ // Matching on the case of
+ //
+ // Base case: A [bitcast for a] setcc(v, <0>, ne).
+ // Recursive case: N = and(X, shuffle(X, power mask)) where X is either
+ // recursive or base case.
+ using namespace llvm::SDPatternMatch;
+
+ EVT VT = N->getValueType(0);
+
+ SDValue LHS = N->getOperand(0);
+ int NumElements = VT.getVectorNumElements();
+ if (NumElements < pow(2, Power))
+ return SDValue();
+
+ if (N->getOpcode() != ISD::AND && NumElements == pow(2, Power)) {
+ SDValue BitCast, Matched;
+
+ // Try for a setcc first.
+ if (sd_match(N, m_c_SetCC(m_Value(Matched), m_Zero(),
+ m_SpecificCondCode(ISD::SETNE))))
+ return Matched;
+
+ // Now try for bitcast
+ if (!sd_match(N, m_BitCast(m_Value(BitCast))))
+ return SDValue();
+
+ if (!sd_match(BitCast, m_c_SetCC(m_Value(Matched), m_Zero(),
+ m_SpecificCondCode(ISD::SETNE))))
+ return SDValue();
+ return Matched;
+ }
+
+ SmallVector<int> PowerIndices = buildMaskArrayByPower(Power, NumElements);
+ if (sd_match(N, m_And(m_Value(LHS),
+ m_Shuffle(m_Value(LHS), m_VectorVT(m_Opc(ISD::POISON)),
+ m_SpecificMask(PowerIndices)))))
+ return matchAndOfShuffle(LHS.getNode(), Power + 1);
+
+ return SDValue();
+}
+static SDValue performExtractVecEltCombine(SDNode *N, SelectionDAG &DAG) {
+ using namespace llvm::SDPatternMatch;
+
+ assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT);
+ SDLoc DL(N);
+
+ SDValue And;
+ if (!sd_match(N, m_ExtractElt(m_VectorVT(m_Value(And)), m_Zero())))
+ return SDValue();
+
+ if (SDValue Matched = matchAndOfShuffle(And.getNode(), 0))
+ return DAG.getZExtOrTrunc(
+ DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, MVT::i32,
+ {DAG.getConstant(Intrinsic::wasm_alltrue, DL, MVT::i32), Matched}),
+ DL, N->getValueType(0));
+
+ return SDValue();
+}
static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG) {
assert(N->getOpcode() == ISD::MUL);
@@ -3402,6 +3487,8 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
return performTruncateCombine(N, DCI);
case ISD::INTRINSIC_WO_CHAIN:
return performLowerPartialReduction(N, DCI.DAG);
+ case ISD::EXTRACT_VECTOR_ELT:
+ return performExtractVecEltCombine(N, DCI.DAG);
case ISD::MUL:
return performMulCombine(N, DCI.DAG);
}
diff --git a/llvm/test/CodeGen/WebAssembly/simd-reduceand.ll b/llvm/test/CodeGen/WebAssembly/simd-reduceand.ll
new file mode 100644
index 0000000000000..f494691941b64
--- /dev/null
+++ b/llvm/test/CodeGen/WebAssembly/simd-reduceand.ll
@@ -0,0 +1,47 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+simd128 | FileCheck %s
+target triple = "wasm64"
+
+define i1 @reduce_and_to_all_true_16i8(<16 x i8> %0) {
+; CHECK-LABEL: reduce_and_to_all_true_16i8:
+; CHECK: .functype reduce_and_to_all_true_16i8 (v128) -> (i32)
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: i8x16.all_true $push0=, $0
+; CHECK-NEXT: return $pop0
+ %2 = icmp ne <16 x i8> %0, zeroinitializer
+ %3 = sext <16 x i1> %2 to <16 x i8>
+ %4 = bitcast <16 x i8> %3 to <4 x i32>
+ %5 = tail call i32 @llvm.vector.reduce.and.v4i32(<4 x i32> %4)
+ %6 = icmp ne i32 %5, 0
+ ret i1 %6
+}
+
+
+define i1 @reduce_and_to_all_true_4i32(<4 x i32> %0) {
+; CHECK-LABEL: reduce_and_to_all_true_4i32:
+; CHECK: .functype reduce_and_to_all_true_4i32 (v128) -> (i32)
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: i32x4.all_true $push0=, $0
+; CHECK-NEXT: return $pop0
+ %2 = icmp ne <4 x i32> %0, zeroinitializer
+ %3 = sext <4 x i1> %2 to <4 x i32>
+ %4 = tail call i32 @llvm.vector.reduce.and.v4i32(<4 x i32> %3)
+ %5 = icmp ne i32 %4, 0
+ ret i1 %5
+}
+
+
+
+define i1 @reduce_and_to_all_true_2i64(<2 x i64> %0) {
+; CHECK-LABEL: reduce_and_to_all_true_2i64:
+; CHECK: .functype reduce_and_to_all_true_2i64 (v128) -> (i32)
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: i32x4.all_true $push0=, $0
+; CHECK-NEXT: return $pop0
+ %2 = bitcast <2 x i64> %0 to <4 x i32>
+ %3 = icmp ne <4 x i32> %2, zeroinitializer
+ %4 = sext <4 x i1> %3 to <4 x i32>
+ %5 = tail call i32 @llvm.vector.reduce.and.v4i32(<4 x i32> %4)
+ %6 = icmp ne i32 %5, 0
+ ret i1 %6
+}
|
430a54b
to
75e133f
Compare
Combine N = and(X, shuffle_vector(X, power of 2 mask)) to all true. Where X is either N or setcc(v, <0>, ne) or a bitcast of said setcc.
75e133f
to
9c7ce61
Compare
define i1 @reduce_and_to_all_true_16i8(<16 x i8> %0) { | ||
; CHECK-LABEL: reduce_and_to_all_true_16i8: | ||
; CHECK: .functype reduce_and_to_all_true_16i8 (v128) -> (i32) | ||
; CHECK-NEXT: # %bb.0: | ||
; CHECK-NEXT: i8x16.all_true $push0=, $0 | ||
; CHECK-NEXT: return $pop0 | ||
%2 = icmp ne <16 x i8> %0, zeroinitializer | ||
%3 = sext <16 x i1> %2 to <16 x i8> | ||
%4 = bitcast <16 x i8> %3 to <4 x i32> | ||
%5 = tail call i32 @llvm.vector.reduce.and.v4i32(<4 x i32> %4) | ||
%6 = icmp ne i32 %5, 0 | ||
ret i1 %6 | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This IR is coming from https://godbolt.org/z/YMo1qqccT right? I'm surprised that we end up with v4i32 from the v16i8 type in the C. Do you know where this is being introduced? Perhaps the easiest fix here is to try and keep it in v16i8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, I'm not sure where it comes from but if i remove the bitcast, then it'll produce i8x16 all_true but generates a few more lines. Will investigate more
define i1 @reduce_and_to_all_true_16i8(<16 x i8> %0) {
; CHECK-LABEL: reduce_and_to_all_true_16i8:
; CHECK: .functype reduce_and_to_all_true_16i8 (v128) -> (i32)
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: i8x16.all_true $push0=, $0
; CHECK-NEXT: i32.const $push1=, 255
; CHECK-NEXT: i32.and $push2=, $pop0, $pop1
; CHECK-NEXT: i32.const $push3=, 0
; CHECK-NEXT: i32.ne $push4=, $pop2, $pop3
; CHECK-NEXT: return $pop4
%2 = icmp ne <16 x i8> %0, zeroinitializer
%3 = sext <16 x i1> %2 to <16 x i8>
%4 = tail call i8 @llvm.vector.reduce.and.v8i16(<16 x i8> %3)
%5 = icmp ne i8 %4, 0
ret i1 %5
}
From the godbolt initially reported in #129441, if we add explicit casts we get much more sensible LLVM IR with the correct type: https://godbolt.org/z/4Gf7xaf3x I.e. bool bar(__i16x8 a) {
__i16x8 zero = wasm_i8x16_splat(0);
return __builtin_reduce_and((__i16x8)wasm_i8x16_ne(a, zero));
} Gives define hidden zeroext i1 @bar(<8 x i16> noundef %a) local_unnamed_addr #0 !dbg !26 {
entry:
#dbg_value(<8 x i16> %a, !32, !DIExpression(), !34)
#dbg_value(<4 x i32> zeroinitializer, !33, !DIExpression(), !34)
%0 = bitcast <8 x i16> %a to <16 x i8>, !dbg !35
%cmp.i = icmp ne <16 x i8> %0, zeroinitializer, !dbg !35
%sext.i = sext <16 x i1> %cmp.i to <16 x i8>, !dbg !35
%1 = bitcast <16 x i8> %sext.i to <8 x i16>, !dbg !36
%rdx.and = tail call i16 @llvm.vector.reduce.and.v8i16(<8 x i16> %1), !dbg !37
%tobool = icmp ne i16 %rdx.and, 0, !dbg !37
ret i1 %tobool, !dbg !38
} However this still doesn't emit all_true and the vector.reduce.and gets expanded to a bunch of shuffles in ExpandReductions.cpp. I think there's two issues here: the intrinsics shouldn't be casting everything to v4i32 presumably? And separately we should be able to emit all_true for the LLVM IR above. Can we change this PR so that it solves the second issue, i.e. handling the above test case by implementing |
Combine and(X, shuffle(X, pow 2 mask)) to all true
Combine N = and(X, shuffle_vector(X, power of 2 mask)) to all true.
Where X is either N or setcc(v, <0>, ne) or a bitcast of said setcc.
Past edits:
I'm hooking up dagcombine for AND(AND(AND(...), SHUFFLE(...)), SHUFFLE(...)) to reduce it to all_true.
I'm unsure why the hook for AND is not triggered when there are AND nodes as input to SelectionDAG. No nodes showed up as AND in:
llvm ir Godbolt link: https://godbolt.org/z/qYEvPn1KW
Local input to selection dag: