[GlobalISel] Handle constant splat in funnel shift combine

This change adds the constant splat versions of m_ICst() (by using
getBuildVectorConstantSplat()) and uses it in
matchOrShiftToFunnelShift(). The getBuildVectorConstantSplat() name is
shortened to getIConstantSplatVal() so that the *SExtVal() version would
have a more compact name.

Differential Revision: https://reviews.llvm.org/D125516
This commit is contained in:
Abinav Puthan Purayil 2022-05-12 22:35:52 +05:30
parent f96d20450c
commit 485dd0b752
7 changed files with 112 additions and 23 deletions

View File

@ -94,6 +94,48 @@ inline ConstantMatch<int64_t> m_ICst(int64_t &Cst) {
return ConstantMatch<int64_t>(Cst);
}
template <typename ConstT>
inline Optional<ConstT> matchConstantSplat(Register,
const MachineRegisterInfo &);
template <>
inline Optional<APInt> matchConstantSplat(Register Reg,
const MachineRegisterInfo &MRI) {
return getIConstantSplatVal(Reg, MRI);
}
template <>
inline Optional<int64_t> matchConstantSplat(Register Reg,
const MachineRegisterInfo &MRI) {
return getIConstantSplatSExtVal(Reg, MRI);
}
template <typename ConstT> struct ICstOrSplatMatch {
ConstT &CR;
ICstOrSplatMatch(ConstT &C) : CR(C) {}
bool match(const MachineRegisterInfo &MRI, Register Reg) {
if (auto MaybeCst = matchConstant<ConstT>(Reg, MRI)) {
CR = *MaybeCst;
return true;
}
if (auto MaybeCstSplat = matchConstantSplat<ConstT>(Reg, MRI)) {
CR = *MaybeCstSplat;
return true;
}
return false;
};
};
inline ICstOrSplatMatch<APInt> m_ICstOrSplat(APInt &Cst) {
return ICstOrSplatMatch<APInt>(Cst);
}
inline ICstOrSplatMatch<int64_t> m_ICstOrSplat(int64_t &Cst) {
return ICstOrSplatMatch<int64_t>(Cst);
}
struct GCstAndRegMatch {
Optional<ValueAndVReg> &ValReg;
GCstAndRegMatch(Optional<ValueAndVReg> &ValReg) : ValReg(ValReg) {}

View File

@ -373,9 +373,23 @@ public:
/// If \p MI is not a splat, returns None.
Optional<int> getSplatIndex(MachineInstr &MI);
/// Returns a scalar constant of a G_BUILD_VECTOR splat if it exists.
Optional<int64_t> getBuildVectorConstantSplat(const MachineInstr &MI,
const MachineRegisterInfo &MRI);
/// \returns the scalar integral splat value of \p Reg if possible.
Optional<APInt> getIConstantSplatVal(const Register Reg,
const MachineRegisterInfo &MRI);
/// \returns the scalar integral splat value defined by \p MI if possible.
Optional<APInt> getIConstantSplatVal(const MachineInstr &MI,
const MachineRegisterInfo &MRI);
/// \returns the scalar sign extended integral splat value of \p Reg if
/// possible.
Optional<int64_t> getIConstantSplatSExtVal(const Register Reg,
const MachineRegisterInfo &MRI);
/// \returns the scalar sign extended integral splat value defined by \p MI if
/// possible.
Optional<int64_t> getIConstantSplatSExtVal(const MachineInstr &MI,
const MachineRegisterInfo &MRI);
/// Returns a floating point scalar constant of a build vector splat if it
/// exists. When \p AllowUndef == true some elements can be undef but not all.

View File

@ -2945,7 +2945,7 @@ bool CombinerHelper::matchNotCmp(MachineInstr &MI,
int64_t Cst;
if (Ty.isVector()) {
MachineInstr *CstDef = MRI.getVRegDef(CstReg);
auto MaybeCst = getBuildVectorConstantSplat(*CstDef, MRI);
auto MaybeCst = getIConstantSplatSExtVal(*CstDef, MRI);
if (!MaybeCst)
return false;
if (!isConstValidTrue(TLI, Ty.getScalarSizeInBits(), *MaybeCst, true, IsFP))
@ -4029,10 +4029,9 @@ bool CombinerHelper::matchOrShiftToFunnelShift(MachineInstr &MI,
// Given constants C0 and C1 such that C0 + C1 is bit-width:
// (or (shl x, C0), (lshr y, C1)) -> (fshl x, y, C0) or (fshr x, y, C1)
// TODO: Match constant splat.
int64_t CstShlAmt, CstLShrAmt;
if (mi_match(ShlAmt, MRI, m_ICst(CstShlAmt)) &&
mi_match(LShrAmt, MRI, m_ICst(CstLShrAmt)) &&
if (mi_match(ShlAmt, MRI, m_ICstOrSplat(CstShlAmt)) &&
mi_match(LShrAmt, MRI, m_ICstOrSplat(CstLShrAmt)) &&
CstShlAmt + CstLShrAmt == BitWidth) {
FshOpc = TargetOpcode::G_FSHR;
Amt = LShrAmt;

View File

@ -1071,15 +1071,38 @@ bool llvm::isBuildVectorConstantSplat(const MachineInstr &MI,
AllowUndef);
}
Optional<int64_t>
llvm::getBuildVectorConstantSplat(const MachineInstr &MI,
const MachineRegisterInfo &MRI) {
Optional<APInt> llvm::getIConstantSplatVal(const Register Reg,
const MachineRegisterInfo &MRI) {
if (auto SplatValAndReg =
getAnyConstantSplat(MI.getOperand(0).getReg(), MRI, false))
getAnyConstantSplat(Reg, MRI, /* AllowUndef */ false)) {
Optional<ValueAndVReg> ValAndVReg =
getIConstantVRegValWithLookThrough(SplatValAndReg->VReg, MRI);
return ValAndVReg->Value;
}
return None;
}
Optional<APInt> getIConstantSplatVal(const MachineInstr &MI,
const MachineRegisterInfo &MRI) {
return getIConstantSplatVal(MI.getOperand(0).getReg(), MRI);
}
Optional<int64_t>
llvm::getIConstantSplatSExtVal(const Register Reg,
const MachineRegisterInfo &MRI) {
if (auto SplatValAndReg =
getAnyConstantSplat(Reg, MRI, /* AllowUndef */ false))
return getIConstantVRegSExtVal(SplatValAndReg->VReg, MRI);
return None;
}
Optional<int64_t>
llvm::getIConstantSplatSExtVal(const MachineInstr &MI,
const MachineRegisterInfo &MRI) {
return getIConstantSplatSExtVal(MI.getOperand(0).getReg(), MRI);
}
Optional<FPValueAndVReg> llvm::getFConstantSplat(Register VReg,
const MachineRegisterInfo &MRI,
bool AllowUndef) {
@ -1105,7 +1128,7 @@ Optional<RegOrConstant> llvm::getVectorSplat(const MachineInstr &MI,
unsigned Opc = MI.getOpcode();
if (!isBuildVectorOp(Opc))
return None;
if (auto Splat = getBuildVectorConstantSplat(MI, MRI))
if (auto Splat = getIConstantSplatSExtVal(MI, MRI))
return RegOrConstant(*Splat);
auto Reg = MI.getOperand(1).getReg();
if (any_of(make_range(MI.operands_begin() + 2, MI.operands_end()),
@ -1176,7 +1199,7 @@ llvm::isConstantOrConstantSplatVector(MachineInstr &MI,
Register Def = MI.getOperand(0).getReg();
if (auto C = getIConstantVRegValWithLookThrough(Def, MRI))
return C->Value;
auto MaybeCst = getBuildVectorConstantSplat(MI, MRI);
auto MaybeCst = getIConstantSplatSExtVal(MI, MRI);
if (!MaybeCst)
return None;
const unsigned ScalarSize = MRI.getType(Def).getScalarSizeInBits();

View File

@ -143,13 +143,9 @@ body: |
; CHECK-NEXT: {{ $}}
; CHECK-NEXT: %a:_(<2 x s32>) = COPY $vgpr0_vgpr1
; CHECK-NEXT: %b:_(<2 x s32>) = COPY $vgpr2_vgpr3
; CHECK-NEXT: %scalar_amt0:_(s32) = G_CONSTANT i32 20
; CHECK-NEXT: %amt0:_(<2 x s32>) = G_BUILD_VECTOR %scalar_amt0(s32), %scalar_amt0(s32)
; CHECK-NEXT: %scalar_amt1:_(s32) = G_CONSTANT i32 12
; CHECK-NEXT: %amt1:_(<2 x s32>) = G_BUILD_VECTOR %scalar_amt1(s32), %scalar_amt1(s32)
; CHECK-NEXT: %shl:_(<2 x s32>) = G_SHL %a, %amt0(<2 x s32>)
; CHECK-NEXT: %lshr:_(<2 x s32>) = G_LSHR %b, %amt1(<2 x s32>)
; CHECK-NEXT: %or:_(<2 x s32>) = G_OR %shl, %lshr
; CHECK-NEXT: %or:_(<2 x s32>) = G_FSHR %a, %b, %amt1(<2 x s32>)
; CHECK-NEXT: $vgpr4_vgpr5 = COPY %or(<2 x s32>)
%a:_(<2 x s32>) = COPY $vgpr0_vgpr1
%b:_(<2 x s32>) = COPY $vgpr2_vgpr3

View File

@ -132,13 +132,9 @@ body: |
; CHECK: liveins: $vgpr0_vgpr1, $vgpr2_vgpr3
; CHECK-NEXT: {{ $}}
; CHECK-NEXT: %a:_(<2 x s32>) = COPY $vgpr0_vgpr1
; CHECK-NEXT: %scalar_amt0:_(s32) = G_CONSTANT i32 20
; CHECK-NEXT: %amt0:_(<2 x s32>) = G_BUILD_VECTOR %scalar_amt0(s32), %scalar_amt0(s32)
; CHECK-NEXT: %scalar_amt1:_(s32) = G_CONSTANT i32 12
; CHECK-NEXT: %amt1:_(<2 x s32>) = G_BUILD_VECTOR %scalar_amt1(s32), %scalar_amt1(s32)
; CHECK-NEXT: %shl:_(<2 x s32>) = G_SHL %a, %amt0(<2 x s32>)
; CHECK-NEXT: %lshr:_(<2 x s32>) = G_LSHR %a, %amt1(<2 x s32>)
; CHECK-NEXT: %or:_(<2 x s32>) = G_OR %shl, %lshr
; CHECK-NEXT: %or:_(<2 x s32>) = G_ROTR %a, %amt1(<2 x s32>)
; CHECK-NEXT: $vgpr2_vgpr3 = COPY %or(<2 x s32>)
%a:_(<2 x s32>) = COPY $vgpr0_vgpr1
%scalar_amt0:_(s32) = G_CONSTANT i32 20

View File

@ -51,6 +51,25 @@ TEST_F(AArch64GISelMITest, MatchIntConstantRegister) {
EXPECT_EQ(Src0->VReg, MIBCst.getReg(0));
}
TEST_F(AArch64GISelMITest, MatchIntConstantSplat) {
setUp();
if (!TM)
return;
LLT s64 = LLT::scalar(64);
LLT v4s64 = LLT::fixed_vector(4, s64);
MachineInstrBuilder FortyTwoSplat =
B.buildSplatVector(v4s64, B.buildConstant(s64, 42));
int64_t Cst;
EXPECT_TRUE(mi_match(FortyTwoSplat.getReg(0), *MRI, m_ICstOrSplat(Cst)));
EXPECT_EQ(Cst, 42);
MachineInstrBuilder NonConstantSplat =
B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]});
EXPECT_FALSE(mi_match(NonConstantSplat.getReg(0), *MRI, m_ICstOrSplat(Cst)));
}
TEST_F(AArch64GISelMITest, MachineInstrPtrBind) {
setUp();
if (!TM)