@@ -32106,6 +32106,15 @@ bool GenTree::CanDivOrModPossiblyOverflow(Compiler* comp) const
3210632106 return true;
3210732107}
3210832108
32109+ //------------------------------------------------------------------------
32110+ // gtFoldExprHWIntrinsic: Attempt to fold a HWIntrinsic
32111+ //
32112+ // Arguments:
32113+ // tree - HWIntrinsic to fold
32114+ //
32115+ // Return Value:
32116+ // folded expression if it could be folded, else the original tree
32117+ //
3210932118#if defined(FEATURE_HW_INTRINSICS)
3211032119GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3211132120{
@@ -32249,7 +32258,8 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3224932258 // We shouldn't find AND_NOT nodes since it should only be produced in lowering
3225032259 assert(oper != GT_AND_NOT);
3225132260
32252- #if defined(FEATURE_MASKED_HW_INTRINSICS) && defined(TARGET_XARCH)
32261+ #ifdef FEATURE_MASKED_HW_INTRINSICS
32262+ #ifdef TARGET_XARCH
3225332263 if (GenTreeHWIntrinsic::OperIsBitwiseHWIntrinsic(oper))
3225432264 {
3225532265 // Comparisons that produce masks lead to more verbose trees than
@@ -32367,7 +32377,75 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3236732377 }
3236832378 }
3236932379 }
32370- #endif // FEATURE_MASKED_HW_INTRINSICS && TARGET_XARCH
32380+ #elif defined(TARGET_ARM64)
32381+ // Check if the tree can be folded into a mask variant
32382+ if (HWIntrinsicInfo::HasAllMaskVariant(tree->GetHWIntrinsicId()))
32383+ {
32384+ NamedIntrinsic maskVariant = HWIntrinsicInfo::GetMaskVariant(tree->GetHWIntrinsicId());
32385+
32386+ assert(opCount == (size_t)HWIntrinsicInfo::lookupNumArgs(maskVariant));
32387+
32388+ // Check all operands are valid
32389+ bool canFold = true;
32390+ if (ni == NI_Sve_ConditionalSelect)
32391+ {
32392+ assert(varTypeIsMask(op1));
32393+ canFold = (op2->OperIsConvertMaskToVector() && op3->OperIsConvertMaskToVector());
32394+ }
32395+ else
32396+ {
32397+ for (size_t i = 1; i <= opCount && canFold; i++)
32398+ {
32399+ canFold &= tree->Op(i)->OperIsConvertMaskToVector();
32400+ }
32401+ }
32402+
32403+ if (canFold)
32404+ {
32405+ // Convert all the operands to masks
32406+ for (size_t i = 1; i <= opCount; i++)
32407+ {
32408+ if (tree->Op(i)->OperIsConvertMaskToVector())
32409+ {
32410+ // Replace with op1.
32411+ tree->Op(i) = tree->Op(i)->AsHWIntrinsic()->Op(1);
32412+ }
32413+ else if (tree->Op(i)->IsVectorZero())
32414+ {
32415+ // Replace the vector of zeroes with a mask of zeroes.
32416+ tree->Op(i) = gtNewSimdFalseMaskByteNode();
32417+ tree->Op(i)->SetMorphed(this);
32418+ }
32419+ assert(varTypeIsMask(tree->Op(i)));
32420+ }
32421+
32422+ // Switch to the mask variant
32423+ switch (opCount)
32424+ {
32425+ case 1:
32426+ tree->ResetHWIntrinsicId(maskVariant, tree->Op(1));
32427+ break;
32428+ case 2:
32429+ tree->ResetHWIntrinsicId(maskVariant, tree->Op(1), tree->Op(2));
32430+ break;
32431+ case 3:
32432+ tree->ResetHWIntrinsicId(maskVariant, this, tree->Op(1), tree->Op(2), tree->Op(3));
32433+ break;
32434+ default:
32435+ unreached();
32436+ }
32437+
32438+ tree->gtType = TYP_MASK;
32439+ tree->SetMorphed(this);
32440+ tree = gtNewSimdCvtMaskToVectorNode(retType, tree, simdBaseJitType, simdSize)->AsHWIntrinsic();
32441+ tree->SetMorphed(this);
32442+ op1 = tree->Op(1);
32443+ op2 = nullptr;
32444+ op3 = nullptr;
32445+ }
32446+ }
32447+ #endif // TARGET_ARM64
32448+ #endif // FEATURE_MASKED_HW_INTRINSICS
3237132449
3237232450 GenTree* cnsNode = nullptr;
3237332451 GenTree* otherNode = nullptr;
@@ -33754,7 +33832,7 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3375433832 // op2 = op2 & op1
3375533833 op2->AsVecCon()->EvaluateBinaryInPlace(GT_AND, false, simdBaseType, op1->AsVecCon());
3375633834
33757- // op3 = op2 & ~op1
33835+ // op3 = op3 & ~op1
3375833836 op3->AsVecCon()->EvaluateBinaryInPlace(GT_AND_NOT, false, simdBaseType, op1->AsVecCon());
3375933837
3376033838 // op2 = op2 | op3
@@ -33767,8 +33845,8 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3376733845
3376833846#if defined(TARGET_ARM64)
3376933847 case NI_Sve_ConditionalSelect:
33848+ case NI_Sve_ConditionalSelect_Predicates:
3377033849 {
33771- assert(!varTypeIsMask(retType));
3377233850 assert(varTypeIsMask(op1));
3377333851
3377433852 if (cnsNode != op1)
@@ -33797,10 +33875,11 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3379733875
3379833876 if (op2->IsCnsVec() && op3->IsCnsVec())
3379933877 {
33878+ assert(ni == NI_Sve_ConditionalSelect);
3380033879 assert(op2->gtType == TYP_SIMD16);
3380133880 assert(op3->gtType == TYP_SIMD16);
3380233881
33803- simd16_t op1SimdVal;
33882+ simd16_t op1SimdVal = {} ;
3380433883 EvaluateSimdCvtMaskToVector<simd16_t>(simdBaseType, &op1SimdVal, op1->AsMskCon()->gtSimdMaskVal);
3380533884
3380633885 // op2 = op2 & op1
@@ -33809,7 +33888,7 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3380933888 op1SimdVal);
3381033889 op2->AsVecCon()->gtSimd16Val = result;
3381133890
33812- // op3 = op2 & ~op1
33891+ // op3 = op3 & ~op1
3381333892 result = {};
3381433893 EvaluateBinarySimd<simd16_t>(GT_AND_NOT, false, simdBaseType, &result, op3->AsVecCon()->gtSimd16Val,
3381533894 op1SimdVal);
@@ -33820,6 +33899,30 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3382033899
3382133900 resultNode = op2;
3382233901 }
33902+ else if (op2->IsCnsMsk() && op3->IsCnsMsk())
33903+ {
33904+ assert(ni == NI_Sve_ConditionalSelect_Predicates);
33905+
33906+ // op2 = op2 & op1
33907+ simdmask_t result = {};
33908+ EvaluateBinaryMask<simd16_t>(GT_AND, false, simdBaseType, &result, op2->AsMskCon()->gtSimdMaskVal,
33909+ op1->AsMskCon()->gtSimdMaskVal);
33910+ op2->AsMskCon()->gtSimdMaskVal = result;
33911+
33912+ // op3 = op3 & ~op1
33913+ result = {};
33914+ EvaluateBinaryMask<simd16_t>(GT_AND_NOT, false, simdBaseType, &result,
33915+ op3->AsMskCon()->gtSimdMaskVal, op1->AsMskCon()->gtSimdMaskVal);
33916+ op3->AsMskCon()->gtSimdMaskVal = result;
33917+
33918+ // op2 = op2 | op3
33919+ result = {};
33920+ EvaluateBinaryMask<simd16_t>(GT_OR, false, simdBaseType, &result, op2->AsMskCon()->gtSimdMaskVal,
33921+ op3->AsMskCon()->gtSimdMaskVal);
33922+ op2->AsMskCon()->gtSimdMaskVal = result;
33923+
33924+ resultNode = op2;
33925+ }
3382333926 break;
3382433927 }
3382533928#endif // TARGET_ARM64
0 commit comments