forked from OSchip/llvm-project
[GlobalISel] Handle constant splat in funnel shift combine
This change adds the constant splat versions of m_ICst() (by using getBuildVectorConstantSplat()) and uses it in matchOrShiftToFunnelShift(). The getBuildVectorConstantSplat() name is shortened to getIConstantSplatVal() so that the *SExtVal() version would have a more compact name. Differential Revision: https://reviews.llvm.org/D125516
This commit is contained in:
parent
f96d20450c
commit
485dd0b752
|
@ -94,6 +94,48 @@ inline ConstantMatch<int64_t> m_ICst(int64_t &Cst) {
|
|||
return ConstantMatch<int64_t>(Cst);
|
||||
}
|
||||
|
||||
template <typename ConstT>
|
||||
inline Optional<ConstT> matchConstantSplat(Register,
|
||||
const MachineRegisterInfo &);
|
||||
|
||||
template <>
|
||||
inline Optional<APInt> matchConstantSplat(Register Reg,
|
||||
const MachineRegisterInfo &MRI) {
|
||||
return getIConstantSplatVal(Reg, MRI);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Optional<int64_t> matchConstantSplat(Register Reg,
|
||||
const MachineRegisterInfo &MRI) {
|
||||
return getIConstantSplatSExtVal(Reg, MRI);
|
||||
}
|
||||
|
||||
template <typename ConstT> struct ICstOrSplatMatch {
|
||||
ConstT &CR;
|
||||
ICstOrSplatMatch(ConstT &C) : CR(C) {}
|
||||
bool match(const MachineRegisterInfo &MRI, Register Reg) {
|
||||
if (auto MaybeCst = matchConstant<ConstT>(Reg, MRI)) {
|
||||
CR = *MaybeCst;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (auto MaybeCstSplat = matchConstantSplat<ConstT>(Reg, MRI)) {
|
||||
CR = *MaybeCstSplat;
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
};
|
||||
|
||||
inline ICstOrSplatMatch<APInt> m_ICstOrSplat(APInt &Cst) {
|
||||
return ICstOrSplatMatch<APInt>(Cst);
|
||||
}
|
||||
|
||||
inline ICstOrSplatMatch<int64_t> m_ICstOrSplat(int64_t &Cst) {
|
||||
return ICstOrSplatMatch<int64_t>(Cst);
|
||||
}
|
||||
|
||||
struct GCstAndRegMatch {
|
||||
Optional<ValueAndVReg> &ValReg;
|
||||
GCstAndRegMatch(Optional<ValueAndVReg> &ValReg) : ValReg(ValReg) {}
|
||||
|
|
|
@ -373,9 +373,23 @@ public:
|
|||
/// If \p MI is not a splat, returns None.
|
||||
Optional<int> getSplatIndex(MachineInstr &MI);
|
||||
|
||||
/// Returns a scalar constant of a G_BUILD_VECTOR splat if it exists.
|
||||
Optional<int64_t> getBuildVectorConstantSplat(const MachineInstr &MI,
|
||||
const MachineRegisterInfo &MRI);
|
||||
/// \returns the scalar integral splat value of \p Reg if possible.
|
||||
Optional<APInt> getIConstantSplatVal(const Register Reg,
|
||||
const MachineRegisterInfo &MRI);
|
||||
|
||||
/// \returns the scalar integral splat value defined by \p MI if possible.
|
||||
Optional<APInt> getIConstantSplatVal(const MachineInstr &MI,
|
||||
const MachineRegisterInfo &MRI);
|
||||
|
||||
/// \returns the scalar sign extended integral splat value of \p Reg if
|
||||
/// possible.
|
||||
Optional<int64_t> getIConstantSplatSExtVal(const Register Reg,
|
||||
const MachineRegisterInfo &MRI);
|
||||
|
||||
/// \returns the scalar sign extended integral splat value defined by \p MI if
|
||||
/// possible.
|
||||
Optional<int64_t> getIConstantSplatSExtVal(const MachineInstr &MI,
|
||||
const MachineRegisterInfo &MRI);
|
||||
|
||||
/// Returns a floating point scalar constant of a build vector splat if it
|
||||
/// exists. When \p AllowUndef == true some elements can be undef but not all.
|
||||
|
|
|
@ -2945,7 +2945,7 @@ bool CombinerHelper::matchNotCmp(MachineInstr &MI,
|
|||
int64_t Cst;
|
||||
if (Ty.isVector()) {
|
||||
MachineInstr *CstDef = MRI.getVRegDef(CstReg);
|
||||
auto MaybeCst = getBuildVectorConstantSplat(*CstDef, MRI);
|
||||
auto MaybeCst = getIConstantSplatSExtVal(*CstDef, MRI);
|
||||
if (!MaybeCst)
|
||||
return false;
|
||||
if (!isConstValidTrue(TLI, Ty.getScalarSizeInBits(), *MaybeCst, true, IsFP))
|
||||
|
@ -4029,10 +4029,9 @@ bool CombinerHelper::matchOrShiftToFunnelShift(MachineInstr &MI,
|
|||
|
||||
// Given constants C0 and C1 such that C0 + C1 is bit-width:
|
||||
// (or (shl x, C0), (lshr y, C1)) -> (fshl x, y, C0) or (fshr x, y, C1)
|
||||
// TODO: Match constant splat.
|
||||
int64_t CstShlAmt, CstLShrAmt;
|
||||
if (mi_match(ShlAmt, MRI, m_ICst(CstShlAmt)) &&
|
||||
mi_match(LShrAmt, MRI, m_ICst(CstLShrAmt)) &&
|
||||
if (mi_match(ShlAmt, MRI, m_ICstOrSplat(CstShlAmt)) &&
|
||||
mi_match(LShrAmt, MRI, m_ICstOrSplat(CstLShrAmt)) &&
|
||||
CstShlAmt + CstLShrAmt == BitWidth) {
|
||||
FshOpc = TargetOpcode::G_FSHR;
|
||||
Amt = LShrAmt;
|
||||
|
|
|
@ -1071,15 +1071,38 @@ bool llvm::isBuildVectorConstantSplat(const MachineInstr &MI,
|
|||
AllowUndef);
|
||||
}
|
||||
|
||||
Optional<int64_t>
|
||||
llvm::getBuildVectorConstantSplat(const MachineInstr &MI,
|
||||
const MachineRegisterInfo &MRI) {
|
||||
Optional<APInt> llvm::getIConstantSplatVal(const Register Reg,
|
||||
const MachineRegisterInfo &MRI) {
|
||||
if (auto SplatValAndReg =
|
||||
getAnyConstantSplat(MI.getOperand(0).getReg(), MRI, false))
|
||||
getAnyConstantSplat(Reg, MRI, /* AllowUndef */ false)) {
|
||||
Optional<ValueAndVReg> ValAndVReg =
|
||||
getIConstantVRegValWithLookThrough(SplatValAndReg->VReg, MRI);
|
||||
return ValAndVReg->Value;
|
||||
}
|
||||
|
||||
return None;
|
||||
}
|
||||
|
||||
Optional<APInt> getIConstantSplatVal(const MachineInstr &MI,
|
||||
const MachineRegisterInfo &MRI) {
|
||||
return getIConstantSplatVal(MI.getOperand(0).getReg(), MRI);
|
||||
}
|
||||
|
||||
Optional<int64_t>
|
||||
llvm::getIConstantSplatSExtVal(const Register Reg,
|
||||
const MachineRegisterInfo &MRI) {
|
||||
if (auto SplatValAndReg =
|
||||
getAnyConstantSplat(Reg, MRI, /* AllowUndef */ false))
|
||||
return getIConstantVRegSExtVal(SplatValAndReg->VReg, MRI);
|
||||
return None;
|
||||
}
|
||||
|
||||
Optional<int64_t>
|
||||
llvm::getIConstantSplatSExtVal(const MachineInstr &MI,
|
||||
const MachineRegisterInfo &MRI) {
|
||||
return getIConstantSplatSExtVal(MI.getOperand(0).getReg(), MRI);
|
||||
}
|
||||
|
||||
Optional<FPValueAndVReg> llvm::getFConstantSplat(Register VReg,
|
||||
const MachineRegisterInfo &MRI,
|
||||
bool AllowUndef) {
|
||||
|
@ -1105,7 +1128,7 @@ Optional<RegOrConstant> llvm::getVectorSplat(const MachineInstr &MI,
|
|||
unsigned Opc = MI.getOpcode();
|
||||
if (!isBuildVectorOp(Opc))
|
||||
return None;
|
||||
if (auto Splat = getBuildVectorConstantSplat(MI, MRI))
|
||||
if (auto Splat = getIConstantSplatSExtVal(MI, MRI))
|
||||
return RegOrConstant(*Splat);
|
||||
auto Reg = MI.getOperand(1).getReg();
|
||||
if (any_of(make_range(MI.operands_begin() + 2, MI.operands_end()),
|
||||
|
@ -1176,7 +1199,7 @@ llvm::isConstantOrConstantSplatVector(MachineInstr &MI,
|
|||
Register Def = MI.getOperand(0).getReg();
|
||||
if (auto C = getIConstantVRegValWithLookThrough(Def, MRI))
|
||||
return C->Value;
|
||||
auto MaybeCst = getBuildVectorConstantSplat(MI, MRI);
|
||||
auto MaybeCst = getIConstantSplatSExtVal(MI, MRI);
|
||||
if (!MaybeCst)
|
||||
return None;
|
||||
const unsigned ScalarSize = MRI.getType(Def).getScalarSizeInBits();
|
||||
|
|
|
@ -143,13 +143,9 @@ body: |
|
|||
; CHECK-NEXT: {{ $}}
|
||||
; CHECK-NEXT: %a:_(<2 x s32>) = COPY $vgpr0_vgpr1
|
||||
; CHECK-NEXT: %b:_(<2 x s32>) = COPY $vgpr2_vgpr3
|
||||
; CHECK-NEXT: %scalar_amt0:_(s32) = G_CONSTANT i32 20
|
||||
; CHECK-NEXT: %amt0:_(<2 x s32>) = G_BUILD_VECTOR %scalar_amt0(s32), %scalar_amt0(s32)
|
||||
; CHECK-NEXT: %scalar_amt1:_(s32) = G_CONSTANT i32 12
|
||||
; CHECK-NEXT: %amt1:_(<2 x s32>) = G_BUILD_VECTOR %scalar_amt1(s32), %scalar_amt1(s32)
|
||||
; CHECK-NEXT: %shl:_(<2 x s32>) = G_SHL %a, %amt0(<2 x s32>)
|
||||
; CHECK-NEXT: %lshr:_(<2 x s32>) = G_LSHR %b, %amt1(<2 x s32>)
|
||||
; CHECK-NEXT: %or:_(<2 x s32>) = G_OR %shl, %lshr
|
||||
; CHECK-NEXT: %or:_(<2 x s32>) = G_FSHR %a, %b, %amt1(<2 x s32>)
|
||||
; CHECK-NEXT: $vgpr4_vgpr5 = COPY %or(<2 x s32>)
|
||||
%a:_(<2 x s32>) = COPY $vgpr0_vgpr1
|
||||
%b:_(<2 x s32>) = COPY $vgpr2_vgpr3
|
||||
|
|
|
@ -132,13 +132,9 @@ body: |
|
|||
; CHECK: liveins: $vgpr0_vgpr1, $vgpr2_vgpr3
|
||||
; CHECK-NEXT: {{ $}}
|
||||
; CHECK-NEXT: %a:_(<2 x s32>) = COPY $vgpr0_vgpr1
|
||||
; CHECK-NEXT: %scalar_amt0:_(s32) = G_CONSTANT i32 20
|
||||
; CHECK-NEXT: %amt0:_(<2 x s32>) = G_BUILD_VECTOR %scalar_amt0(s32), %scalar_amt0(s32)
|
||||
; CHECK-NEXT: %scalar_amt1:_(s32) = G_CONSTANT i32 12
|
||||
; CHECK-NEXT: %amt1:_(<2 x s32>) = G_BUILD_VECTOR %scalar_amt1(s32), %scalar_amt1(s32)
|
||||
; CHECK-NEXT: %shl:_(<2 x s32>) = G_SHL %a, %amt0(<2 x s32>)
|
||||
; CHECK-NEXT: %lshr:_(<2 x s32>) = G_LSHR %a, %amt1(<2 x s32>)
|
||||
; CHECK-NEXT: %or:_(<2 x s32>) = G_OR %shl, %lshr
|
||||
; CHECK-NEXT: %or:_(<2 x s32>) = G_ROTR %a, %amt1(<2 x s32>)
|
||||
; CHECK-NEXT: $vgpr2_vgpr3 = COPY %or(<2 x s32>)
|
||||
%a:_(<2 x s32>) = COPY $vgpr0_vgpr1
|
||||
%scalar_amt0:_(s32) = G_CONSTANT i32 20
|
||||
|
|
|
@ -51,6 +51,25 @@ TEST_F(AArch64GISelMITest, MatchIntConstantRegister) {
|
|||
EXPECT_EQ(Src0->VReg, MIBCst.getReg(0));
|
||||
}
|
||||
|
||||
TEST_F(AArch64GISelMITest, MatchIntConstantSplat) {
|
||||
setUp();
|
||||
if (!TM)
|
||||
return;
|
||||
|
||||
LLT s64 = LLT::scalar(64);
|
||||
LLT v4s64 = LLT::fixed_vector(4, s64);
|
||||
|
||||
MachineInstrBuilder FortyTwoSplat =
|
||||
B.buildSplatVector(v4s64, B.buildConstant(s64, 42));
|
||||
int64_t Cst;
|
||||
EXPECT_TRUE(mi_match(FortyTwoSplat.getReg(0), *MRI, m_ICstOrSplat(Cst)));
|
||||
EXPECT_EQ(Cst, 42);
|
||||
|
||||
MachineInstrBuilder NonConstantSplat =
|
||||
B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]});
|
||||
EXPECT_FALSE(mi_match(NonConstantSplat.getReg(0), *MRI, m_ICstOrSplat(Cst)));
|
||||
}
|
||||
|
||||
TEST_F(AArch64GISelMITest, MachineInstrPtrBind) {
|
||||
setUp();
|
||||
if (!TM)
|
||||
|
|
Loading…
Reference in New Issue