Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/coreclr/jit/codegenarm64test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5067,7 +5067,7 @@ void CodeGen::genArm64EmitterUnitTestsSve()
theEmitter->emitIns_R_R(INS_sve_ptest, EA_SCALABLE, REG_P2, REG_P14, INS_OPTS_SCALABLE_B); // PTEST <Pg>, <Pn>.B

// IF_SVE_DK_3A
theEmitter->emitIns_R_R_R(INS_sve_cntp, EA_8BYTE, REG_R29, REG_P0, REG_P15,
theEmitter->emitIns_R_R_R(INS_sve_cntp, EA_SCALABLE, REG_R29, REG_P0, REG_P15,
INS_OPTS_SCALABLE_D); // CNTP <Xd>, <Pg>, <Pn>.<T>

// IF_SVE_GE_4A
Expand Down Expand Up @@ -6351,21 +6351,21 @@ void CodeGen::genArm64EmitterUnitTestsSve()
INS_OPTS_SCALABLE_B); // UQSHRNT <Zd>.<T>, <Zn>.<Tb>, #<const>

// IF_SVE_DL_2A
theEmitter->emitIns_R_R(INS_sve_cntp, EA_8BYTE, REG_R0, REG_P0, INS_OPTS_SCALABLE_B,
theEmitter->emitIns_R_R(INS_sve_cntp, EA_SCALABLE, REG_R0, REG_P0, INS_OPTS_SCALABLE_B,
INS_SCALABLE_OPTS_VL_2X); // CNTP <Xd>, <PNn>.<T>, <vl>
theEmitter->emitIns_R_R(INS_sve_cntp, EA_8BYTE, REG_R1, REG_P1, INS_OPTS_SCALABLE_B,
theEmitter->emitIns_R_R(INS_sve_cntp, EA_SCALABLE, REG_R1, REG_P1, INS_OPTS_SCALABLE_B,
INS_SCALABLE_OPTS_VL_4X); // CNTP <Xd>, <PNn>.<T>, <vl>
theEmitter->emitIns_R_R(INS_sve_cntp, EA_8BYTE, REG_R2, REG_P2, INS_OPTS_SCALABLE_H,
theEmitter->emitIns_R_R(INS_sve_cntp, EA_SCALABLE, REG_R2, REG_P2, INS_OPTS_SCALABLE_H,
INS_SCALABLE_OPTS_VL_2X); // CNTP <Xd>, <PNn>.<T>, <vl>
theEmitter->emitIns_R_R(INS_sve_cntp, EA_8BYTE, REG_R3, REG_P3, INS_OPTS_SCALABLE_H,
theEmitter->emitIns_R_R(INS_sve_cntp, EA_SCALABLE, REG_R3, REG_P3, INS_OPTS_SCALABLE_H,
INS_SCALABLE_OPTS_VL_4X); // CNTP <Xd>, <PNn>.<T>, <vl>
theEmitter->emitIns_R_R(INS_sve_cntp, EA_8BYTE, REG_R4, REG_P4, INS_OPTS_SCALABLE_S,
theEmitter->emitIns_R_R(INS_sve_cntp, EA_SCALABLE, REG_R4, REG_P4, INS_OPTS_SCALABLE_S,
INS_SCALABLE_OPTS_VL_2X); // CNTP <Xd>, <PNn>.<T>, <vl>
theEmitter->emitIns_R_R(INS_sve_cntp, EA_8BYTE, REG_R5, REG_P5, INS_OPTS_SCALABLE_S,
theEmitter->emitIns_R_R(INS_sve_cntp, EA_SCALABLE, REG_R5, REG_P5, INS_OPTS_SCALABLE_S,
INS_SCALABLE_OPTS_VL_4X); // CNTP <Xd>, <PNn>.<T>, <vl>
theEmitter->emitIns_R_R(INS_sve_cntp, EA_8BYTE, REG_R6, REG_P6, INS_OPTS_SCALABLE_D,
theEmitter->emitIns_R_R(INS_sve_cntp, EA_SCALABLE, REG_R6, REG_P6, INS_OPTS_SCALABLE_D,
INS_SCALABLE_OPTS_VL_2X); // CNTP <Xd>, <PNn>.<T>, <vl>
theEmitter->emitIns_R_R(INS_sve_cntp, EA_8BYTE, REG_R7, REG_P7, INS_OPTS_SCALABLE_D,
theEmitter->emitIns_R_R(INS_sve_cntp, EA_SCALABLE, REG_R7, REG_P7, INS_OPTS_SCALABLE_D,
INS_SCALABLE_OPTS_VL_4X); // CNTP <Xd>, <PNn>.<T>, <vl>

// IF_SVE_DM_2A
Expand Down
13 changes: 9 additions & 4 deletions src/coreclr/jit/emitarm64sve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2057,6 +2057,7 @@ void emitter::emitInsSve_R_R(instruction ins,
break;

case INS_sve_cntp:
assert(isScalableVectorSize(size));
assert(insOptsScalableStandard(opt));
assert(insScalableOptsWithVectorLength(sopt)); // l
assert(isGeneralRegister(reg1)); // ddddd
Expand Down Expand Up @@ -3918,7 +3919,7 @@ void emitter::emitInsSve_R_R_R(instruction ins,
break;

case INS_sve_cntp:
assert(size == EA_8BYTE);
assert(isScalableVectorSize(size));
assert(isGeneralRegister(reg1)); // ddddd
assert(isPredicateRegister(reg2)); // gggg
assert(isPredicateRegister(reg3)); // NNNN
Expand Down Expand Up @@ -13077,7 +13078,7 @@ void emitter::emitInsSveSanityCheck(instrDesc* id)
break;

case IF_SVE_DK_3A: // ........xx...... ..gggg.NNNNddddd -- SVE predicate count
assert(id->idOpSize() == EA_8BYTE);
assert(isScalableVectorSize(id->idOpSize()));
assert(insOptsScalableStandard(id->idInsOpt()));
assert(isGeneralRegister(id->idReg1())); // ddddd
assert(isPredicateRegister(id->idReg2())); // gggg
Expand Down Expand Up @@ -13338,9 +13339,13 @@ void emitter::emitInsSveSanityCheck(instrDesc* id)
break;

case IF_SVE_DL_2A: // ........xx...... .....l.NNNNddddd -- SVE predicate count (predicate-as-counter)
assert(id->idOpSize() == EA_8BYTE);
assert(insOptsScalableStandard(id->idInsOpt()));
assert(isValidVectorElemsize(optGetSveElemsize(id->idInsOpt()))); // xx
assert(isGeneralRegister(id->idReg1())); // ddddd
assert(isPredicateRegister(id->idReg2())); // NNNN
assert(isScalableVectorSize(id->idOpSize()));
break;

FALLTHROUGH;
case IF_SVE_DO_2A: // ........xx...... .....X.MMMMddddd -- SVE saturating inc/dec register by predicate count
case IF_SVE_DM_2A: // ........xx...... .......MMMMddddd -- SVE inc/dec register by predicate count
assert(insOptsScalableStandard(id->idInsOpt()));
Expand Down
24 changes: 24 additions & 0 deletions src/coreclr/jit/hwintrinsicarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2426,6 +2426,30 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
break;
}

case NI_Sve_GetActiveElementCount:
{
assert(sig->numArgs == 2);

CORINFO_ARG_LIST_HANDLE arg1 = sig->args;
CORINFO_ARG_LIST_HANDLE arg2 = info.compCompHnd->getArgNext(arg1);
var_types argType = TYP_UNKNOWN;
CORINFO_CLASS_HANDLE argClass = NO_CLASS_HANDLE;

argType = JITtype2varType(strip(info.compCompHnd->getArgType(sig, arg2, &argClass)));
op2 = getArgForHWIntrinsic(argType, argClass);
argType = JITtype2varType(strip(info.compCompHnd->getArgType(sig, arg1, &argClass)));
op1 = impPopStack().val;

// Op1 and Op2 are masks. Op1 mask handling will be handled via IsExplicitMaskedOperation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What cannot we handle op2 handling next to the place where we handle op1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that would require an if (id ==NI_Sve_GetActiveElementCount) in hwintrinsic.cpp. I'm trying to avoid adding any Arm specific code in that file.

Copy link
Contributor

@kunalspathak kunalspathak May 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently we have this:

    if (HWIntrinsicInfo::IsExplicitMaskedOperation(intrinsic))
    {
        assert(numArgs > 0);
        GenTree* op1 = retNode->AsHWIntrinsic()->Op(1);
        if (intrinsic == NI_Sve_ConditionalSelect)
        {
            if (op1->IsVectorAllBitsSet() || op1->IsMaskAllBitsSet())
            {
                return retNode->AsHWIntrinsic()->Op(2);
            }
            else if (op1->IsVectorZero())
            {
                return retNode->AsHWIntrinsic()->Op(3);
            }
        }

        if (!varTypeIsMask(op1))
        {
            // Op1 input is a vector. HWInstrinsic requires a mask.
            retNode->AsHWIntrinsic()->Op(1) = gtNewSimdConvertVectorToMaskNode(retType, op1, simdBaseJitType, simdSize);
        }
}

we can have varTypeIsMask(op2) after the op1, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but would still require a NI_Sve_GetActiveElementCount check, as we don't want to convert for other nodes.

Copy link
Contributor

@kunalspathak kunalspathak May 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hhm, are there many intrinsic that will have op2 as mask? we can add assert or something to make sure that it is NI_Sve_GetActiveElementCount.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've simplified this handling. There's now just a small extra bit in IsExplicitMaskedOperation()

if (!varTypeIsMask(op2))
{
op2 = gtNewSimdConvertVectorToMaskNode(retType, op2, simdBaseJitType, simdSize);
}

retNode = gtNewSimdHWIntrinsicNode(retType, op1, op2, intrinsic, simdBaseJitType, simdSize);
break;
}

case NI_Sve_StoreAndZip:
{
assert(sig->numArgs == 3);
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/hwintrinsiclistarm64sve.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ HARDWARE_INTRINSIC(Sve, FusedMultiplyAddNegated,
HARDWARE_INTRINSIC(Sve, FusedMultiplySubtract, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fmls, INS_sve_fmls}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation|HW_Flag_FmaIntrinsic|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(Sve, FusedMultiplySubtractBySelectedScalar, -1, 4, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fmls, INS_sve_fmls}, HW_Category_SIMDByIndexedElement, HW_Flag_Scalable|HW_Flag_HasImmediateOperand|HW_Flag_HasRMWSemantics|HW_Flag_FmaIntrinsic|HW_Flag_LowVectorOperation)
HARDWARE_INTRINSIC(Sve, FusedMultiplySubtractNegated, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fnmls, INS_sve_fnmls}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation|HW_Flag_FmaIntrinsic|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(Sve, GetActiveElementCount, -1, 2, true, {INS_sve_cntp, INS_sve_cntp, INS_sve_cntp, INS_sve_cntp, INS_sve_cntp, INS_sve_cntp, INS_sve_cntp, INS_sve_cntp, INS_sve_cntp, INS_sve_cntp}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_ExplicitMaskedOperation|HW_Flag_SpecialImport)
HARDWARE_INTRINSIC(Sve, LeadingSignCount, -1, -1, false, {INS_sve_cls, INS_invalid, INS_sve_cls, INS_invalid, INS_sve_cls, INS_invalid, INS_sve_cls, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, LeadingZeroCount, -1, -1, false, {INS_sve_clz, INS_sve_clz, INS_sve_clz, INS_sve_clz, INS_sve_clz, INS_sve_clz, INS_sve_clz, INS_sve_clz, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, LoadVector, -1, 2, true, {INS_sve_ld1b, INS_sve_ld1b, INS_sve_ld1h, INS_sve_ld1h, INS_sve_ld1w, INS_sve_ld1w, INS_sve_ld1d, INS_sve_ld1d, INS_sve_ld1w, INS_sve_ld1d}, HW_Category_MemoryLoad, HW_Flag_Scalable|HW_Flag_ExplicitMaskedOperation|HW_Flag_LowMaskedOperation)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1320,6 +1320,69 @@ internal Arm64() { }
public static unsafe Vector<float> FusedMultiplySubtractNegated(Vector<float> minuend, Vector<float> left, Vector<float> right) { throw new PlatformNotSupportedException(); }


/// Count set predicate bits

/// <summary>
/// uint64_t svcntp_b8(svbool_t pg, svbool_t op)
/// CNTP Xresult, Pg, Pop.B
/// </summary>
public static unsafe ulong GetActiveElementCount(Vector<byte> mask, Vector<byte> from) { throw new PlatformNotSupportedException(); }

/// <summary>
/// uint64_t svcntp_b8(svbool_t pg, svbool_t op)
/// CNTP Xresult, Pg, Pop.B
/// </summary>
public static unsafe ulong GetActiveElementCount(Vector<double> mask, Vector<double> from) { throw new PlatformNotSupportedException(); }

/// <summary>
/// uint64_t svcntp_b8(svbool_t pg, svbool_t op)
/// CNTP Xresult, Pg, Pop.B
/// </summary>
public static unsafe ulong GetActiveElementCount(Vector<short> mask, Vector<short> from) { throw new PlatformNotSupportedException(); }

/// <summary>
/// uint64_t svcntp_b8(svbool_t pg, svbool_t op)
/// CNTP Xresult, Pg, Pop.B
/// </summary>
public static unsafe ulong GetActiveElementCount(Vector<int> mask, Vector<int> from) { throw new PlatformNotSupportedException(); }

/// <summary>
/// uint64_t svcntp_b8(svbool_t pg, svbool_t op)
/// CNTP Xresult, Pg, Pop.B
/// </summary>
public static unsafe ulong GetActiveElementCount(Vector<long> mask, Vector<long> from) { throw new PlatformNotSupportedException(); }

/// <summary>
/// uint64_t svcntp_b8(svbool_t pg, svbool_t op)
/// CNTP Xresult, Pg, Pop.B
/// </summary>
public static unsafe ulong GetActiveElementCount(Vector<sbyte> mask, Vector<sbyte> from) { throw new PlatformNotSupportedException(); }

/// <summary>
/// uint64_t svcntp_b8(svbool_t pg, svbool_t op)
/// CNTP Xresult, Pg, Pop.B
/// </summary>
public static unsafe ulong GetActiveElementCount(Vector<float> mask, Vector<float> from) { throw new PlatformNotSupportedException(); }

/// <summary>
/// uint64_t svcntp_b16(svbool_t pg, svbool_t op)
/// CNTP Xresult, Pg, Pop.H
/// </summary>
public static unsafe ulong GetActiveElementCount(Vector<ushort> mask, Vector<ushort> from) { throw new PlatformNotSupportedException(); }

/// <summary>
/// uint64_t svcntp_b32(svbool_t pg, svbool_t op)
/// CNTP Xresult, Pg, Pop.S
/// </summary>
public static unsafe ulong GetActiveElementCount(Vector<uint> mask, Vector<uint> from) { throw new PlatformNotSupportedException(); }

/// <summary>
/// uint64_t svcntp_b64(svbool_t pg, svbool_t op)
/// CNTP Xresult, Pg, Pop.D
/// </summary>
public static unsafe ulong GetActiveElementCount(Vector<ulong> mask, Vector<ulong> from) { throw new PlatformNotSupportedException(); }


/// Count leading sign bits

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1376,6 +1376,69 @@ internal Arm64() { }
public static unsafe Vector<float> FusedMultiplySubtractNegated(Vector<float> minuend, Vector<float> left, Vector<float> right) => FusedMultiplySubtractNegated(minuend, left, right);


/// Count set predicate bits

/// <summary>
/// uint64_t svcntp_b8(svbool_t pg, svbool_t op)
/// CNTP Xresult, Pg, Pop.B
/// </summary>
public static unsafe ulong GetActiveElementCount(Vector<byte> mask, Vector<byte> from) => GetActiveElementCount(mask, from);

/// <summary>
/// uint64_t svcntp_b8(svbool_t pg, svbool_t op)
/// CNTP Xresult, Pg, Pop.B
/// </summary>
public static unsafe ulong GetActiveElementCount(Vector<double> mask, Vector<double> from) => GetActiveElementCount(mask, from);

/// <summary>
/// uint64_t svcntp_b8(svbool_t pg, svbool_t op)
/// CNTP Xresult, Pg, Pop.B
/// </summary>
public static unsafe ulong GetActiveElementCount(Vector<short> mask, Vector<short> from) => GetActiveElementCount(mask, from);

/// <summary>
/// uint64_t svcntp_b8(svbool_t pg, svbool_t op)
/// CNTP Xresult, Pg, Pop.B
/// </summary>
public static unsafe ulong GetActiveElementCount(Vector<int> mask, Vector<int> from) => GetActiveElementCount(mask, from);

/// <summary>
/// uint64_t svcntp_b8(svbool_t pg, svbool_t op)
/// CNTP Xresult, Pg, Pop.B
/// </summary>
public static unsafe ulong GetActiveElementCount(Vector<long> mask, Vector<long> from) => GetActiveElementCount(mask, from);

/// <summary>
/// uint64_t svcntp_b8(svbool_t pg, svbool_t op)
/// CNTP Xresult, Pg, Pop.B
/// </summary>
public static unsafe ulong GetActiveElementCount(Vector<sbyte> mask, Vector<sbyte> from) => GetActiveElementCount(mask, from);

/// <summary>
/// uint64_t svcntp_b8(svbool_t pg, svbool_t op)
/// CNTP Xresult, Pg, Pop.B
/// </summary>
public static unsafe ulong GetActiveElementCount(Vector<float> mask, Vector<float> from) => GetActiveElementCount(mask, from);

/// <summary>
/// uint64_t svcntp_b16(svbool_t pg, svbool_t op)
/// CNTP Xresult, Pg, Pop.H
/// </summary>
public static unsafe ulong GetActiveElementCount(Vector<ushort> mask, Vector<ushort> from) => GetActiveElementCount(mask, from);

/// <summary>
/// uint64_t svcntp_b32(svbool_t pg, svbool_t op)
/// CNTP Xresult, Pg, Pop.S
/// </summary>
public static unsafe ulong GetActiveElementCount(Vector<uint> mask, Vector<uint> from) => GetActiveElementCount(mask, from);

/// <summary>
/// uint64_t svcntp_b64(svbool_t pg, svbool_t op)
/// CNTP Xresult, Pg, Pop.D
/// </summary>
public static unsafe ulong GetActiveElementCount(Vector<ulong> mask, Vector<ulong> from) => GetActiveElementCount(mask, from);


/// LeadingSignCount : Count leading sign bits

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4365,6 +4365,17 @@ internal Arm64() { }
public static System.Numerics.Vector<double> FusedMultiplySubtractNegated(System.Numerics.Vector<double> minuend, System.Numerics.Vector<double> left, System.Numerics.Vector<double> right) { throw null; }
public static System.Numerics.Vector<float> FusedMultiplySubtractNegated(System.Numerics.Vector<float> minuend, System.Numerics.Vector<float> left, System.Numerics.Vector<float> right) { throw null; }

public static ulong GetActiveElementCount(System.Numerics.Vector<byte> mask, System.Numerics.Vector<byte> from) { throw null; }
public static ulong GetActiveElementCount(System.Numerics.Vector<double> mask, System.Numerics.Vector<double> from) { throw null; }
public static ulong GetActiveElementCount(System.Numerics.Vector<short> mask, System.Numerics.Vector<short> from) { throw null; }
public static ulong GetActiveElementCount(System.Numerics.Vector<int> mask, System.Numerics.Vector<int> from) { throw null; }
public static ulong GetActiveElementCount(System.Numerics.Vector<long> mask, System.Numerics.Vector<long> from) { throw null; }
public static ulong GetActiveElementCount(System.Numerics.Vector<sbyte> mask, System.Numerics.Vector<sbyte> from) { throw null; }
public static ulong GetActiveElementCount(System.Numerics.Vector<float> mask, System.Numerics.Vector<float> from) { throw null; }
public static ulong GetActiveElementCount(System.Numerics.Vector<ushort> mask, System.Numerics.Vector<ushort> from) { throw null; }
public static ulong GetActiveElementCount(System.Numerics.Vector<uint> mask, System.Numerics.Vector<uint> from) { throw null; }
public static ulong GetActiveElementCount(System.Numerics.Vector<ulong> mask, System.Numerics.Vector<ulong> from) { throw null; }

public static System.Numerics.Vector<byte> LeadingSignCount(System.Numerics.Vector<sbyte> value) { throw null; }
public static System.Numerics.Vector<ushort> LeadingSignCount(System.Numerics.Vector<short> value) { throw null; }
public static System.Numerics.Vector<uint> LeadingSignCount(System.Numerics.Vector<int> value) { throw null; }
Expand Down
Loading