[globalisel][tablegen] Fix patterns involving multiple ComplexPatterns.

Summary:
Temporaries are now allocated to operands instead of predicates and this
allocation is used to correctly pair up the rendered operands with the
matched operands.

Previously, ComplexPatterns were allocated temporaries independently in the
Src Pattern and Dst Pattern, leading to mismatches. Additionally, the Dst
Pattern failed to account for the allocated index and therefore always used
temporary 0, 1, ... when it should have used base+0, base+1, ...

Thanks to Aditya Nandakumar for noticing the bug.

Depends on D30539

Reviewers: ab, t.p.northover, qcolombet, rovka, aditya_nandakumar

Reviewed By: rovka

Subscribers: igorb, dberris, kristof.beyls, llvm-commits

Differential Revision: https://reviews.llvm.org/D31054

llvm-svn: 299538
This commit is contained in:
Daniel Sanders 2017-04-05 13:14:03 +00:00
parent 3bccec5da7
commit 4f3eb249cf
2 changed files with 137 additions and 41 deletions

View File

@ -18,12 +18,57 @@ class I<dag OOps, dag IOps, list<dag> Pat>
let Pattern = Pat; let Pattern = Pat;
} }
def complex : Operand<i32>, ComplexPattern<i32, 2, "SelectComplexPattern", []> {
let MIOperandInfo = (ops i32imm, i32imm);
}
def gi_complex :
GIComplexOperandMatcher<s32, (ops i32imm, i32imm), "selectComplexPattern">,
GIComplexPatternEquiv<complex>;
//===- Test the function definition boilerplate. --------------------------===// //===- Test the function definition boilerplate. --------------------------===//
// CHECK: bool MyTargetInstructionSelector::selectImpl(MachineInstr &I) const { // CHECK: bool MyTargetInstructionSelector::selectImpl(MachineInstr &I) const {
// CHECK: MachineFunction &MF = *I.getParent()->getParent(); // CHECK: MachineFunction &MF = *I.getParent()->getParent();
// CHECK: const MachineRegisterInfo &MRI = MF.getRegInfo(); // CHECK: const MachineRegisterInfo &MRI = MF.getRegInfo();
//===- Test a pattern with multiple ComplexPattern operands. --------------===//
//
// CHECK-LABEL: if ([&]() {
// CHECK-NEXT: MachineInstr &MI0 = I;
// CHECK-NEXT: if (MI0.getNumOperands() < 4)
// CHECK-NEXT: return false;
// CHECK-NEXT: if ((MI0.getOpcode() == TargetOpcode::G_SELECT) &&
// CHECK-NEXT: ((/* dst */ (MRI.getType(MI0.getOperand(0).getReg()) == (LLT::scalar(32))) &&
// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI0.getOperand(0).getReg(), MRI, TRI))))) &&
// CHECK-NEXT: ((/* src1 */ (MRI.getType(MI0.getOperand(1).getReg()) == (LLT::scalar(32))) &&
// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI0.getOperand(1).getReg(), MRI, TRI))))) &&
// CHECK-NEXT: ((/* src2 */ (MRI.getType(MI0.getOperand(2).getReg()) == (LLT::scalar(32))) &&
// CHECK-NEXT: (selectComplexPattern(MI0.getOperand(2), TempOp0, TempOp1)))) &&
// CHECK-NEXT: ((/* src3 */ (MRI.getType(MI0.getOperand(3).getReg()) == (LLT::scalar(32))) &&
// CHECK-NEXT: (selectComplexPattern(MI0.getOperand(3), TempOp2, TempOp3))))) {
// CHECK-NEXT: // (select:i32 GPR32:i32:$src1, complex:i32:$src2, complex:i32:$src3) => (INSN2:i32 GPR32:i32:$src1, complex:i32:$src3, complex:i32:$src2)
// CHECK-NEXT: MachineInstrBuilder MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(MyTarget::INSN2));
// CHECK-NEXT: MIB.add(MI0.getOperand(0)/*dst*/);
// CHECK-NEXT: MIB.add(MI0.getOperand(1)/*src1*/);
// CHECK-NEXT: MIB.add(TempOp2);
// CHECK-NEXT: MIB.add(TempOp3);
// CHECK-NEXT: MIB.add(TempOp0);
// CHECK-NEXT: MIB.add(TempOp1);
// CHECK-NEXT: for (const auto *FromMI : {&MI0, })
// CHECK-NEXT: for (const auto &MMO : FromMI->memoperands())
// CHECK-NEXT: MIB.addMemOperand(MMO);
// CHECK-NEXT: I.eraseFromParent();
// CHECK-NEXT: MachineInstr &NewI = *MIB;
// CHECK-NEXT: constrainSelectedInstRegOperands(NewI, TII, TRI, RBI);
// CHECK-NEXT: return true;
// CHECK-NEXT: }
def : GINodeEquiv<G_SELECT, select>;
def INSN2 : I<(outs GPR32:$dst), (ins GPR32:$src1, complex:$src2, complex:$src3), []>;
def : Pat<(select GPR32:$src1, complex:$src2, complex:$src3),
(INSN2 GPR32:$src1, complex:$src3, complex:$src2)>;
//===- Test a simple pattern with regclass operands. ----------------------===// //===- Test a simple pattern with regclass operands. ----------------------===//
// CHECK-LABEL: if ([&]() { // CHECK-LABEL: if ([&]() {
@ -166,6 +211,38 @@ def MULADD : I<(outs GPR32:$dst), (ins GPR32:$src1, GPR32:$src2, GPR32:$src3),
def MUL : I<(outs GPR32:$dst), (ins GPR32:$src2, GPR32:$src1), def MUL : I<(outs GPR32:$dst), (ins GPR32:$src2, GPR32:$src1),
[(set GPR32:$dst, (mul GPR32:$src1, GPR32:$src2))]>; [(set GPR32:$dst, (mul GPR32:$src1, GPR32:$src2))]>;
//===- Test a pattern with ComplexPattern operands. -----------------------===//
//
// CHECK-LABEL: if ([&]() {
// CHECK-NEXT: MachineInstr &MI0 = I;
// CHECK-NEXT: if (MI0.getNumOperands() < 3)
// CHECK-NEXT: return false;
// CHECK-NEXT: if ((MI0.getOpcode() == TargetOpcode::G_SUB) &&
// CHECK-NEXT: ((/* dst */ (MRI.getType(MI0.getOperand(0).getReg()) == (LLT::scalar(32))) &&
// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI0.getOperand(0).getReg(), MRI, TRI))))) &&
// CHECK-NEXT: ((/* src1 */ (MRI.getType(MI0.getOperand(1).getReg()) == (LLT::scalar(32))) &&
// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI0.getOperand(1).getReg(), MRI, TRI))))) &&
// CHECK-NEXT: ((/* src2 */ (MRI.getType(MI0.getOperand(2).getReg()) == (LLT::scalar(32))) &&
// CHECK-NEXT: (selectComplexPattern(MI0.getOperand(2), TempOp0, TempOp1))))) {
// CHECK-NEXT: // (sub:i32 GPR32:i32:$src1, complex:i32:$src2) => (INSN1:i32 GPR32:i32:$src1, complex:i32:$src2)
// CHECK-NEXT: MachineInstrBuilder MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(MyTarget::INSN1));
// CHECK-NEXT: MIB.add(MI0.getOperand(0)/*dst*/);
// CHECK-NEXT: MIB.add(MI0.getOperand(1)/*src1*/);
// CHECK-NEXT: MIB.add(TempOp0);
// CHECK-NEXT: MIB.add(TempOp1);
// CHECK-NEXT: for (const auto *FromMI : {&MI0, })
// CHECK-NEXT: for (const auto &MMO : FromMI->memoperands())
// CHECK-NEXT: MIB.addMemOperand(MMO);
// CHECK-NEXT: I.eraseFromParent();
// CHECK-NEXT: MachineInstr &NewI = *MIB;
// CHECK-NEXT: constrainSelectedInstRegOperands(NewI, TII, TRI, RBI);
// CHECK-NEXT: return true;
// CHECK-NEXT: }
def INSN1 : I<(outs GPR32:$dst), (ins GPR32:$src1, complex:$src2), []>;
def : Pat<(sub GPR32:$src1, complex:$src2), (INSN1 GPR32:$src1, complex:$src2)>;
//===- Test a simple pattern with constant immediate operands. ------------===// //===- Test a simple pattern with constant immediate operands. ------------===//
// //
// This must precede the 3-register variants because constant immediates have // This must precede the 3-register variants because constant immediates have

View File

@ -332,26 +332,31 @@ public:
/// Generates code to check that an operand is a particular target constant. /// Generates code to check that an operand is a particular target constant.
class ComplexPatternOperandMatcher : public OperandPredicateMatcher { class ComplexPatternOperandMatcher : public OperandPredicateMatcher {
protected: protected:
const OperandMatcher &Operand;
const Record &TheDef; const Record &TheDef;
/// The index of the first temporary operand to allocate to this
/// ComplexPattern.
unsigned BaseTemporaryID;
unsigned getNumOperands() const { unsigned getNumOperands() const {
return TheDef.getValueAsDag("Operands")->getNumArgs(); return TheDef.getValueAsDag("Operands")->getNumArgs();
} }
unsigned getAllocatedTemporariesBaseID() const;
public: public:
ComplexPatternOperandMatcher(const Record &TheDef, unsigned BaseTemporaryID) ComplexPatternOperandMatcher(const OperandMatcher &Operand,
: OperandPredicateMatcher(OPM_ComplexPattern), TheDef(TheDef), const Record &TheDef)
BaseTemporaryID(BaseTemporaryID) {} : OperandPredicateMatcher(OPM_ComplexPattern), Operand(Operand),
TheDef(TheDef) {}
static bool classof(const OperandPredicateMatcher *P) {
return P->getKind() == OPM_ComplexPattern;
}
void emitCxxPredicateExpr(raw_ostream &OS, RuleMatcher &Rule, void emitCxxPredicateExpr(raw_ostream &OS, RuleMatcher &Rule,
StringRef OperandExpr) const override { StringRef OperandExpr) const override {
OS << TheDef.getValueAsString("MatcherFn") << "(" << OperandExpr; OS << TheDef.getValueAsString("MatcherFn") << "(" << OperandExpr;
for (unsigned I = 0; I < getNumOperands(); ++I) { for (unsigned I = 0; I < getNumOperands(); ++I) {
OS << ", "; OS << ", ";
OperandPlaceholder::CreateTemporary(BaseTemporaryID + I) OperandPlaceholder::CreateTemporary(getAllocatedTemporariesBaseID() + I)
.emitCxxValueExpr(OS); .emitCxxValueExpr(OS);
} }
OS << ")"; OS << ")";
@ -425,10 +430,17 @@ protected:
unsigned OpIdx; unsigned OpIdx;
std::string SymbolicName; std::string SymbolicName;
/// The index of the first temporary variable allocated to this operand. The
/// number of allocated temporaries can be found with
/// countTemporaryOperands().
unsigned AllocatedTemporariesBaseID;
public: public:
OperandMatcher(InstructionMatcher &Insn, unsigned OpIdx, OperandMatcher(InstructionMatcher &Insn, unsigned OpIdx,
const std::string &SymbolicName) const std::string &SymbolicName,
: Insn(Insn), OpIdx(OpIdx), SymbolicName(SymbolicName) {} unsigned AllocatedTemporariesBaseID)
: Insn(Insn), OpIdx(OpIdx), SymbolicName(SymbolicName),
AllocatedTemporariesBaseID(AllocatedTemporariesBaseID) {}
bool hasSymbolicName() const { return !SymbolicName.empty(); } bool hasSymbolicName() const { return !SymbolicName.empty(); }
const StringRef getSymbolicName() const { return SymbolicName; } const StringRef getSymbolicName() const { return SymbolicName; }
@ -509,8 +521,16 @@ public:
return A + Predicate->countTemporaryOperands(); return A + Predicate->countTemporaryOperands();
}); });
} }
unsigned getAllocatedTemporariesBaseID() const {
return AllocatedTemporariesBaseID;
}
}; };
unsigned ComplexPatternOperandMatcher::getAllocatedTemporariesBaseID() const {
return Operand.getAllocatedTemporariesBaseID();
}
/// Generates code to check a predicate on an instruction. /// Generates code to check a predicate on an instruction.
/// ///
/// Typical predicates include: /// Typical predicates include:
@ -598,7 +618,7 @@ public:
class InstructionMatcher class InstructionMatcher
: public PredicateListMatcher<InstructionPredicateMatcher> { : public PredicateListMatcher<InstructionPredicateMatcher> {
protected: protected:
typedef std::vector<OperandMatcher> OperandVec; typedef std::vector<std::unique_ptr<OperandMatcher>> OperandVec;
/// The operands to match. All rendered operands must be present even if the /// The operands to match. All rendered operands must be present even if the
/// condition is always true. /// condition is always true.
@ -606,18 +626,20 @@ protected:
public: public:
/// Add an operand to the matcher. /// Add an operand to the matcher.
OperandMatcher &addOperand(unsigned OpIdx, const std::string &SymbolicName) { OperandMatcher &addOperand(unsigned OpIdx, const std::string &SymbolicName,
Operands.emplace_back(*this, OpIdx, SymbolicName); unsigned AllocatedTemporariesBaseID) {
return Operands.back(); Operands.emplace_back(new OperandMatcher(*this, OpIdx, SymbolicName,
AllocatedTemporariesBaseID));
return *Operands.back();
} }
OperandMatcher &getOperand(unsigned OpIdx) { OperandMatcher &getOperand(unsigned OpIdx) {
auto I = std::find_if(Operands.begin(), Operands.end(), auto I = std::find_if(Operands.begin(), Operands.end(),
[&OpIdx](const OperandMatcher &X) { [&OpIdx](const std::unique_ptr<OperandMatcher> &X) {
return X.getOperandIndex() == OpIdx; return X->getOperandIndex() == OpIdx;
}); });
if (I != Operands.end()) if (I != Operands.end())
return *I; return **I;
llvm_unreachable("Failed to lookup operand"); llvm_unreachable("Failed to lookup operand");
} }
@ -625,7 +647,7 @@ public:
getOptionalOperand(StringRef SymbolicName) const { getOptionalOperand(StringRef SymbolicName) const {
assert(!SymbolicName.empty() && "Cannot lookup unnamed operand"); assert(!SymbolicName.empty() && "Cannot lookup unnamed operand");
for (const auto &Operand : Operands) { for (const auto &Operand : Operands) {
const auto &OM = Operand.getOptionalOperand(SymbolicName); const auto &OM = Operand->getOptionalOperand(SymbolicName);
if (OM.hasValue()) if (OM.hasValue())
return OM.getValue(); return OM.getValue();
} }
@ -657,7 +679,7 @@ public:
OS << "if (" << Expr << ".getNumOperands() < " << getNumOperands() << ")\n" OS << "if (" << Expr << ".getNumOperands() < " << getNumOperands() << ")\n"
<< " return false;\n"; << " return false;\n";
for (const auto &Operand : Operands) { for (const auto &Operand : Operands) {
Operand.emitCxxCaptureStmts(OS, Rule, Operand.getOperandExpr(Expr)); Operand->emitCxxCaptureStmts(OS, Rule, Operand->getOperandExpr(Expr));
} }
} }
@ -668,7 +690,7 @@ public:
emitCxxPredicateListExpr(OS, Rule, InsnVarName); emitCxxPredicateListExpr(OS, Rule, InsnVarName);
for (const auto &Operand : Operands) { for (const auto &Operand : Operands) {
OS << " &&\n("; OS << " &&\n(";
Operand.emitCxxPredicateExpr(OS, Rule, InsnVarName); Operand->emitCxxPredicateExpr(OS, Rule, InsnVarName);
OS << ")"; OS << ")";
} }
} }
@ -691,9 +713,9 @@ public:
} }
for (const auto &Operand : zip(Operands, B.Operands)) { for (const auto &Operand : zip(Operands, B.Operands)) {
if (std::get<0>(Operand).isHigherPriorityThan(std::get<1>(Operand))) if (std::get<0>(Operand)->isHigherPriorityThan(*std::get<1>(Operand)))
return true; return true;
if (std::get<1>(Operand).isHigherPriorityThan(std::get<0>(Operand))) if (std::get<1>(Operand)->isHigherPriorityThan(*std::get<0>(Operand)))
return false; return false;
} }
@ -709,9 +731,10 @@ public:
&Predicate) { &Predicate) {
return A + Predicate->countTemporaryOperands(); return A + Predicate->countTemporaryOperands();
}) + }) +
std::accumulate(Operands.begin(), Operands.end(), 0, std::accumulate(
[](unsigned A, const OperandMatcher &Operand) { Operands.begin(), Operands.end(), 0,
return A + Operand.countTemporaryOperands(); [](unsigned A, const std::unique_ptr<OperandMatcher> &Operand) {
return A + Operand->countTemporaryOperands();
}); });
} }
}; };
@ -1174,8 +1197,7 @@ private:
const InstructionMatcher &InsnMatcher) const; const InstructionMatcher &InsnMatcher) const;
Error importExplicitUseRenderer(BuildMIAction &DstMIBuilder, Error importExplicitUseRenderer(BuildMIAction &DstMIBuilder,
TreePatternNode *DstChild, TreePatternNode *DstChild,
const InstructionMatcher &InsnMatcher, const InstructionMatcher &InsnMatcher) const;
unsigned &TempOpIdx) const;
Error Error
importImplicitDefRenderers(BuildMIAction &DstMIBuilder, importImplicitDefRenderers(BuildMIAction &DstMIBuilder,
const std::vector<Record *> &ImplicitDefs) const; const std::vector<Record *> &ImplicitDefs) const;
@ -1237,6 +1259,7 @@ Expected<InstructionMatcher &> GlobalISelEmitter::createAndImportSelDAGMatcher(
InsnMatcher.addPredicate<InstructionOpcodeMatcher>(&SrcGI); InsnMatcher.addPredicate<InstructionOpcodeMatcher>(&SrcGI);
unsigned OpIdx = 0; unsigned OpIdx = 0;
unsigned TempOpIdx = 0;
for (const EEVT::TypeSet &Ty : Src->getExtTypes()) { for (const EEVT::TypeSet &Ty : Src->getExtTypes()) {
auto OpTyOrNone = MVTToLLT(Ty.getConcrete()); auto OpTyOrNone = MVTToLLT(Ty.getConcrete());
@ -1246,11 +1269,10 @@ Expected<InstructionMatcher &> GlobalISelEmitter::createAndImportSelDAGMatcher(
// Results don't have a name unless they are the root node. The caller will // Results don't have a name unless they are the root node. The caller will
// set the name if appropriate. // set the name if appropriate.
OperandMatcher &OM = InsnMatcher.addOperand(OpIdx++, ""); OperandMatcher &OM = InsnMatcher.addOperand(OpIdx++, "", TempOpIdx);
OM.addPredicate<LLTOperandMatcher>(*OpTyOrNone); OM.addPredicate<LLTOperandMatcher>(*OpTyOrNone);
} }
unsigned TempOpIdx = 0;
// Match the used operands (i.e. the children of the operator). // Match the used operands (i.e. the children of the operator).
for (unsigned i = 0, e = Src->getNumChildren(); i != e; ++i) { for (unsigned i = 0, e = Src->getNumChildren(); i != e; ++i) {
if (auto Error = importChildMatcher(InsnMatcher, Src->getChild(i), OpIdx++, if (auto Error = importChildMatcher(InsnMatcher, Src->getChild(i), OpIdx++,
@ -1265,7 +1287,8 @@ Error GlobalISelEmitter::importChildMatcher(InstructionMatcher &InsnMatcher,
TreePatternNode *SrcChild, TreePatternNode *SrcChild,
unsigned OpIdx, unsigned OpIdx,
unsigned &TempOpIdx) const { unsigned &TempOpIdx) const {
OperandMatcher &OM = InsnMatcher.addOperand(OpIdx, SrcChild->getName()); OperandMatcher &OM =
InsnMatcher.addOperand(OpIdx, SrcChild->getName(), TempOpIdx);
if (SrcChild->hasAnyPredicate()) if (SrcChild->hasAnyPredicate())
return failedImport("Src pattern child has predicate"); return failedImport("Src pattern child has predicate");
@ -1328,7 +1351,7 @@ Error GlobalISelEmitter::importChildMatcher(InstructionMatcher &InsnMatcher,
"SelectionDAG ComplexPattern not mapped to GlobalISel"); "SelectionDAG ComplexPattern not mapped to GlobalISel");
const auto &Predicate = OM.addPredicate<ComplexPatternOperandMatcher>( const auto &Predicate = OM.addPredicate<ComplexPatternOperandMatcher>(
*ComplexPattern->second, TempOpIdx); OM, *ComplexPattern->second);
TempOpIdx += Predicate.countTemporaryOperands(); TempOpIdx += Predicate.countTemporaryOperands();
return Error::success(); return Error::success();
} }
@ -1342,7 +1365,7 @@ Error GlobalISelEmitter::importChildMatcher(InstructionMatcher &InsnMatcher,
Error GlobalISelEmitter::importExplicitUseRenderer( Error GlobalISelEmitter::importExplicitUseRenderer(
BuildMIAction &DstMIBuilder, TreePatternNode *DstChild, BuildMIAction &DstMIBuilder, TreePatternNode *DstChild,
const InstructionMatcher &InsnMatcher, unsigned &TempOpIdx) const { const InstructionMatcher &InsnMatcher) const {
// The only non-leaf child we accept is 'bb': it's an operator because // The only non-leaf child we accept is 'bb': it's an operator because
// BasicBlockSDNode isn't inline, but in MI it's just another operand. // BasicBlockSDNode isn't inline, but in MI it's just another operand.
if (!DstChild->isLeaf()) { if (!DstChild->isLeaf()) {
@ -1389,13 +1412,10 @@ Error GlobalISelEmitter::importExplicitUseRenderer(
"SelectionDAG ComplexPattern not mapped to GlobalISel"); "SelectionDAG ComplexPattern not mapped to GlobalISel");
SmallVector<OperandPlaceholder, 2> RenderedOperands; SmallVector<OperandPlaceholder, 2> RenderedOperands;
for (unsigned I = 0; const OperandMatcher &OM = InsnMatcher.getOperand(DstChild->getName());
I < for (unsigned I = 0; I < OM.countTemporaryOperands(); ++I)
InsnMatcher.getOperand(DstChild->getName()).countTemporaryOperands(); RenderedOperands.push_back(OperandPlaceholder::CreateTemporary(
++I) { OM.getAllocatedTemporariesBaseID() + I));
RenderedOperands.push_back(OperandPlaceholder::CreateTemporary(I));
TempOpIdx++;
}
DstMIBuilder.addRenderer<RenderComplexPatternOperand>( DstMIBuilder.addRenderer<RenderComplexPatternOperand>(
*ComplexPattern->second, RenderedOperands); *ComplexPattern->second, RenderedOperands);
return Error::success(); return Error::success();
@ -1425,10 +1445,9 @@ Expected<BuildMIAction &> GlobalISelEmitter::createAndImportInstructionRenderer(
} }
// Render the explicit uses. // Render the explicit uses.
unsigned TempOpIdx = 0;
for (unsigned i = 0, e = Dst->getNumChildren(); i != e; ++i) { for (unsigned i = 0, e = Dst->getNumChildren(); i != e; ++i) {
if (auto Error = importExplicitUseRenderer(DstMIBuilder, Dst->getChild(i), if (auto Error = importExplicitUseRenderer(DstMIBuilder, Dst->getChild(i),
InsnMatcher, TempOpIdx)) InsnMatcher))
return std::move(Error); return std::move(Error);
} }