Skip to content

[llvm][AArch64][Assembly]: Add FP8FMA assembly and disassembly. #70134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/TargetParser/AArch64TargetParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ enum ArchExtKind : unsigned {
AEK_FPMR = 58, // FEAT_FPMR
AEK_FP8 = 59, // FEAT_FP8
AEK_FAMINMAX = 60, // FEAT_FAMINMAX
AEK_FP8FMA = 61, // FEAT_FP8FMA
AEK_SSVE_FP8FMA = 62, // FEAT_SSVE_FP8FMA
AEK_NUM_EXTENSIONS
};
using ExtensionBitset = Bitset<AEK_NUM_EXTENSIONS>;
Expand Down Expand Up @@ -273,6 +275,8 @@ inline constexpr ExtensionInfo Extensions[] = {
{"fpmr", AArch64::AEK_FPMR, "+fpmr", "-fpmr", FEAT_INIT, "", 0},
{"fp8", AArch64::AEK_FP8, "+fp8", "-fp8", FEAT_INIT, "+fpmr", 0},
{"faminmax", AArch64::AEK_FAMINMAX, "+faminmax", "-faminmax", FEAT_INIT, "", 0},
{"fp8fma", AArch64::AEK_FP8FMA, "+fp8fma", "-fp8fma", FEAT_INIT, "+fpmr", 0},
{"ssve-fp8fma", AArch64::AEK_SSVE_FP8FMA, "+ssve-fp8fma", "-ssve-fp8fma", FEAT_INIT, "+sme2", 0},
// Special cases
{"none", AArch64::AEK_NONE, {}, {}, FEAT_INIT, "", ExtensionInfo::MaxFMVPriority},
};
Expand Down
2 changes: 1 addition & 1 deletion llvm/include/llvm/TargetParser/SubtargetFeature.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace llvm {
class raw_ostream;
class Triple;

const unsigned MAX_SUBTARGET_WORDS = 4;
const unsigned MAX_SUBTARGET_WORDS = 5;
const unsigned MAX_SUBTARGET_FEATURES = MAX_SUBTARGET_WORDS * 64;

/// Container class for subtarget features.
Expand Down
10 changes: 8 additions & 2 deletions llvm/lib/Target/AArch64/AArch64.td
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,12 @@ def FeatureSME2p1 : SubtargetFeature<"sme2p1", "HasSME2p1", "true",
def FeatureFAMINMAX: SubtargetFeature<"faminmax", "HasFAMINMAX", "true",
"Enable FAMIN and FAMAX instructions (FEAT_FAMINMAX)">;

def FeatureFP8FMA : SubtargetFeature<"fp8fma", "HasFP8FMA", "true",
"Enable fp8 multiply-add instructions (FEAT_FP8FMA)">;

def FeatureSSVE_FP8FMA : SubtargetFeature<"ssve-fp8fma", "HasSSVE_FP8FMA", "true",
"Enable SVE2 fp8 multiply-add instructions (FEAT_SSVE_FP8FMA)", [FeatureSME2]>;

def FeatureAppleA7SysReg : SubtargetFeature<"apple-a7-sysreg", "HasAppleA7SysReg", "true",
"Apple A7 (the CPU formerly known as Cyclone)">;

Expand Down Expand Up @@ -747,7 +753,7 @@ let F = [HasSVE2p1, HasSVE2p1_or_HasSME2, HasSVE2p1_or_HasSME2p1] in
def SVE2p1Unsupported : AArch64Unsupported;

def SVE2Unsupported : AArch64Unsupported {
let F = !listconcat([HasSVE2, HasSVE2orSME,
let F = !listconcat([HasSVE2, HasSVE2orSME, HasSSVE_FP8FMA,
HasSVE2AES, HasSVE2SHA3, HasSVE2SM4, HasSVE2BitPerm],
SVE2p1Unsupported.F);
}
Expand All @@ -761,7 +767,7 @@ let F = [HasSME2p1, HasSVE2p1_or_HasSME2p1] in
def SME2p1Unsupported : AArch64Unsupported;

def SME2Unsupported : AArch64Unsupported {
let F = !listconcat([HasSME2, HasSVE2p1_or_HasSME2],
let F = !listconcat([HasSME2, HasSVE2p1_or_HasSME2, HasSSVE_FP8FMA],
SME2p1Unsupported.F);
}

Expand Down
34 changes: 34 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -6055,6 +6055,15 @@ multiclass SIMDThreeSameVectorFML<bit U, bit b13, bits<3> size, string asm,
v4f32, v8f16, OpNode>;
}

multiclass SIMDThreeSameVectorMLA<bit Q, string asm>{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for using multiclasses if they contain a single record each?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you are correct.
But in my view, it's more clear to use a multiclass in this case because the class is called Dot while the multiclass instance is for MLA, so it's more clear.
What do you think ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could use the class BaseSIMDThreeSameVectorDot when defining , but she is following the pattern of the other instructions(despite the fact they have more than one definition).
Also we can hide the size of the instructions in the multiclass.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative would be to define a new class which inherits from BaseSIMDThreeSameVectorDot and handles the common values for all the definitions. For example:

class SIMDThreeSameVectorMLA<bit Q, string asm>
    : BaseSIMDThreeSameVectorDot<Q, 0b0, 0b11, 0b1111, asm, ".8h", ".16b",
                                                              V128, v8f16, v16i8, null_frag>;

Using multiclasses for consistency sounds reasonable, though, if you prefer to keep the current implementation.

def v8f16 : BaseSIMDThreeSameVectorDot<Q, 0b0, 0b11, 0b1111, asm, ".8h", ".16b",
V128, v8f16, v16i8, null_frag>;
}

multiclass SIMDThreeSameVectorMLAL<bit Q, bits<2> sz, string asm>{
def v4f32 : BaseSIMDThreeSameVectorDot<Q, 0b0, sz, 0b1000, asm, ".4s", ".16b",
V128, v4f32, v16i8, null_frag>;
}

// FP8 assembly/disassembly classes

Expand Down Expand Up @@ -8521,6 +8530,31 @@ class BF16ToSinglePrecision<string asm>
}
} // End of let mayStore = 0, mayLoad = 0, hasSideEffects = 0

//----------------------------------------------------------------------------
class BaseSIMDThreeSameVectorIndexB<bit Q, bit U, bits<2> sz, bits<4> opc,
string asm, string dst_kind,
RegisterOperand RegType,
RegisterOperand RegType_lo>
: BaseSIMDIndexedTied<Q, U, 0b0, sz, opc,
RegType, RegType, RegType_lo, VectorIndexB,
asm, "", dst_kind, ".16b", ".b", []> {

// idx = H:L:M
bits<4> idx;
let Inst{11} = idx{3};
let Inst{21-19} = idx{2-0};
}

multiclass SIMDThreeSameVectorMLAIndex<bit Q, string asm> {
def v8f16 : BaseSIMDThreeSameVectorIndexB<Q, 0b0, 0b11, 0b0000, asm, ".8h",
V128, V128_0to7>;
}

multiclass SIMDThreeSameVectorMLALIndex<bit Q, bits<2> sz, string asm> {
def v4f32 : BaseSIMDThreeSameVectorIndexB<Q, 0b1, sz, 0b1000, asm, ".4s",
V128, V128_0to7>;
}

//----------------------------------------------------------------------------
// Armv8.6 Matrix Multiply Extension
//----------------------------------------------------------------------------
Expand Down
22 changes: 22 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ def HasFP8 : Predicate<"Subtarget->hasFP8()">,
AssemblerPredicateWithAll<(all_of FeatureFP8), "fp8">;
def HasFAMINMAX : Predicate<"Subtarget->hasFAMINMAX()">,
AssemblerPredicateWithAll<(all_of FeatureFAMINMAX), "faminmax">;
def HasFP8FMA : Predicate<"Subtarget->hasFP8FMA()">,
AssemblerPredicateWithAll<(all_of FeatureFP8FMA), "fp8fma">;
def HasSSVE_FP8FMA : Predicate<"Subtarget->SSVE_FP8FMA() || "
"(Subtarget->hasSVE2() && Subtarget->hasFP8FMA())">,
AssemblerPredicateWithAll<(any_of FeatureSSVE_FP8FMA,
(all_of FeatureSVE2, FeatureFP8FMA)),
"ssve-fp8fma or (sve2 and fp8fma)">;

// A subset of SVE(2) instructions are legal in Streaming SVE execution mode,
// they should be enabled if either has been specified.
Expand Down Expand Up @@ -9283,6 +9290,21 @@ let Predicates = [HasFAMINMAX] in {
defm FAMIN : SIMDThreeSameVectorFP<0b1, 0b1, 0b011, "famin", null_frag>;
} // End let Predicates = [HasFAMAXMIN]

let Predicates = [HasFP8FMA] in {
defm FMLALBlane : SIMDThreeSameVectorMLAIndex<0b0, "fmlalb">;
defm FMLALTlane : SIMDThreeSameVectorMLAIndex<0b1, "fmlalt">;
defm FMLALLBBlane : SIMDThreeSameVectorMLALIndex<0b0, 0b00, "fmlallbb">;
defm FMLALLBTlane : SIMDThreeSameVectorMLALIndex<0b0, 0b01, "fmlallbt">;
defm FMLALLTBlane : SIMDThreeSameVectorMLALIndex<0b1, 0b00, "fmlalltb">;
defm FMLALLTTlane : SIMDThreeSameVectorMLALIndex<0b1, 0b01, "fmlalltt">;

defm FMLALB : SIMDThreeSameVectorMLA<0b0, "fmlalb">;
defm FMLALT : SIMDThreeSameVectorMLA<0b1, "fmlalt">;
defm FMLALLBB : SIMDThreeSameVectorMLAL<0b0, 0b00, "fmlallbb">;
defm FMLALLBT : SIMDThreeSameVectorMLAL<0b0, 0b01, "fmlallbt">;
defm FMLALLTB : SIMDThreeSameVectorMLAL<0b1, 0b00, "fmlalltb">;
defm FMLALLTT : SIMDThreeSameVectorMLAL<0b1, 0b01, "fmlalltt">;
} // End let Predicates = [HasFP8FMA]

include "AArch64InstrAtomics.td"
include "AArch64SVEInstrInfo.td"
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,8 @@ unsigned AArch64RegisterInfo::getRegPressureLimit(const TargetRegisterClass *RC,
case AArch64::FPR64_loRegClassID:
case AArch64::FPR16_loRegClassID:
return 16;
case AArch64::FPR128_0to7RegClassID:
return 8;
}
}

Expand Down
16 changes: 16 additions & 0 deletions llvm/lib/Target/AArch64/AArch64RegisterInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,13 @@ def FPR128_lo : RegisterClass<"AArch64",
v8bf16],
128, (trunc FPR128, 16)>;

// The lower 8 vector registers. Some instructions can only take registers
// in this range.
def FPR128_0to7 : RegisterClass<"AArch64",
[v16i8, v8i16, v4i32, v2i64, v4f32, v2f64, v8f16,
v8bf16],
128, (trunc FPR128, 8)>;

// Pairs, triples, and quads of 64-bit vector registers.
def DSeqPairs : RegisterTuples<[dsub0, dsub1], [(rotl FPR64, 0), (rotl FPR64, 1)]>;
def DSeqTriples : RegisterTuples<[dsub0, dsub1, dsub2],
Expand Down Expand Up @@ -534,6 +541,15 @@ def V128_lo : RegisterOperand<FPR128_lo, "printVRegOperand"> {
let ParserMatchClass = VectorRegLoAsmOperand;
}

def VectorReg0to7AsmOperand : AsmOperandClass {
let Name = "VectorReg0to7";
let PredicateMethod = "isNeonVectorReg0to7";
}

def V128_0to7 : RegisterOperand<FPR128_0to7, "printVRegOperand"> {
let ParserMatchClass = VectorReg0to7AsmOperand;
}

class TypedVecListAsmOperand<int count, string vecty, int lanes, int eltsize>
: AsmOperandClass {
let Name = "TypedVectorList" # count # "_" # lanes # eltsize;
Expand Down
19 changes: 19 additions & 0 deletions llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -4029,3 +4029,22 @@ let Predicates = [HasSVE2orSME2, HasFAMINMAX] in {
defm FAMIN_ZPmZ : sve_fp_2op_p_zds<0b1111, "famin", "", null_frag, DestructiveOther>;
defm FAMAX_ZPmZ : sve_fp_2op_p_zds<0b1110, "famax", "", null_frag, DestructiveOther>;
} // End HasSVE2orSME2, HasFAMINMAX

let Predicates = [HasSSVE_FP8FMA] in {
// FP8 Widening Multiply-Add Long - Indexed Group
def FMLALB_ZZZI : sve2_fp8_mla_long_by_indexed_elem<0b0, "fmlalb">;
def FMLALT_ZZZI : sve2_fp8_mla_long_by_indexed_elem<0b1, "fmlalt">;
// FP8 Widening Multiply-Add Long Group
def FMLALB_ZZZ : sve2_fp8_mla<0b100, ZPR16, "fmlalb">;
def FMLALT_ZZZ : sve2_fp8_mla<0b101, ZPR16, "fmlalt">;
// FP8 Widening Multiply-Add Long Long - Indexed Group
def FMLALLBB_ZZZI : sve2_fp8_mla_long_long_by_indexed_elem<0b00, "fmlallbb">;
def FMLALLBT_ZZZI : sve2_fp8_mla_long_long_by_indexed_elem<0b01, "fmlallbt">;
def FMLALLTB_ZZZI : sve2_fp8_mla_long_long_by_indexed_elem<0b10, "fmlalltb">;
def FMLALLTT_ZZZI : sve2_fp8_mla_long_long_by_indexed_elem<0b11, "fmlalltt">;
// FP8 Widening Multiply-Add Long Long Group
def FMLALLBB_ZZZ : sve2_fp8_mla<0b000, ZPR32, "fmlallbb">;
def FMLALLBT_ZZZ : sve2_fp8_mla<0b001, ZPR32, "fmlallbt">;
def FMLALLTB_ZZZ : sve2_fp8_mla<0b010, ZPR32, "fmlalltb">;
def FMLALLTT_ZZZ : sve2_fp8_mla<0b011, ZPR32, "fmlalltt">;
} // End HasSSVE_FP8FMA
2 changes: 1 addition & 1 deletion llvm/lib/Target/AArch64/AArch64SchedA64FX.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def A64FXModel : SchedMachineModel {
list<Predicate> UnsupportedFeatures =
[HasSVE2, HasSVE2AES, HasSVE2SM4, HasSVE2SHA3, HasSVE2BitPerm, HasPAuth,
HasSVE2orSME, HasMTE, HasMatMulInt8, HasBF16, HasSME2, HasSME2p1, HasSVE2p1,
HasSVE2p1_or_HasSME2p1, HasSMEF16F16];
HasSVE2p1_or_HasSME2p1, HasSMEF16F16, HasSSVE_FP8FMA];

let FullInstRWOverlapCheck = 0;
}
Expand Down
63 changes: 38 additions & 25 deletions llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1223,6 +1223,12 @@ class AArch64Operand : public MCParsedAsmOperand {
Reg.RegNum));
}

bool isNeonVectorReg0to7() const {
return Kind == k_Register && Reg.Kind == RegKind::NeonVector &&
(AArch64MCRegisterClasses[AArch64::FPR128_0to7RegClassID].contains(
Reg.RegNum));
}

bool isMatrix() const { return Kind == k_MatrixRegister; }
bool isMatrixTileList() const { return Kind == k_MatrixTileList; }

Expand Down Expand Up @@ -1766,6 +1772,11 @@ class AArch64Operand : public MCParsedAsmOperand {
Inst.addOperand(MCOperand::createReg(getReg()));
}

void addVectorReg0to7Operands(MCInst &Inst, unsigned N) const {
assert(N == 1 && "Invalid number of operands!");
Inst.addOperand(MCOperand::createReg(getReg()));
}

enum VecListIndexType {
VecListIdx_DReg = 0,
VecListIdx_QReg = 1,
Expand Down Expand Up @@ -2598,31 +2609,31 @@ static std::optional<std::pair<int, int>> parseVectorKind(StringRef Suffix,

switch (VectorKind) {
case RegKind::NeonVector:
Res =
StringSwitch<std::pair<int, int>>(Suffix.lower())
.Case("", {0, 0})
.Case(".1d", {1, 64})
.Case(".1q", {1, 128})
// '.2h' needed for fp16 scalar pairwise reductions
.Case(".2h", {2, 16})
.Case(".2s", {2, 32})
.Case(".2d", {2, 64})
// '.4b' is another special case for the ARMv8.2a dot product
// operand
.Case(".4b", {4, 8})
.Case(".4h", {4, 16})
.Case(".4s", {4, 32})
.Case(".8b", {8, 8})
.Case(".8h", {8, 16})
.Case(".16b", {16, 8})
// Accept the width neutral ones, too, for verbose syntax. If those
// aren't used in the right places, the token operand won't match so
// all will work out.
.Case(".b", {0, 8})
.Case(".h", {0, 16})
.Case(".s", {0, 32})
.Case(".d", {0, 64})
.Default({-1, -1});
Res = StringSwitch<std::pair<int, int>>(Suffix.lower())
.Case("", {0, 0})
.Case(".1d", {1, 64})
.Case(".1q", {1, 128})
// '.2h' needed for fp16 scalar pairwise reductions
.Case(".2h", {2, 16})
.Case(".2b", {2, 8})
.Case(".2s", {2, 32})
.Case(".2d", {2, 64})
// '.4b' is another special case for the ARMv8.2a dot product
// operand
.Case(".4b", {4, 8})
.Case(".4h", {4, 16})
.Case(".4s", {4, 32})
.Case(".8b", {8, 8})
.Case(".8h", {8, 16})
.Case(".16b", {16, 8})
// Accept the width neutral ones, too, for verbose syntax. If
// those aren't used in the right places, the token operand won't
// match so all will work out.
.Case(".b", {0, 8})
.Case(".h", {0, 16})
.Case(".s", {0, 32})
.Case(".d", {0, 64})
.Default({-1, -1});
break;
case RegKind::SVEPredicateAsCounter:
case RegKind::SVEPredicateVector:
Expand Down Expand Up @@ -3641,6 +3652,8 @@ static const struct Extension {
{"fpmr", {AArch64::FeatureFPMR}},
{"fp8", {AArch64::FeatureFP8}},
{"faminmax", {AArch64::FeatureFAMINMAX}},
{"fp8fma", {AArch64::FeatureFP8FMA}},
{"ssve-fp8fma", {AArch64::FeatureSSVE_FP8FMA}},
};

static void setRequiredFeatureString(FeatureBitset FBS, std::string &Str) {
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Target/AArch64/Disassembler/AArch64Disassembler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ static DecodeStatus DecodeFPR128RegisterClass(MCInst &Inst, unsigned RegNo,
static DecodeStatus DecodeFPR128_loRegisterClass(MCInst &Inst, unsigned RegNo,
uint64_t Address,
const MCDisassembler *Decoder);
static DecodeStatus
DecodeFPR128_0to7RegisterClass(MCInst &Inst, unsigned RegNo, uint64_t Address,
const MCDisassembler *Decoder);
static DecodeStatus DecodeFPR64RegisterClass(MCInst &Inst, unsigned RegNo,
uint64_t Address,
const MCDisassembler *Decoder);
Expand Down Expand Up @@ -437,6 +440,14 @@ DecodeFPR128_loRegisterClass(MCInst &Inst, unsigned RegNo, uint64_t Addr,
return DecodeFPR128RegisterClass(Inst, RegNo, Addr, Decoder);
}

static DecodeStatus
DecodeFPR128_0to7RegisterClass(MCInst &Inst, unsigned RegNo, uint64_t Addr,
const MCDisassembler *Decoder) {
if (RegNo > 7)
return Fail;
return DecodeFPR128RegisterClass(Inst, RegNo, Addr, Decoder);
}

static DecodeStatus DecodeFPR64RegisterClass(MCInst &Inst, unsigned RegNo,
uint64_t Addr,
const MCDisassembler *Decoder) {
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,10 @@ AArch64RegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
case AArch64::FPR32_with_hsub_in_FPR16_loRegClassID:
case AArch64::FPR32RegClassID:
case AArch64::FPR64RegClassID:
case AArch64::FPR64_loRegClassID:
case AArch64::FPR128RegClassID:
case AArch64::FPR64_loRegClassID:
case AArch64::FPR128_loRegClassID:
case AArch64::FPR128_0to7RegClassID:
case AArch64::DDRegClassID:
case AArch64::DDDRegClassID:
case AArch64::DDDDRegClassID:
Expand Down
Loading