[RISCV][RVV] Select unmasked TU RVV pseudos in a DAG post-process

Following D118810 that reduced the size of ISel table,
this patch optimizes allone-masked RVV pseudos with TU policy and
swap them out to their unmasked TU pseudos.

Since the UNDEF merge operand is not preserved, we turn it into TA
pseudo regardless of the policy operand.

Reviewed By: craig.topper, frasercrmck
Differential Revision: https://reviews.llvm.org/D121881
This commit is contained in:
ShihPo Hung 2022-03-21 19:58:46 -07:00
parent bcb2b86df6
commit 6b55f133fb
4 changed files with 45 additions and 29 deletions

View File

@ -2265,33 +2265,52 @@ bool RISCVDAGToDAGISel::doPeepholeMaskedRVV(SDNode *N) {
const MCInstrDesc &MaskedMCID = TII->get(N->getMachineOpcode()); const MCInstrDesc &MaskedMCID = TII->get(N->getMachineOpcode());
bool IsTA = true;
if (RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags)) { if (RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags)) {
// The last operand of the pseudo is the policy op, but we're expecting a // The last operand of the pseudo is the policy op, but we might have a
// Glue operand last. We may also have a chain. // Glue operand last. We might also have a chain.
TailPolicyOpIdx = N->getNumOperands() - 1; TailPolicyOpIdx = N->getNumOperands() - 1;
if (N->getOperand(*TailPolicyOpIdx).getValueType() == MVT::Glue) if (N->getOperand(*TailPolicyOpIdx).getValueType() == MVT::Glue)
(*TailPolicyOpIdx)--; (*TailPolicyOpIdx)--;
if (N->getOperand(*TailPolicyOpIdx).getValueType() == MVT::Other) if (N->getOperand(*TailPolicyOpIdx).getValueType() == MVT::Other)
(*TailPolicyOpIdx)--; (*TailPolicyOpIdx)--;
// If the policy isn't TAIL_AGNOSTIC we can't perform this optimization. if (!(N->getConstantOperandVal(*TailPolicyOpIdx) &
if (N->getConstantOperandVal(*TailPolicyOpIdx) != RISCVII::TAIL_AGNOSTIC) RISCVII::TAIL_AGNOSTIC)) {
// Keep the true-masked instruction when there is no unmasked TU
// instruction
if (I->UnmaskedTUPseudo == I->MaskedPseudo && !N->getOperand(0).isUndef())
return false; return false;
// We can't use TA if the tie-operand is not IMPLICIT_DEF
if (!N->getOperand(0).isUndef())
IsTA = false;
}
} }
const MCInstrDesc &UnmaskedMCID = TII->get(I->UnmaskedPseudo); if (IsTA) {
uint64_t TSFlags = TII->get(I->UnmaskedPseudo).TSFlags;
// Check that we're dropping the merge operand, the mask operand, and any // Check that we're dropping the merge operand, the mask operand, and any
// policy operand when we transform to this unmasked pseudo. // policy operand when we transform to this unmasked pseudo.
assert(!RISCVII::hasMergeOp(UnmaskedMCID.TSFlags) && assert(!RISCVII::hasMergeOp(TSFlags) && RISCVII::hasDummyMaskOp(TSFlags) &&
RISCVII::hasDummyMaskOp(UnmaskedMCID.TSFlags) && !RISCVII::hasVecPolicyOp(TSFlags) &&
!RISCVII::hasVecPolicyOp(UnmaskedMCID.TSFlags) &&
"Unexpected pseudo to transform to"); "Unexpected pseudo to transform to");
(void)UnmaskedMCID; (void)TSFlags;
} else {
uint64_t TSFlags = TII->get(I->UnmaskedTUPseudo).TSFlags;
// Check that we're dropping the mask operand, and any policy operand
// when we transform to this unmasked tu pseudo.
assert(RISCVII::hasMergeOp(TSFlags) && RISCVII::hasDummyMaskOp(TSFlags) &&
!RISCVII::hasVecPolicyOp(TSFlags) &&
"Unexpected pseudo to transform to");
(void)TSFlags;
}
unsigned Opc = IsTA ? I->UnmaskedPseudo : I->UnmaskedTUPseudo;
SmallVector<SDValue, 8> Ops; SmallVector<SDValue, 8> Ops;
// Skip the merge operand at index 0. // Skip the merge operand at index 0 if IsTA
for (unsigned I = 1, E = N->getNumOperands(); I != E; I++) { for (unsigned I = IsTA, E = N->getNumOperands(); I != E; I++) {
// Skip the mask, the policy, and the Glue. // Skip the mask, the policy, and the Glue.
SDValue Op = N->getOperand(I); SDValue Op = N->getOperand(I);
if (I == MaskOpIdx || I == TailPolicyOpIdx || if (I == MaskOpIdx || I == TailPolicyOpIdx ||
@ -2304,8 +2323,7 @@ bool RISCVDAGToDAGISel::doPeepholeMaskedRVV(SDNode *N) {
if (auto *TGlued = Glued->getGluedNode()) if (auto *TGlued = Glued->getGluedNode())
Ops.push_back(SDValue(TGlued, TGlued->getNumValues() - 1)); Ops.push_back(SDValue(TGlued, TGlued->getNumValues() - 1));
SDNode *Result = SDNode *Result = CurDAG->getMachineNode(Opc, SDLoc(N), N->getVTList(), Ops);
CurDAG->getMachineNode(I->UnmaskedPseudo, SDLoc(N), N->getVTList(), Ops);
ReplaceUses(N, Result); ReplaceUses(N, Result);
return true; return true;

View File

@ -191,6 +191,7 @@ struct VLX_VSXPseudo {
struct RISCVMaskedPseudoInfo { struct RISCVMaskedPseudoInfo {
uint16_t MaskedPseudo; uint16_t MaskedPseudo;
uint16_t UnmaskedPseudo; uint16_t UnmaskedPseudo;
uint16_t UnmaskedTUPseudo;
uint8_t MaskOpIdx; uint8_t MaskOpIdx;
}; };

View File

@ -423,16 +423,17 @@ def RISCVVIntrinsicsTable : GenericTable {
let PrimaryKeyName = "getRISCVVIntrinsicInfo"; let PrimaryKeyName = "getRISCVVIntrinsicInfo";
} }
class RISCVMaskedPseudo<bits<4> MaskIdx> { class RISCVMaskedPseudo<bits<4> MaskIdx, bit HasTU = true> {
Pseudo MaskedPseudo = !cast<Pseudo>(NAME); Pseudo MaskedPseudo = !cast<Pseudo>(NAME);
Pseudo UnmaskedPseudo = !cast<Pseudo>(!subst("_MASK", "", NAME)); Pseudo UnmaskedPseudo = !cast<Pseudo>(!subst("_MASK", "", NAME));
Pseudo UnmaskedTUPseudo = !if(HasTU, !cast<Pseudo>(!subst("_MASK", "", NAME # "_TU")), MaskedPseudo);
bits<4> MaskOpIdx = MaskIdx; bits<4> MaskOpIdx = MaskIdx;
} }
def RISCVMaskedPseudosTable : GenericTable { def RISCVMaskedPseudosTable : GenericTable {
let FilterClass = "RISCVMaskedPseudo"; let FilterClass = "RISCVMaskedPseudo";
let CppTypeName = "RISCVMaskedPseudoInfo"; let CppTypeName = "RISCVMaskedPseudoInfo";
let Fields = ["MaskedPseudo", "UnmaskedPseudo", "MaskOpIdx"]; let Fields = ["MaskedPseudo", "UnmaskedPseudo", "UnmaskedTUPseudo", "MaskOpIdx"];
let PrimaryKey = ["MaskedPseudo"]; let PrimaryKey = ["MaskedPseudo"];
let PrimaryKeyName = "getMaskedPseudoInfo"; let PrimaryKeyName = "getMaskedPseudoInfo";
} }
@ -1770,7 +1771,7 @@ multiclass VPseudoBinaryM<VReg RetClass,
let ForceTailAgnostic = true in let ForceTailAgnostic = true in
def "_" # MInfo.MX # "_MASK" : VPseudoBinaryMOutMask<RetClass, Op1Class, def "_" # MInfo.MX # "_MASK" : VPseudoBinaryMOutMask<RetClass, Op1Class,
Op2Class, Constraint>, Op2Class, Constraint>,
RISCVMaskedPseudo</*MaskOpIdx*/ 3>; RISCVMaskedPseudo</*MaskOpIdx*/ 3, /*HasTU*/ false>;
} }
} }

View File

@ -31,14 +31,12 @@ entry:
ret <vscale x 1 x i8> %a ret <vscale x 1 x i8> %a
} }
; FIXME: Use an unmasked TAIL_AGNOSTIC instruction if the tie operand is IMPLICIT_DEF ; Use an unmasked TAIL_AGNOSTIC instruction if the tie operand is IMPLICIT_DEF
define <vscale x 1 x i8> @test1(<vscale x 1 x i8> %0, <vscale x 1 x i8> %1, iXLen %2) nounwind { define <vscale x 1 x i8> @test1(<vscale x 1 x i8> %0, <vscale x 1 x i8> %1, iXLen %2) nounwind {
; CHECK-LABEL: test1: ; CHECK-LABEL: test1:
; CHECK: # %bb.0: # %entry ; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vsetvli zero, a0, e8, mf8, ta, mu ; CHECK-NEXT: vsetvli zero, a0, e8, mf8, ta, mu
; CHECK-NEXT: vmset.m v0 ; CHECK-NEXT: vadd.vv v8, v8, v9
; CHECK-NEXT: vsetvli zero, zero, e8, mf8, tu, mu
; CHECK-NEXT: vadd.vv v8, v8, v9, v0.t
; CHECK-NEXT: ret ; CHECK-NEXT: ret
entry: entry:
%allone = call <vscale x 1 x i1> @llvm.riscv.vmset.nxv1i1( %allone = call <vscale x 1 x i1> @llvm.riscv.vmset.nxv1i1(
@ -53,14 +51,12 @@ entry:
ret <vscale x 1 x i8> %a ret <vscale x 1 x i8> %a
} }
; FIXME: Use an unmasked TU instruction because of the policy operand ; Use an unmasked TU instruction because of the policy operand
define <vscale x 1 x i8> @test2(<vscale x 1 x i8> %0, <vscale x 1 x i8> %1, <vscale x 1 x i8> %2, iXLen %3) nounwind { define <vscale x 1 x i8> @test2(<vscale x 1 x i8> %0, <vscale x 1 x i8> %1, <vscale x 1 x i8> %2, iXLen %3) nounwind {
; CHECK-LABEL: test2: ; CHECK-LABEL: test2:
; CHECK: # %bb.0: # %entry ; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vsetvli zero, a0, e8, mf8, ta, mu ; CHECK-NEXT: vsetvli zero, a0, e8, mf8, tu, mu
; CHECK-NEXT: vmset.m v0 ; CHECK-NEXT: vadd.vv v8, v9, v10
; CHECK-NEXT: vsetvli zero, zero, e8, mf8, tu, mu
; CHECK-NEXT: vadd.vv v8, v9, v10, v0.t
; CHECK-NEXT: ret ; CHECK-NEXT: ret
entry: entry:
%allone = call <vscale x 1 x i1> @llvm.riscv.vmset.nxv1i1( %allone = call <vscale x 1 x i1> @llvm.riscv.vmset.nxv1i1(