diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h b/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h index 2b0ef6c3af57..bb5f55789a0e 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h @@ -206,6 +206,14 @@ public: } }; +/// Represents a G_IMPLICIT_DEF. +class GImplicitDef : public GenericMachineInstr { +public: + static bool classof(const MachineInstr *MI) { + return MI->getOpcode() == TargetOpcode::G_IMPLICIT_DEF; + } +}; + } // namespace llvm #endif // LLVM_CODEGEN_GLOBALISEL_GENERICMACHINEINSTRS_H \ No newline at end of file diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h index d8cebee063a4..e813d030eec3 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h @@ -99,6 +99,21 @@ inline GFCstAndRegMatch m_GFCst(Optional &FPValReg) { return GFCstAndRegMatch(FPValReg); } +struct GFCstOrSplatGFCstMatch { + Optional &FPValReg; + GFCstOrSplatGFCstMatch(Optional &FPValReg) + : FPValReg(FPValReg) {} + bool match(const MachineRegisterInfo &MRI, Register Reg) { + return (FPValReg = getFConstantSplat(Reg, MRI)) || + (FPValReg = getFConstantVRegValWithLookThrough(Reg, MRI)); + }; +}; + +inline GFCstOrSplatGFCstMatch +m_GFCstOrSplat(Optional &FPValReg) { + return GFCstOrSplatGFCstMatch(FPValReg); +} + /// Matcher for a specific constant value. struct SpecificConstantMatch { int64_t RequestedVal; diff --git a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h index daaa09911548..a6e6e4942d22 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h @@ -357,15 +357,23 @@ Optional getSplatIndex(MachineInstr &MI); Optional getBuildVectorConstantSplat(const MachineInstr &MI, const MachineRegisterInfo &MRI); +/// Returns a floating point scalar constant of a build vector splat if it +/// exists. When \p AllowUndef == true some elements can be undef but not all. +Optional getFConstantSplat(Register VReg, + const MachineRegisterInfo &MRI, + bool AllowUndef = true); + /// Return true if the specified instruction is a G_BUILD_VECTOR or /// G_BUILD_VECTOR_TRUNC where all of the elements are 0 or undef. bool isBuildVectorAllZeros(const MachineInstr &MI, - const MachineRegisterInfo &MRI); + const MachineRegisterInfo &MRI, + bool AllowUndef = false); /// Return true if the specified instruction is a G_BUILD_VECTOR or /// G_BUILD_VECTOR_TRUNC where all of the elements are ~0 or undef. bool isBuildVectorAllOnes(const MachineInstr &MI, - const MachineRegisterInfo &MRI); + const MachineRegisterInfo &MRI, + bool AllowUndef = false); /// \returns a value when \p MI is a vector splat. The splat can be either a /// Register or a constant. diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp index 177d4025bbb8..3c09df0b6970 100644 --- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp +++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp @@ -15,6 +15,7 @@ #include "llvm/ADT/Optional.h" #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h" #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h" +#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" #include "llvm/CodeGen/GlobalISel/RegisterBankInfo.h" #include "llvm/CodeGen/MachineInstr.h" @@ -924,53 +925,81 @@ static bool isBuildVectorOp(unsigned Opcode) { Opcode == TargetOpcode::G_BUILD_VECTOR_TRUNC; } -// TODO: Handle mixed undef elements. -static bool isBuildVectorConstantSplat(const MachineInstr &MI, - const MachineRegisterInfo &MRI, - int64_t SplatValue) { - if (!isBuildVectorOp(MI.getOpcode())) - return false; +namespace { - const unsigned NumOps = MI.getNumOperands(); - for (unsigned I = 1; I != NumOps; ++I) { - Register Element = MI.getOperand(I).getReg(); - if (!mi_match(Element, MRI, m_SpecificICst(SplatValue))) - return false; +Optional getAnyConstantSplat(Register VReg, + const MachineRegisterInfo &MRI, + bool AllowUndef) { + MachineInstr *MI = getDefIgnoringCopies(VReg, MRI); + if (!MI) + return None; + + if (!isBuildVectorOp(MI->getOpcode())) + return None; + + Optional SplatValAndReg = None; + for (MachineOperand &Op : MI->uses()) { + Register Element = Op.getReg(); + auto ElementValAndReg = + getAnyConstantVRegValWithLookThrough(Element, MRI, true, true); + + // If AllowUndef, treat undef as value that will result in a constant splat. + if (!ElementValAndReg) { + if (AllowUndef && isa(MRI.getVRegDef(Element))) + continue; + return None; + } + + // Record splat value + if (!SplatValAndReg) + SplatValAndReg = ElementValAndReg; + + // Different constant then the one already recorded, not a constant splat. + if (SplatValAndReg->Value != ElementValAndReg->Value) + return None; } - return true; + return SplatValAndReg; } +bool isBuildVectorConstantSplat(const MachineInstr &MI, + const MachineRegisterInfo &MRI, + int64_t SplatValue, bool AllowUndef) { + if (auto SplatValAndReg = + getAnyConstantSplat(MI.getOperand(0).getReg(), MRI, AllowUndef)) + return mi_match(SplatValAndReg->VReg, MRI, m_SpecificICst(SplatValue)); + return false; +} + +} // end anonymous namespace + Optional llvm::getBuildVectorConstantSplat(const MachineInstr &MI, const MachineRegisterInfo &MRI) { - if (!isBuildVectorOp(MI.getOpcode())) - return None; + if (auto SplatValAndReg = + getAnyConstantSplat(MI.getOperand(0).getReg(), MRI, false)) + return getIConstantVRegSExtVal(SplatValAndReg->VReg, MRI); + return None; +} - const unsigned NumOps = MI.getNumOperands(); - Optional Scalar; - for (unsigned I = 1; I != NumOps; ++I) { - Register Element = MI.getOperand(I).getReg(); - int64_t ElementValue; - if (!mi_match(Element, MRI, m_ICst(ElementValue))) - return None; - if (!Scalar) - Scalar = ElementValue; - else if (*Scalar != ElementValue) - return None; - } - - return Scalar; +Optional llvm::getFConstantSplat(Register VReg, + const MachineRegisterInfo &MRI, + bool AllowUndef) { + if (auto SplatValAndReg = getAnyConstantSplat(VReg, MRI, AllowUndef)) + return getFConstantVRegValWithLookThrough(SplatValAndReg->VReg, MRI); + return None; } bool llvm::isBuildVectorAllZeros(const MachineInstr &MI, - const MachineRegisterInfo &MRI) { - return isBuildVectorConstantSplat(MI, MRI, 0); + const MachineRegisterInfo &MRI, + bool AllowUndef) { + return isBuildVectorConstantSplat(MI, MRI, 0, AllowUndef); } bool llvm::isBuildVectorAllOnes(const MachineInstr &MI, - const MachineRegisterInfo &MRI) { - return isBuildVectorConstantSplat(MI, MRI, -1); + const MachineRegisterInfo &MRI, + bool AllowUndef) { + return isBuildVectorConstantSplat(MI, MRI, -1, AllowUndef); } Optional llvm::getVectorSplat(const MachineInstr &MI, diff --git a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp index 9ebb4b1cc54f..b5f4e2266b07 100644 --- a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp +++ b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp @@ -574,6 +574,57 @@ TEST_F(AArch64GISelMITest, MatchFPOrIntConst) { EXPECT_EQ(FPOne, FValReg->VReg); } +TEST_F(AArch64GISelMITest, MatchConstantSplat) { + setUp(); + if (!TM) + return; + + LLT s64 = LLT::scalar(64); + LLT v4s64 = LLT::fixed_vector(4, 64); + + Register FPOne = B.buildFConstant(s64, 1.0).getReg(0); + Register FPZero = B.buildFConstant(s64, 0.0).getReg(0); + Register Undef = B.buildUndef(s64).getReg(0); + Optional FValReg; + + // GFCstOrSplatGFCstMatch allows undef as part of splat. Undef often comes + // from padding to legalize into available operation and then ignore added + // elements e.g. v3s64 to v4s64. + + EXPECT_TRUE(mi_match(FPZero, *MRI, GFCstOrSplatGFCstMatch(FValReg))); + EXPECT_EQ(FPZero, FValReg->VReg); + + EXPECT_FALSE(mi_match(Undef, *MRI, GFCstOrSplatGFCstMatch(FValReg))); + + auto ZeroSplat = B.buildBuildVector(v4s64, {FPZero, FPZero, FPZero, FPZero}); + EXPECT_TRUE( + mi_match(ZeroSplat.getReg(0), *MRI, GFCstOrSplatGFCstMatch(FValReg))); + EXPECT_EQ(FPZero, FValReg->VReg); + + auto ZeroUndef = B.buildBuildVector(v4s64, {FPZero, FPZero, FPZero, Undef}); + EXPECT_TRUE( + mi_match(ZeroUndef.getReg(0), *MRI, GFCstOrSplatGFCstMatch(FValReg))); + EXPECT_EQ(FPZero, FValReg->VReg); + + // All undefs are not constant splat. + auto UndefSplat = B.buildBuildVector(v4s64, {Undef, Undef, Undef, Undef}); + EXPECT_FALSE( + mi_match(UndefSplat.getReg(0), *MRI, GFCstOrSplatGFCstMatch(FValReg))); + + auto ZeroOne = B.buildBuildVector(v4s64, {FPZero, FPZero, FPZero, FPOne}); + EXPECT_FALSE( + mi_match(ZeroOne.getReg(0), *MRI, GFCstOrSplatGFCstMatch(FValReg))); + + auto NonConstantSplat = + B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]}); + EXPECT_FALSE(mi_match(NonConstantSplat.getReg(0), *MRI, + GFCstOrSplatGFCstMatch(FValReg))); + + auto Mixed = B.buildBuildVector(v4s64, {FPZero, FPZero, FPZero, Copies[0]}); + EXPECT_FALSE( + mi_match(Mixed.getReg(0), *MRI, GFCstOrSplatGFCstMatch(FValReg))); +} + TEST_F(AArch64GISelMITest, MatchNeg) { setUp(); if (!TM)