diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index b25cb128bce9f..f8a5ccc3023a4 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -16229,6 +16229,68 @@ static SDValue performBITREVERSECombine(SDNode *N, SelectionDAG &DAG, return DAG.getNode(RISCVISD::BREV8, DL, VT, Src.getOperand(0)); } +static SDValue performVP_REVERSECombine(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + // Fold: + // vp.reverse(vp.load(ADDR, MASK)) -> vp.strided.load(ADDR, -1, MASK) + + // Check if its first operand is a vp.load. + auto *VPLoad = dyn_cast(N->getOperand(0)); + if (!VPLoad) + return SDValue(); + + EVT LoadVT = VPLoad->getValueType(0); + // We do not have a strided_load version for masks, and the evl of vp.reverse + // and vp.load should always be the same. + if (!LoadVT.getVectorElementType().isByteSized() || + N->getOperand(2) != VPLoad->getVectorLength() || + !N->getOperand(0).hasOneUse()) + return SDValue(); + + // Check if the mask of outer vp.reverse are all 1's. + if (!isOneOrOneSplat(N->getOperand(1))) + return SDValue(); + + SDValue LoadMask = VPLoad->getMask(); + // If Mask is all ones, then load is unmasked and can be reversed. + if (!isOneOrOneSplat(LoadMask)) { + // If the mask is not all ones, we can reverse the load if the mask was also + // reversed by an unmasked vp.reverse with the same EVL. + if (LoadMask.getOpcode() != ISD::EXPERIMENTAL_VP_REVERSE || + !isOneOrOneSplat(LoadMask.getOperand(1)) || + LoadMask.getOperand(2) != VPLoad->getVectorLength()) + return SDValue(); + LoadMask = LoadMask.getOperand(0); + } + + // Base = LoadAddr + (NumElem - 1) * ElemWidthByte + SDLoc DL(N); + MVT XLenVT = Subtarget.getXLenVT(); + SDValue NumElem = VPLoad->getVectorLength(); + uint64_t ElemWidthByte = VPLoad->getValueType(0).getScalarSizeInBits() / 8; + + SDValue Temp1 = DAG.getNode(ISD::SUB, DL, XLenVT, NumElem, + DAG.getConstant(1, DL, XLenVT)); + SDValue Temp2 = DAG.getNode(ISD::MUL, DL, XLenVT, Temp1, + DAG.getConstant(ElemWidthByte, DL, XLenVT)); + SDValue Base = DAG.getNode(ISD::ADD, DL, XLenVT, VPLoad->getBasePtr(), Temp2); + SDValue Stride = DAG.getConstant(-ElemWidthByte, DL, XLenVT); + + MachineFunction &MF = DAG.getMachineFunction(); + MachinePointerInfo PtrInfo(VPLoad->getAddressSpace()); + MachineMemOperand *MMO = MF.getMachineMemOperand( + PtrInfo, VPLoad->getMemOperand()->getFlags(), + LocationSize::beforeOrAfterPointer(), VPLoad->getAlign()); + + SDValue Ret = DAG.getStridedLoadVP( + LoadVT, DL, VPLoad->getChain(), Base, Stride, LoadMask, + VPLoad->getVectorLength(), MMO, VPLoad->isExpandingLoad()); + + DAG.ReplaceAllUsesOfValueWith(SDValue(VPLoad, 1), Ret.getValue(1)); + + return Ret; +} + // Convert from one FMA opcode to another based on whether we are negating the // multiply result and/or the accumulator. // NOTE: Only supports RVV operations with VL. @@ -18372,6 +18434,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, } } } + case ISD::EXPERIMENTAL_VP_REVERSE: + return performVP_REVERSECombine(N, DAG, Subtarget); case ISD::BITCAST: { assert(Subtarget.useRVVForFixedLengthVectors()); SDValue N0 = N->getOperand(0); diff --git a/llvm/test/CodeGen/RISCV/rvv/vp-combine-reverse-load.ll b/llvm/test/CodeGen/RISCV/rvv/vp-combine-reverse-load.ll new file mode 100644 index 0000000000000..50e26bd141070 --- /dev/null +++ b/llvm/test/CodeGen/RISCV/rvv/vp-combine-reverse-load.ll @@ -0,0 +1,79 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=riscv64 -mattr=+f,+v -verify-machineinstrs < %s | FileCheck %s + +define @test_reverse_load_combiner(* %ptr, i32 zeroext %evl) { +; CHECK-LABEL: test_reverse_load_combiner: +; CHECK: # %bb.0: +; CHECK-NEXT: slli a2, a1, 2 +; CHECK-NEXT: add a0, a2, a0 +; CHECK-NEXT: addi a0, a0, -4 +; CHECK-NEXT: li a2, -4 +; CHECK-NEXT: vsetvli zero, a1, e32, m1, ta, ma +; CHECK-NEXT: vlse32.v v8, (a0), a2 +; CHECK-NEXT: ret + %load = call @llvm.vp.load.nxv2f32.p0nxv2f32(* %ptr, splat (i1 true), i32 %evl) + %rev = call @llvm.experimental.vp.reverse.nxv2f32( %load, splat (i1 true), i32 %evl) + ret %rev +} + +define @test_load_mask_is_vp_reverse(* %ptr, %mask, i32 zeroext %evl) { +; CHECK-LABEL: test_load_mask_is_vp_reverse: +; CHECK: # %bb.0: +; CHECK-NEXT: slli a2, a1, 2 +; CHECK-NEXT: add a0, a2, a0 +; CHECK-NEXT: addi a0, a0, -4 +; CHECK-NEXT: li a2, -4 +; CHECK-NEXT: vsetvli zero, a1, e32, m1, ta, ma +; CHECK-NEXT: vlse32.v v8, (a0), a2, v0.t +; CHECK-NEXT: ret + %loadmask = call @llvm.experimental.vp.reverse.nxv2i1( %mask, splat (i1 true), i32 %evl) + %load = call @llvm.vp.load.nxv2f32.p0nxv2f32(* %ptr, %loadmask, i32 %evl) + %rev = call @llvm.experimental.vp.reverse.nxv2f32( %load, splat (i1 true), i32 %evl) + ret %rev +} + +define @test_load_mask_not_all_one(* %ptr, %notallones, i32 zeroext %evl) { +; CHECK-LABEL: test_load_mask_not_all_one: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a1, e32, m1, ta, ma +; CHECK-NEXT: vle32.v v9, (a0), v0.t +; CHECK-NEXT: vid.v v8, v0.t +; CHECK-NEXT: addi a1, a1, -1 +; CHECK-NEXT: vrsub.vx v10, v8, a1, v0.t +; CHECK-NEXT: vrgather.vv v8, v9, v10, v0.t +; CHECK-NEXT: ret + %load = call @llvm.vp.load.nxv2f32.p0nxv2f32(* %ptr, %notallones, i32 %evl) + %rev = call @llvm.experimental.vp.reverse.nxv2f32( %load, %notallones, i32 %evl) + ret %rev +} + +define @test_different_evl(* %ptr, %mask, i32 zeroext %evl1, i32 zeroext %evl2) { +; CHECK-LABEL: test_different_evl: +; CHECK: # %bb.0: +; CHECK-NEXT: addi a3, a1, -1 +; CHECK-NEXT: vsetvli zero, a1, e16, mf2, ta, ma +; CHECK-NEXT: vid.v v8 +; CHECK-NEXT: vsetvli zero, zero, e8, mf4, ta, ma +; CHECK-NEXT: vmv.v.i v9, 0 +; CHECK-NEXT: vsetvli zero, zero, e16, mf2, ta, ma +; CHECK-NEXT: vrsub.vx v8, v8, a3 +; CHECK-NEXT: vsetvli zero, zero, e8, mf4, ta, ma +; CHECK-NEXT: vmerge.vim v9, v9, 1, v0 +; CHECK-NEXT: vrgatherei16.vv v10, v9, v8 +; CHECK-NEXT: vmsne.vi v0, v10, 0 +; CHECK-NEXT: vsetvli zero, a2, e32, m1, ta, ma +; CHECK-NEXT: vle32.v v9, (a0), v0.t +; CHECK-NEXT: addi a2, a2, -1 +; CHECK-NEXT: vid.v v8 +; CHECK-NEXT: vrsub.vx v10, v8, a2 +; CHECK-NEXT: vrgather.vv v8, v9, v10 +; CHECK-NEXT: ret + %loadmask = call @llvm.experimental.vp.reverse.nxv2i1( %mask, splat (i1 true), i32 %evl1) + %load = call @llvm.vp.load.nxv2f32.p0nxv2f32(* %ptr, %loadmask, i32 %evl2) + %rev = call @llvm.experimental.vp.reverse.nxv2f32( %load, splat (i1 true), i32 %evl2) + ret %rev +} + +declare @llvm.vp.load.nxv2f32.p0nxv2f32(* nocapture, , i32) +declare @llvm.experimental.vp.reverse.nxv2f32(, , i32) +declare @llvm.experimental.vp.reverse.nxv2i1(, , i32)