From 4c0faef943985e517e8c7402b4a742cea8342168 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Fri, 11 Jan 2019 07:41:12 -0800 Subject: [PATCH] Avoid redundant predicate checking in type matching. Expand type matcher template generator to consider a set of predicates that are known to hold. This avoids inserting redundant checking for trivially true predicates (for example predicate that hold according to the op definition). This only targets predicates that trivially holds and does not attempt any logic equivalence proof. PiperOrigin-RevId: 228880468 --- mlir/include/mlir/TableGen/Predicate.h | 2 +- mlir/lib/TableGen/Operator.cpp | 3 ++- mlir/lib/TableGen/Predicate.cpp | 19 ++++++++++++++++++- mlir/tools/mlir-tblgen/RewriterGen.cpp | 4 ++-- 4 files changed, 23 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/TableGen/Predicate.h b/mlir/include/mlir/TableGen/Predicate.h index a89fbf15e483..a667d92788c3 100644 --- a/mlir/include/mlir/TableGen/Predicate.h +++ b/mlir/include/mlir/TableGen/Predicate.h @@ -60,7 +60,7 @@ public: // Returns the template string to construct the matcher corresponding to this // predicate CNF. The string uses '{0}' to represent the type. - std::string createTypeMatcherTemplate() const; + std::string createTypeMatcherTemplate(PredCNF predsKnownToHold) const; private: // The TableGen definition of this predicate CNF. nullptr means an empty diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index 595cf8a59b75..61b6d1745b5d 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -163,5 +163,6 @@ bool tblgen::Operator::Operand::hasMatcher() const { } std::string tblgen::Operator::Operand::createTypeMatcherTemplate() const { - return tblgen::Type(defInit).getPredicate().createTypeMatcherTemplate(); + return tblgen::Type(defInit).getPredicate().createTypeMatcherTemplate( + PredCNF()); } diff --git a/mlir/lib/TableGen/Predicate.cpp b/mlir/lib/TableGen/Predicate.cpp index 88d08f565c9d..b297fff80f5e 100644 --- a/mlir/lib/TableGen/Predicate.cpp +++ b/mlir/lib/TableGen/Predicate.cpp @@ -20,6 +20,7 @@ //===----------------------------------------------------------------------===// #include "mlir/TableGen/Predicate.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/TableGen/Error.h" @@ -42,15 +43,29 @@ const llvm::ListInit *tblgen::PredCNF::getConditions() const { return def->getValueAsListInit("conditions"); } -std::string tblgen::PredCNF::createTypeMatcherTemplate() const { +std::string +tblgen::PredCNF::createTypeMatcherTemplate(PredCNF predsKnownToHold) const { const auto *conjunctiveList = getConditions(); if (!conjunctiveList) return "true"; + // Create a set of all the disjunctive conditions that hold. This is taking + // advantage of uniquieing of lists to discard based on the pointer + // below. This is not perfect but this will also be moved to FSM matching in + // future and gets rid of trivial redundant checking. + llvm::SmallSetVector existingConditions; + auto existingList = predsKnownToHold.getConditions(); + if (existingList) { + for (auto disjunctiveInit : *existingList) + existingConditions.insert(disjunctiveInit); + } + std::string outString; llvm::raw_string_ostream ss(outString); bool firstDisjunctive = true; for (auto disjunctiveInit : *conjunctiveList) { + if (existingConditions.count(disjunctiveInit) != 0) + continue; ss << (firstDisjunctive ? "(" : " && ("); firstDisjunctive = false; bool firstConjunctive = true; @@ -63,6 +78,8 @@ std::string tblgen::PredCNF::createTypeMatcherTemplate() const { } ss << ")"; } + if (firstDisjunctive) + return "true"; ss.flush(); return outString; } diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index b4e38954ec79..615b08fa1571 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -175,10 +175,10 @@ static void matchOp(Record *pattern, DagInit *tree, int depth, "type argument required for operand"); auto pred = tblgen::Type(defInit).getPredicate(); - + auto opPred = tblgen::Type(operand->defInit).getPredicate(); os.indent(indent) << "if (!(" - << formatv(pred.createTypeMatcherTemplate().c_str(), + << formatv(pred.createTypeMatcherTemplate(opPred).c_str(), formatv("op{0}->getOperand({1})->getType()", depth, i)) << ")) return matchFailure();\n"; }