diff --git a/llvm/test/TableGen/GlobalISelEmitter.td b/llvm/test/TableGen/GlobalISelEmitter.td index 50d4af01061f..f0efcd0582dc 100644 --- a/llvm/test/TableGen/GlobalISelEmitter.td +++ b/llvm/test/TableGen/GlobalISelEmitter.td @@ -18,12 +18,57 @@ class I Pat> let Pattern = Pat; } +def complex : Operand, ComplexPattern { + let MIOperandInfo = (ops i32imm, i32imm); +} +def gi_complex : + GIComplexOperandMatcher, + GIComplexPatternEquiv; + //===- Test the function definition boilerplate. --------------------------===// // CHECK: bool MyTargetInstructionSelector::selectImpl(MachineInstr &I) const { // CHECK: MachineFunction &MF = *I.getParent()->getParent(); // CHECK: const MachineRegisterInfo &MRI = MF.getRegInfo(); +//===- Test a pattern with multiple ComplexPattern operands. --------------===// +// + +// CHECK-LABEL: if ([&]() { +// CHECK-NEXT: MachineInstr &MI0 = I; +// CHECK-NEXT: if (MI0.getNumOperands() < 4) +// CHECK-NEXT: return false; +// CHECK-NEXT: if ((MI0.getOpcode() == TargetOpcode::G_SELECT) && +// CHECK-NEXT: ((/* dst */ (MRI.getType(MI0.getOperand(0).getReg()) == (LLT::scalar(32))) && +// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI0.getOperand(0).getReg(), MRI, TRI))))) && +// CHECK-NEXT: ((/* src1 */ (MRI.getType(MI0.getOperand(1).getReg()) == (LLT::scalar(32))) && +// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI0.getOperand(1).getReg(), MRI, TRI))))) && +// CHECK-NEXT: ((/* src2 */ (MRI.getType(MI0.getOperand(2).getReg()) == (LLT::scalar(32))) && +// CHECK-NEXT: (selectComplexPattern(MI0.getOperand(2), TempOp0, TempOp1)))) && +// CHECK-NEXT: ((/* src3 */ (MRI.getType(MI0.getOperand(3).getReg()) == (LLT::scalar(32))) && +// CHECK-NEXT: (selectComplexPattern(MI0.getOperand(3), TempOp2, TempOp3))))) { +// CHECK-NEXT: // (select:i32 GPR32:i32:$src1, complex:i32:$src2, complex:i32:$src3) => (INSN2:i32 GPR32:i32:$src1, complex:i32:$src3, complex:i32:$src2) +// CHECK-NEXT: MachineInstrBuilder MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(MyTarget::INSN2)); +// CHECK-NEXT: MIB.add(MI0.getOperand(0)/*dst*/); +// CHECK-NEXT: MIB.add(MI0.getOperand(1)/*src1*/); +// CHECK-NEXT: MIB.add(TempOp2); +// CHECK-NEXT: MIB.add(TempOp3); +// CHECK-NEXT: MIB.add(TempOp0); +// CHECK-NEXT: MIB.add(TempOp1); +// CHECK-NEXT: for (const auto *FromMI : {&MI0, }) +// CHECK-NEXT: for (const auto &MMO : FromMI->memoperands()) +// CHECK-NEXT: MIB.addMemOperand(MMO); +// CHECK-NEXT: I.eraseFromParent(); +// CHECK-NEXT: MachineInstr &NewI = *MIB; +// CHECK-NEXT: constrainSelectedInstRegOperands(NewI, TII, TRI, RBI); +// CHECK-NEXT: return true; +// CHECK-NEXT: } + +def : GINodeEquiv; +def INSN2 : I<(outs GPR32:$dst), (ins GPR32:$src1, complex:$src2, complex:$src3), []>; +def : Pat<(select GPR32:$src1, complex:$src2, complex:$src3), + (INSN2 GPR32:$src1, complex:$src3, complex:$src2)>; + //===- Test a simple pattern with regclass operands. ----------------------===// // CHECK-LABEL: if ([&]() { @@ -166,6 +211,38 @@ def MULADD : I<(outs GPR32:$dst), (ins GPR32:$src1, GPR32:$src2, GPR32:$src3), def MUL : I<(outs GPR32:$dst), (ins GPR32:$src2, GPR32:$src1), [(set GPR32:$dst, (mul GPR32:$src1, GPR32:$src2))]>; +//===- Test a pattern with ComplexPattern operands. -----------------------===// +// + +// CHECK-LABEL: if ([&]() { +// CHECK-NEXT: MachineInstr &MI0 = I; +// CHECK-NEXT: if (MI0.getNumOperands() < 3) +// CHECK-NEXT: return false; +// CHECK-NEXT: if ((MI0.getOpcode() == TargetOpcode::G_SUB) && +// CHECK-NEXT: ((/* dst */ (MRI.getType(MI0.getOperand(0).getReg()) == (LLT::scalar(32))) && +// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI0.getOperand(0).getReg(), MRI, TRI))))) && +// CHECK-NEXT: ((/* src1 */ (MRI.getType(MI0.getOperand(1).getReg()) == (LLT::scalar(32))) && +// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI0.getOperand(1).getReg(), MRI, TRI))))) && +// CHECK-NEXT: ((/* src2 */ (MRI.getType(MI0.getOperand(2).getReg()) == (LLT::scalar(32))) && +// CHECK-NEXT: (selectComplexPattern(MI0.getOperand(2), TempOp0, TempOp1))))) { +// CHECK-NEXT: // (sub:i32 GPR32:i32:$src1, complex:i32:$src2) => (INSN1:i32 GPR32:i32:$src1, complex:i32:$src2) +// CHECK-NEXT: MachineInstrBuilder MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(MyTarget::INSN1)); +// CHECK-NEXT: MIB.add(MI0.getOperand(0)/*dst*/); +// CHECK-NEXT: MIB.add(MI0.getOperand(1)/*src1*/); +// CHECK-NEXT: MIB.add(TempOp0); +// CHECK-NEXT: MIB.add(TempOp1); +// CHECK-NEXT: for (const auto *FromMI : {&MI0, }) +// CHECK-NEXT: for (const auto &MMO : FromMI->memoperands()) +// CHECK-NEXT: MIB.addMemOperand(MMO); +// CHECK-NEXT: I.eraseFromParent(); +// CHECK-NEXT: MachineInstr &NewI = *MIB; +// CHECK-NEXT: constrainSelectedInstRegOperands(NewI, TII, TRI, RBI); +// CHECK-NEXT: return true; +// CHECK-NEXT: } + +def INSN1 : I<(outs GPR32:$dst), (ins GPR32:$src1, complex:$src2), []>; +def : Pat<(sub GPR32:$src1, complex:$src2), (INSN1 GPR32:$src1, complex:$src2)>; + //===- Test a simple pattern with constant immediate operands. ------------===// // // This must precede the 3-register variants because constant immediates have diff --git a/llvm/utils/TableGen/GlobalISelEmitter.cpp b/llvm/utils/TableGen/GlobalISelEmitter.cpp index a2d1b1cbc919..7a8acc4ea152 100644 --- a/llvm/utils/TableGen/GlobalISelEmitter.cpp +++ b/llvm/utils/TableGen/GlobalISelEmitter.cpp @@ -332,26 +332,31 @@ public: /// Generates code to check that an operand is a particular target constant. class ComplexPatternOperandMatcher : public OperandPredicateMatcher { protected: + const OperandMatcher &Operand; const Record &TheDef; - /// The index of the first temporary operand to allocate to this - /// ComplexPattern. - unsigned BaseTemporaryID; unsigned getNumOperands() const { return TheDef.getValueAsDag("Operands")->getNumArgs(); } + unsigned getAllocatedTemporariesBaseID() const; + public: - ComplexPatternOperandMatcher(const Record &TheDef, unsigned BaseTemporaryID) - : OperandPredicateMatcher(OPM_ComplexPattern), TheDef(TheDef), - BaseTemporaryID(BaseTemporaryID) {} + ComplexPatternOperandMatcher(const OperandMatcher &Operand, + const Record &TheDef) + : OperandPredicateMatcher(OPM_ComplexPattern), Operand(Operand), + TheDef(TheDef) {} + + static bool classof(const OperandPredicateMatcher *P) { + return P->getKind() == OPM_ComplexPattern; + } void emitCxxPredicateExpr(raw_ostream &OS, RuleMatcher &Rule, StringRef OperandExpr) const override { OS << TheDef.getValueAsString("MatcherFn") << "(" << OperandExpr; for (unsigned I = 0; I < getNumOperands(); ++I) { OS << ", "; - OperandPlaceholder::CreateTemporary(BaseTemporaryID + I) + OperandPlaceholder::CreateTemporary(getAllocatedTemporariesBaseID() + I) .emitCxxValueExpr(OS); } OS << ")"; @@ -425,10 +430,17 @@ protected: unsigned OpIdx; std::string SymbolicName; + /// The index of the first temporary variable allocated to this operand. The + /// number of allocated temporaries can be found with + /// countTemporaryOperands(). + unsigned AllocatedTemporariesBaseID; + public: OperandMatcher(InstructionMatcher &Insn, unsigned OpIdx, - const std::string &SymbolicName) - : Insn(Insn), OpIdx(OpIdx), SymbolicName(SymbolicName) {} + const std::string &SymbolicName, + unsigned AllocatedTemporariesBaseID) + : Insn(Insn), OpIdx(OpIdx), SymbolicName(SymbolicName), + AllocatedTemporariesBaseID(AllocatedTemporariesBaseID) {} bool hasSymbolicName() const { return !SymbolicName.empty(); } const StringRef getSymbolicName() const { return SymbolicName; } @@ -509,8 +521,16 @@ public: return A + Predicate->countTemporaryOperands(); }); } + + unsigned getAllocatedTemporariesBaseID() const { + return AllocatedTemporariesBaseID; + } }; +unsigned ComplexPatternOperandMatcher::getAllocatedTemporariesBaseID() const { + return Operand.getAllocatedTemporariesBaseID(); +} + /// Generates code to check a predicate on an instruction. /// /// Typical predicates include: @@ -598,7 +618,7 @@ public: class InstructionMatcher : public PredicateListMatcher { protected: - typedef std::vector OperandVec; + typedef std::vector> OperandVec; /// The operands to match. All rendered operands must be present even if the /// condition is always true. @@ -606,18 +626,20 @@ protected: public: /// Add an operand to the matcher. - OperandMatcher &addOperand(unsigned OpIdx, const std::string &SymbolicName) { - Operands.emplace_back(*this, OpIdx, SymbolicName); - return Operands.back(); + OperandMatcher &addOperand(unsigned OpIdx, const std::string &SymbolicName, + unsigned AllocatedTemporariesBaseID) { + Operands.emplace_back(new OperandMatcher(*this, OpIdx, SymbolicName, + AllocatedTemporariesBaseID)); + return *Operands.back(); } OperandMatcher &getOperand(unsigned OpIdx) { auto I = std::find_if(Operands.begin(), Operands.end(), - [&OpIdx](const OperandMatcher &X) { - return X.getOperandIndex() == OpIdx; + [&OpIdx](const std::unique_ptr &X) { + return X->getOperandIndex() == OpIdx; }); if (I != Operands.end()) - return *I; + return **I; llvm_unreachable("Failed to lookup operand"); } @@ -625,7 +647,7 @@ public: getOptionalOperand(StringRef SymbolicName) const { assert(!SymbolicName.empty() && "Cannot lookup unnamed operand"); for (const auto &Operand : Operands) { - const auto &OM = Operand.getOptionalOperand(SymbolicName); + const auto &OM = Operand->getOptionalOperand(SymbolicName); if (OM.hasValue()) return OM.getValue(); } @@ -657,7 +679,7 @@ public: OS << "if (" << Expr << ".getNumOperands() < " << getNumOperands() << ")\n" << " return false;\n"; for (const auto &Operand : Operands) { - Operand.emitCxxCaptureStmts(OS, Rule, Operand.getOperandExpr(Expr)); + Operand->emitCxxCaptureStmts(OS, Rule, Operand->getOperandExpr(Expr)); } } @@ -668,7 +690,7 @@ public: emitCxxPredicateListExpr(OS, Rule, InsnVarName); for (const auto &Operand : Operands) { OS << " &&\n("; - Operand.emitCxxPredicateExpr(OS, Rule, InsnVarName); + Operand->emitCxxPredicateExpr(OS, Rule, InsnVarName); OS << ")"; } } @@ -691,9 +713,9 @@ public: } for (const auto &Operand : zip(Operands, B.Operands)) { - if (std::get<0>(Operand).isHigherPriorityThan(std::get<1>(Operand))) + if (std::get<0>(Operand)->isHigherPriorityThan(*std::get<1>(Operand))) return true; - if (std::get<1>(Operand).isHigherPriorityThan(std::get<0>(Operand))) + if (std::get<1>(Operand)->isHigherPriorityThan(*std::get<0>(Operand))) return false; } @@ -709,10 +731,11 @@ public: &Predicate) { return A + Predicate->countTemporaryOperands(); }) + - std::accumulate(Operands.begin(), Operands.end(), 0, - [](unsigned A, const OperandMatcher &Operand) { - return A + Operand.countTemporaryOperands(); - }); + std::accumulate( + Operands.begin(), Operands.end(), 0, + [](unsigned A, const std::unique_ptr &Operand) { + return A + Operand->countTemporaryOperands(); + }); } }; @@ -1174,8 +1197,7 @@ private: const InstructionMatcher &InsnMatcher) const; Error importExplicitUseRenderer(BuildMIAction &DstMIBuilder, TreePatternNode *DstChild, - const InstructionMatcher &InsnMatcher, - unsigned &TempOpIdx) const; + const InstructionMatcher &InsnMatcher) const; Error importImplicitDefRenderers(BuildMIAction &DstMIBuilder, const std::vector &ImplicitDefs) const; @@ -1237,6 +1259,7 @@ Expected GlobalISelEmitter::createAndImportSelDAGMatcher( InsnMatcher.addPredicate(&SrcGI); unsigned OpIdx = 0; + unsigned TempOpIdx = 0; for (const EEVT::TypeSet &Ty : Src->getExtTypes()) { auto OpTyOrNone = MVTToLLT(Ty.getConcrete()); @@ -1246,11 +1269,10 @@ Expected GlobalISelEmitter::createAndImportSelDAGMatcher( // Results don't have a name unless they are the root node. The caller will // set the name if appropriate. - OperandMatcher &OM = InsnMatcher.addOperand(OpIdx++, ""); + OperandMatcher &OM = InsnMatcher.addOperand(OpIdx++, "", TempOpIdx); OM.addPredicate(*OpTyOrNone); } - unsigned TempOpIdx = 0; // Match the used operands (i.e. the children of the operator). for (unsigned i = 0, e = Src->getNumChildren(); i != e; ++i) { if (auto Error = importChildMatcher(InsnMatcher, Src->getChild(i), OpIdx++, @@ -1265,7 +1287,8 @@ Error GlobalISelEmitter::importChildMatcher(InstructionMatcher &InsnMatcher, TreePatternNode *SrcChild, unsigned OpIdx, unsigned &TempOpIdx) const { - OperandMatcher &OM = InsnMatcher.addOperand(OpIdx, SrcChild->getName()); + OperandMatcher &OM = + InsnMatcher.addOperand(OpIdx, SrcChild->getName(), TempOpIdx); if (SrcChild->hasAnyPredicate()) return failedImport("Src pattern child has predicate"); @@ -1328,7 +1351,7 @@ Error GlobalISelEmitter::importChildMatcher(InstructionMatcher &InsnMatcher, "SelectionDAG ComplexPattern not mapped to GlobalISel"); const auto &Predicate = OM.addPredicate( - *ComplexPattern->second, TempOpIdx); + OM, *ComplexPattern->second); TempOpIdx += Predicate.countTemporaryOperands(); return Error::success(); } @@ -1342,7 +1365,7 @@ Error GlobalISelEmitter::importChildMatcher(InstructionMatcher &InsnMatcher, Error GlobalISelEmitter::importExplicitUseRenderer( BuildMIAction &DstMIBuilder, TreePatternNode *DstChild, - const InstructionMatcher &InsnMatcher, unsigned &TempOpIdx) const { + const InstructionMatcher &InsnMatcher) const { // The only non-leaf child we accept is 'bb': it's an operator because // BasicBlockSDNode isn't inline, but in MI it's just another operand. if (!DstChild->isLeaf()) { @@ -1389,13 +1412,10 @@ Error GlobalISelEmitter::importExplicitUseRenderer( "SelectionDAG ComplexPattern not mapped to GlobalISel"); SmallVector RenderedOperands; - for (unsigned I = 0; - I < - InsnMatcher.getOperand(DstChild->getName()).countTemporaryOperands(); - ++I) { - RenderedOperands.push_back(OperandPlaceholder::CreateTemporary(I)); - TempOpIdx++; - } + const OperandMatcher &OM = InsnMatcher.getOperand(DstChild->getName()); + for (unsigned I = 0; I < OM.countTemporaryOperands(); ++I) + RenderedOperands.push_back(OperandPlaceholder::CreateTemporary( + OM.getAllocatedTemporariesBaseID() + I)); DstMIBuilder.addRenderer( *ComplexPattern->second, RenderedOperands); return Error::success(); @@ -1425,10 +1445,9 @@ Expected GlobalISelEmitter::createAndImportInstructionRenderer( } // Render the explicit uses. - unsigned TempOpIdx = 0; for (unsigned i = 0, e = Dst->getNumChildren(); i != e; ++i) { if (auto Error = importExplicitUseRenderer(DstMIBuilder, Dst->getChild(i), - InsnMatcher, TempOpIdx)) + InsnMatcher)) return std::move(Error); }