From 8bc71856681c235a3192813947308a19577c9236 Mon Sep 17 00:00:00 2001 From: Petar Avramovic Date: Tue, 21 Sep 2021 11:54:12 +0200 Subject: [PATCH] GlobalISel/Utils: Refactor constant splat match functions Add generic helper function that matches constant splat. It has option to match constant splat with undef (some elements can be undef but not all). Add util function and matcher for G_FCONSTANT splat. Differential Revision: https://reviews.llvm.org/D104410 --- .../CodeGen/GlobalISel/GenericMachineInstrs.h | 8 ++ .../llvm/CodeGen/GlobalISel/MIPatternMatch.h | 15 +++ llvm/include/llvm/CodeGen/GlobalISel/Utils.h | 12 ++- llvm/lib/CodeGen/GlobalISel/Utils.cpp | 93 ++++++++++++------- .../CodeGen/GlobalISel/PatternMatchTest.cpp | 51 ++++++++++ 5 files changed, 145 insertions(+), 34 deletions(-) diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h b/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h index 2b0ef6c3af57..bb5f55789a0e 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h @@ -206,6 +206,14 @@ public: } }; +/// Represents a G_IMPLICIT_DEF. +class GImplicitDef : public GenericMachineInstr { +public: + static bool classof(const MachineInstr *MI) { + return MI->getOpcode() == TargetOpcode::G_IMPLICIT_DEF; + } +}; + } // namespace llvm #endif // LLVM_CODEGEN_GLOBALISEL_GENERICMACHINEINSTRS_H \ No newline at end of file diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h index d8cebee063a4..e813d030eec3 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h @@ -99,6 +99,21 @@ inline GFCstAndRegMatch m_GFCst(Optional &FPValReg) { return GFCstAndRegMatch(FPValReg); } +struct GFCstOrSplatGFCstMatch { + Optional &FPValReg; + GFCstOrSplatGFCstMatch(Optional &FPValReg) + : FPValReg(FPValReg) {} + bool match(const MachineRegisterInfo &MRI, Register Reg) { + return (FPValReg = getFConstantSplat(Reg, MRI)) || + (FPValReg = getFConstantVRegValWithLookThrough(Reg, MRI)); + }; +}; + +inline GFCstOrSplatGFCstMatch +m_GFCstOrSplat(Optional &FPValReg) { + return GFCstOrSplatGFCstMatch(FPValReg); +} + /// Matcher for a specific constant value. struct SpecificConstantMatch { int64_t RequestedVal; diff --git a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h index daaa09911548..a6e6e4942d22 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h @@ -357,15 +357,23 @@ Optional getSplatIndex(MachineInstr &MI); Optional getBuildVectorConstantSplat(const MachineInstr &MI, const MachineRegisterInfo &MRI); +/// Returns a floating point scalar constant of a build vector splat if it +/// exists. When \p AllowUndef == true some elements can be undef but not all. +Optional getFConstantSplat(Register VReg, + const MachineRegisterInfo &MRI, + bool AllowUndef = true); + /// Return true if the specified instruction is a G_BUILD_VECTOR or /// G_BUILD_VECTOR_TRUNC where all of the elements are 0 or undef. bool isBuildVectorAllZeros(const MachineInstr &MI, - const MachineRegisterInfo &MRI); + const MachineRegisterInfo &MRI, + bool AllowUndef = false); /// Return true if the specified instruction is a G_BUILD_VECTOR or /// G_BUILD_VECTOR_TRUNC where all of the elements are ~0 or undef. bool isBuildVectorAllOnes(const MachineInstr &MI, - const MachineRegisterInfo &MRI); + const MachineRegisterInfo &MRI, + bool AllowUndef = false); /// \returns a value when \p MI is a vector splat. The splat can be either a /// Register or a constant. diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp index 177d4025bbb8..3c09df0b6970 100644 --- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp +++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp @@ -15,6 +15,7 @@ #include "llvm/ADT/Optional.h" #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h" #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h" +#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" #include "llvm/CodeGen/GlobalISel/RegisterBankInfo.h" #include "llvm/CodeGen/MachineInstr.h" @@ -924,53 +925,81 @@ static bool isBuildVectorOp(unsigned Opcode) { Opcode == TargetOpcode::G_BUILD_VECTOR_TRUNC; } -// TODO: Handle mixed undef elements. -static bool isBuildVectorConstantSplat(const MachineInstr &MI, - const MachineRegisterInfo &MRI, - int64_t SplatValue) { - if (!isBuildVectorOp(MI.getOpcode())) - return false; +namespace { - const unsigned NumOps = MI.getNumOperands(); - for (unsigned I = 1; I != NumOps; ++I) { - Register Element = MI.getOperand(I).getReg(); - if (!mi_match(Element, MRI, m_SpecificICst(SplatValue))) - return false; +Optional getAnyConstantSplat(Register VReg, + const MachineRegisterInfo &MRI, + bool AllowUndef) { + MachineInstr *MI = getDefIgnoringCopies(VReg, MRI); + if (!MI) + return None; + + if (!isBuildVectorOp(MI->getOpcode())) + return None; + + Optional SplatValAndReg = None; + for (MachineOperand &Op : MI->uses()) { + Register Element = Op.getReg(); + auto ElementValAndReg = + getAnyConstantVRegValWithLookThrough(Element, MRI, true, true); + + // If AllowUndef, treat undef as value that will result in a constant splat. + if (!ElementValAndReg) { + if (AllowUndef && isa(MRI.getVRegDef(Element))) + continue; + return None; + } + + // Record splat value + if (!SplatValAndReg) + SplatValAndReg = ElementValAndReg; + + // Different constant then the one already recorded, not a constant splat. + if (SplatValAndReg->Value != ElementValAndReg->Value) + return None; } - return true; + return SplatValAndReg; } +bool isBuildVectorConstantSplat(const MachineInstr &MI, + const MachineRegisterInfo &MRI, + int64_t SplatValue, bool AllowUndef) { + if (auto SplatValAndReg = + getAnyConstantSplat(MI.getOperand(0).getReg(), MRI, AllowUndef)) + return mi_match(SplatValAndReg->VReg, MRI, m_SpecificICst(SplatValue)); + return false; +} + +} // end anonymous namespace + Optional llvm::getBuildVectorConstantSplat(const MachineInstr &MI, const MachineRegisterInfo &MRI) { - if (!isBuildVectorOp(MI.getOpcode())) - return None; + if (auto SplatValAndReg = + getAnyConstantSplat(MI.getOperand(0).getReg(), MRI, false)) + return getIConstantVRegSExtVal(SplatValAndReg->VReg, MRI); + return None; +} - const unsigned NumOps = MI.getNumOperands(); - Optional Scalar; - for (unsigned I = 1; I != NumOps; ++I) { - Register Element = MI.getOperand(I).getReg(); - int64_t ElementValue; - if (!mi_match(Element, MRI, m_ICst(ElementValue))) - return None; - if (!Scalar) - Scalar = ElementValue; - else if (*Scalar != ElementValue) - return None; - } - - return Scalar; +Optional llvm::getFConstantSplat(Register VReg, + const MachineRegisterInfo &MRI, + bool AllowUndef) { + if (auto SplatValAndReg = getAnyConstantSplat(VReg, MRI, AllowUndef)) + return getFConstantVRegValWithLookThrough(SplatValAndReg->VReg, MRI); + return None; } bool llvm::isBuildVectorAllZeros(const MachineInstr &MI, - const MachineRegisterInfo &MRI) { - return isBuildVectorConstantSplat(MI, MRI, 0); + const MachineRegisterInfo &MRI, + bool AllowUndef) { + return isBuildVectorConstantSplat(MI, MRI, 0, AllowUndef); } bool llvm::isBuildVectorAllOnes(const MachineInstr &MI, - const MachineRegisterInfo &MRI) { - return isBuildVectorConstantSplat(MI, MRI, -1); + const MachineRegisterInfo &MRI, + bool AllowUndef) { + return isBuildVectorConstantSplat(MI, MRI, -1, AllowUndef); } Optional llvm::getVectorSplat(const MachineInstr &MI, diff --git a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp index 9ebb4b1cc54f..b5f4e2266b07 100644 --- a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp +++ b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp @@ -574,6 +574,57 @@ TEST_F(AArch64GISelMITest, MatchFPOrIntConst) { EXPECT_EQ(FPOne, FValReg->VReg); } +TEST_F(AArch64GISelMITest, MatchConstantSplat) { + setUp(); + if (!TM) + return; + + LLT s64 = LLT::scalar(64); + LLT v4s64 = LLT::fixed_vector(4, 64); + + Register FPOne = B.buildFConstant(s64, 1.0).getReg(0); + Register FPZero = B.buildFConstant(s64, 0.0).getReg(0); + Register Undef = B.buildUndef(s64).getReg(0); + Optional FValReg; + + // GFCstOrSplatGFCstMatch allows undef as part of splat. Undef often comes + // from padding to legalize into available operation and then ignore added + // elements e.g. v3s64 to v4s64. + + EXPECT_TRUE(mi_match(FPZero, *MRI, GFCstOrSplatGFCstMatch(FValReg))); + EXPECT_EQ(FPZero, FValReg->VReg); + + EXPECT_FALSE(mi_match(Undef, *MRI, GFCstOrSplatGFCstMatch(FValReg))); + + auto ZeroSplat = B.buildBuildVector(v4s64, {FPZero, FPZero, FPZero, FPZero}); + EXPECT_TRUE( + mi_match(ZeroSplat.getReg(0), *MRI, GFCstOrSplatGFCstMatch(FValReg))); + EXPECT_EQ(FPZero, FValReg->VReg); + + auto ZeroUndef = B.buildBuildVector(v4s64, {FPZero, FPZero, FPZero, Undef}); + EXPECT_TRUE( + mi_match(ZeroUndef.getReg(0), *MRI, GFCstOrSplatGFCstMatch(FValReg))); + EXPECT_EQ(FPZero, FValReg->VReg); + + // All undefs are not constant splat. + auto UndefSplat = B.buildBuildVector(v4s64, {Undef, Undef, Undef, Undef}); + EXPECT_FALSE( + mi_match(UndefSplat.getReg(0), *MRI, GFCstOrSplatGFCstMatch(FValReg))); + + auto ZeroOne = B.buildBuildVector(v4s64, {FPZero, FPZero, FPZero, FPOne}); + EXPECT_FALSE( + mi_match(ZeroOne.getReg(0), *MRI, GFCstOrSplatGFCstMatch(FValReg))); + + auto NonConstantSplat = + B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]}); + EXPECT_FALSE(mi_match(NonConstantSplat.getReg(0), *MRI, + GFCstOrSplatGFCstMatch(FValReg))); + + auto Mixed = B.buildBuildVector(v4s64, {FPZero, FPZero, FPZero, Copies[0]}); + EXPECT_FALSE( + mi_match(Mixed.getReg(0), *MRI, GFCstOrSplatGFCstMatch(FValReg))); +} + TEST_F(AArch64GISelMITest, MatchNeg) { setUp(); if (!TM)