TableGen: implement predicate tree and basic simplification

A recent change in TableGen definitions allowed arbitrary AND/OR predicate
compositions at the cost of removing known-true predicate simplification.
Introduce a more advanced simplification mechanism instead.

In particular, instead of folding predicate C++ expressions directly in
TableGen, keep them as is and build a predicate tree in TableGen C++ library.
The predicate expression-substitution mechanism, necessary to implement complex
predicates for nested classes such as `ContainerType`, is replaced by a
dedicated predicate.  This predicate appears in the predicate tree and can be
used for tree matching and separation.  More specifically, subtrees defined
below such predicate may be subject to different transformations than those
that appear above.  For example, a subtree known to be true above the
substitution predicate is not necessarily true below it.

Use the predicate tree structure to eliminate known-true and known-false
predicates before code emission, as well as to collapse AND and OR predicates
if their value can be deduced based on the value of one child.

PiperOrigin-RevId: 229605997
This commit is contained in:
Alex Zinenko 2019-01-16 12:36:10 -08:00 committed by jpienaar
parent 4b2b5f5267
commit 05b02bb98e
6 changed files with 410 additions and 38 deletions

View File

@ -27,33 +27,49 @@
// Predicates.
//===----------------------------------------------------------------------===//
// A logical predicate.
class Pred;
// Logical predicate wrapping a C expression.
class CPred<code pred> {
class CPred<code pred> : Pred {
code predCall = "(" # pred # ")";
}
// Kinds of combined logical predicates. These must closesly match the
// predicates implemented by the C++ backend (tblgen::PredCombinerKind).
class PredCombinerKind;
def PredCombinerAnd : PredCombinerKind;
def PredCombinerOr : PredCombinerKind;
def PredCombinerNot : PredCombinerKind;
def PredCombinerSubstLeaves : PredCombinerKind;
// A predicate that combines other predicates as defined by PredCombinerKind.
// Instantiated below.
class CombinedPred<PredCombinerKind k, list<Pred> c> : Pred {
PredCombinerKind kind = k;
list<Pred> children = c;
}
// A predicate that holds if all of its children hold. Always holds for zero
// children.
class AllOf<list<CPred> children> : CPred<
!if(
!empty(children),
"true",
!foldl(!head(children).predCall, !tail(children), acc, elem,
!cast<code>(acc # " && " # elem.predCall))
)>;
class AllOf<list<Pred> children> : CombinedPred<PredCombinerAnd, children>;
// A predicate that holds if any of its children hold. Never holds for zero
// children.
class AnyOf<list<CPred> children> : CPred<
!if(
!empty(children),
"false",
!foldl(!head(children).predCall, !tail(children), acc, elem,
!cast<code>(acc # " || " # elem.predCall))
)>;
class AnyOf<list<Pred> children> : CombinedPred<PredCombinerOr, children>;
// A predicate that hold if its child does not.
class NotCPred<CPred child> : CPred<"!" # child>;
// A predicate that holds if its child does not.
class Neg<Pred child> : CombinedPred<PredCombinerNot, [child]>;
// A predicate that substitutes "pat" with "repl" in predicate calls of the
// leaves of the predicate tree (i.e., not CombinedPredicates). This is plain
// string substitution without regular expressions or captures, new predicates
// with more complex logical can be introduced should the need arise.
class SubstLeaves<string pat, string repl, Pred child>
: CombinedPred<PredCombinerSubstLeaves, [child]> {
string pattern = pat;
string replacement = repl;
}
//===----------------------------------------------------------------------===//
// Type predicates. ({0} is replaced by an instance of mlir::Type)
@ -75,17 +91,17 @@ def IsStaticShapeTensorTypePred :
// A constraint on types. This can be used to check the validity of
// instruction arguments.
class TypeConstraint<CPred condition, string descr = ""> {
class TypeConstraint<Pred condition, string descr = ""> {
// The predicates that this type satisfies.
// Format: {0} will be expanded to the type.
CPred predicate = condition;
Pred predicate = condition;
// User-readable description used, e.g., for error reporting. If empty, a
// generic message will be used instead.
string description = descr;
}
// A type, carries type constraints, but accepts any type by default.
class Type<CPred condition = CPred<"true">, string descr = "">
class Type<Pred condition = CPred<"true">, string descr = "">
: TypeConstraint<condition, descr>;
// A type that can be constructed using MLIR::Builder.
@ -134,14 +150,13 @@ class F<int width>
def F32 : F<32>;
// A container type is a type that has another type embedded within it.
class ContainerType<Type etype, CPred containerPred, code elementTypeCall,
class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
string descr> :
// First, check the container predicate. Then, substitute the extracted
// element into the element type checker.
Type<AllOf<[containerPred,
CPred<!subst("{0}",
!cast<string>(elementTypeCall),
!cast<string>(etype.predicate.predCall))>]>,
SubstLeaves<"{0}", !cast<string>(elementTypeCall),
etype.predicate>]>,
descr # "<" # etype.description # ">" > {
// The type of elements in the container.
Type elementType = etype;

View File

@ -24,36 +24,93 @@
#include "mlir/Support/LLVM.h"
#include <vector>
namespace llvm {
class Init;
class ListInit;
class Record;
class SMLoc;
} // end namespace llvm
namespace mlir {
namespace tblgen {
// A logical predicate.
// A logical predicate. This class must closely follow the definition of
// TableGen class 'Pred'.
class Pred {
public:
// Construct a Predicate from a record.
explicit Pred(const llvm::Record *def);
explicit Pred(const llvm::Record *record);
// Construct a Predicate from an initializer.
explicit Pred(const llvm::Init *init);
// Get the predicate condition. The predicate must not be null.
StringRef getCondition() const;
// Check if the predicate is defined. Callers may use this to interpret the
// missing predicate as either true (e.g. in filters) or false (e.g. in
// precondition verification).
bool isNull() const { return def == nullptr; }
private:
// Get the predicate condition. This may dispatch to getConditionImpl() of
// the underlying predicate type.
std::string getCondition() const;
// Whether the predicate is a combination of other predicates, i.e. an
// record of type CombinedPred.
bool isCombined() const;
// Records are pointer-comparable.
bool operator==(const Pred &other) const { return def == other.def; }
// Get the location of the predicate.
ArrayRef<llvm::SMLoc> getLoc() const;
protected:
// The TableGen definition of this predicate.
const llvm::Record *def;
};
// A logical predicate wrapping a C expression. This class must closely follow
// the definition of TableGen class 'CPred'.
class CPred : public Pred {
public:
// Construct a CPred from a record.
explicit CPred(const llvm::Record *record);
// Construct a CPred an initializer.
explicit CPred(const llvm::Init *init);
// Get the predicate condition.
std::string getConditionImpl() const;
};
// A logical predicate that is a combination of other predicates. This class
// must closely follow the definition of TableGen class 'CombinedPred'.
class CombinedPred : public Pred {
public:
// Construct a CombinedPred from a record.
explicit CombinedPred(const llvm::Record *record);
// Construct a CombinedPred from an initializer.
explicit CombinedPred(const llvm::Init *init);
// Get the predicate condition.
std::string getConditionImpl() const;
// Get the definition of the combiner used in this predicate.
const llvm::Record *getCombinerDef() const;
// Get the predicates that are combined by this predicate.
const std::vector<llvm::Record *> getChildren() const;
};
// A combined predicate that requires all child predicates of 'CPred' type to
// have their expression rewritten with a simple string substitution rule.
class SubstLeavesPred : public CombinedPred {
public:
// Get the replacement pattern.
StringRef getPattern() const;
// Get the string used to replace the pattern.
StringRef getReplacement() const;
};
} // end namespace tblgen
} // end namespace mlir

View File

@ -48,7 +48,7 @@ public:
// Returns the condition template that can be used to check if a type
// satisfies this type constraint. The template may contain "{0}" that must
// be substituted with an expression returning an mlir::Type.
StringRef getConditionTemplate() const;
std::string getConditionTemplate() const;
// Returns the user-readable description of the constraint. If the
// description is not provided, returns an empty string.

View File

@ -21,6 +21,7 @@
#include "mlir/TableGen/Predicate.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
@ -29,7 +30,7 @@
using namespace mlir;
// Construct a Predicate from a record.
tblgen::Pred::Pred(const llvm::Record *def) : def(def) {
tblgen::Pred::Pred(const llvm::Record *record) : def(record) {
assert(def->isSubClassOf("Pred") &&
"must be a subclass of TableGen 'Pred' class");
}
@ -40,8 +41,306 @@ tblgen::Pred::Pred(const llvm::Init *init) : def(nullptr) {
def = defInit->getDef();
}
// Get condition of the Predicate.
StringRef tblgen::Pred::getCondition() const {
std::string tblgen::Pred::getCondition() const {
// Static dispatch to subclasses.
if (def->isSubClassOf("CombinedPred"))
return static_cast<const CombinedPred *>(this)->getConditionImpl();
if (def->isSubClassOf("CPred"))
return static_cast<const CPred *>(this)->getConditionImpl();
llvm_unreachable("Pred::getCondition must be overridden in subclasses");
}
bool tblgen::Pred::isCombined() const {
return def && def->isSubClassOf("CombinedPred");
}
ArrayRef<llvm::SMLoc> tblgen::Pred::getLoc() const { return def->getLoc(); }
tblgen::CPred::CPred(const llvm::Record *record) : Pred(record) {
assert(def->isSubClassOf("CPred") &&
"must be a subclass of Tablegen 'CPred' class");
}
tblgen::CPred::CPred(const llvm::Init *init) : Pred(init) {
assert((!def || def->isSubClassOf("CPred")) &&
"must be a subclass of Tablegen 'CPred' class");
}
// Get condition of the C Predicate.
std::string tblgen::CPred::getConditionImpl() const {
assert(!isNull() && "null predicate does not have a condition");
return def->getValueAsString("predCall");
}
tblgen::CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) {
assert(def->isSubClassOf("CombinedPred") &&
"must be a subclass of Tablegen 'CombinedPred' class");
}
tblgen::CombinedPred::CombinedPred(const llvm::Init *init) : Pred(init) {
assert((!def || def->isSubClassOf("CombinedPred")) &&
"must be a subclass of Tablegen 'CombinedPred' class");
}
const llvm::Record *tblgen::CombinedPred::getCombinerDef() const {
assert(def->getValue("kind") && "CombinedPred must have a value 'kind'");
return def->getValueAsDef("kind");
}
const std::vector<llvm::Record *> tblgen::CombinedPred::getChildren() const {
assert(def->getValue("children") &&
"CombinedPred must have a value 'children'");
return def->getValueAsListOfDefs("children");
}
namespace {
// Kinds of nodes in a logical predicate tree.
enum class PredCombinerKind {
Leaf,
And,
Or,
Not,
SubstLeaves,
// Special kinds that are used in simplification.
False,
True
};
// A node in a logical predicate tree.
struct PredNode {
PredCombinerKind kind;
const tblgen::Pred *predicate;
SmallVector<PredNode *, 4> children;
std::string expr;
};
} // end anonymous namespace
// Get a predicate tree node kind based on the kind used in the predicate
// TableGen record.
static PredCombinerKind getPredCombinerKind(const tblgen::Pred &pred) {
if (!pred.isCombined())
return PredCombinerKind::Leaf;
const auto &combinedPred = static_cast<const tblgen::CombinedPred &>(pred);
return llvm::StringSwitch<PredCombinerKind>(
combinedPred.getCombinerDef()->getName())
.Case("PredCombinerAnd", PredCombinerKind::And)
.Case("PredCombinerOr", PredCombinerKind::Or)
.Case("PredCombinerNot", PredCombinerKind::Not)
.Case("PredCombinerSubstLeaves", PredCombinerKind::SubstLeaves);
}
namespace {
// Substitution<pattern, replacement>.
using Subst = std::pair<StringRef, StringRef>;
} // end anonymous namespace
// Build the predicate tree starting from the top-level predicate, which may
// have children, and perform leaf substitutions inplace. Note that after
// substitution, nodes are still pointing to the original TableGen record.
// All nodes are created within "allocator".
static PredNode *buildPredicateTree(const tblgen::Pred &root,
llvm::BumpPtrAllocator &allocator,
ArrayRef<Subst> substitutions) {
auto *rootNode = allocator.Allocate<PredNode>();
new (rootNode) PredNode;
rootNode->kind = getPredCombinerKind(root);
rootNode->predicate = &root;
if (!root.isCombined()) {
rootNode->expr = root.getCondition();
// Apply all parent substitutions from innermost to outermost.
for (const auto &subst : llvm::reverse(substitutions)) {
size_t start = 0;
while (auto pos =
rootNode->expr.find(subst.first, start) != std::string::npos) {
rootNode->expr.replace(pos, subst.first.size(), subst.second);
start = pos + subst.second.size();
}
}
return rootNode;
}
// If the current combined predicate is a leaf substitution, append it to the
// list before contiuing.
auto allSubstitutions = llvm::to_vector<4>(substitutions);
if (rootNode->kind == PredCombinerKind::SubstLeaves) {
const auto &substPred = static_cast<const tblgen::SubstLeavesPred &>(root);
allSubstitutions.push_back(
{substPred.getPattern(), substPred.getReplacement()});
}
// Build child subtrees.
auto combined = static_cast<const tblgen::CombinedPred &>(root);
for (const auto *record : combined.getChildren()) {
auto childTree =
buildPredicateTree(tblgen::Pred(record), allocator, allSubstitutions);
rootNode->children.push_back(childTree);
}
return rootNode;
}
// Simplify a predicate tree rooted at "node" using the predicates that are
// known to be true(false). For AND(OR) combined predicates, if any of the
// children is known to be false(true), the result is also false(true).
// Furthermore, for AND(OR) combined predicates, children that are known to be
// true(false) don't have to be checked dynamically.
static PredNode *propagateGroundTruth(
PredNode *node, const llvm::SmallPtrSetImpl<tblgen::Pred *> &knownTruePreds,
const llvm::SmallPtrSetImpl<tblgen::Pred *> &knownFalsePreds) {
// If the current predicate is known to be true or false, change the kind of
// the node and return immediately.
if (knownTruePreds.count(node->predicate) != 0) {
node->kind = PredCombinerKind::True;
node->children.clear();
return node;
}
if (knownFalsePreds.count(node->predicate) != 0) {
node->kind = PredCombinerKind::False;
node->children.clear();
return node;
}
// If the current node is a substitution, stop recursion now.
// The expressions in the leaves below this node were rewritten, but the nodes
// still point to the original predicate records. While the original
// predicate may be known to be true or false, it is not necessarily the case
// after rewriting.
// TODO(zinenko,jpienaar): we can support ground truth for rewritten
// predicates by either (a) having our own unique'ing of the predicates
// instead of relying on TableGen record pointers or (b) taking ground truth
// values optinally prefixed with a list of substitutions to apply, e.g.
// "predX is true by itself as well as predSubY leaf substitution had been
// applied to it".
if (node->kind == PredCombinerKind::SubstLeaves) {
return node;
}
// Otherwise, look at child nodes.
// Move child nodes into some local variable so that they can be optimized
// separately and re-added if necessary.
llvm::SmallVector<PredNode *, 4> children;
std::swap(node->children, children);
for (auto &child : children) {
// First, simplify the child. This maintains the predicate as it was.
auto simplifiedChild =
propagateGroundTruth(child, knownTruePreds, knownFalsePreds);
// Just add the child if we don't know how to simplify the current node.
if (node->kind != PredCombinerKind::And &&
node->kind != PredCombinerKind::Or) {
node->children.push_back(simplifiedChild);
continue;
}
// Second, based on the type define which known values of child predicates
// immediately collapse this predicate to a known value, and which others
// may be safely ignored.
// OR(..., True, ...) = True
// OR(..., False, ...) = OR(..., ...)
// AND(..., False, ...) = False
// AND(..., True, ...) = AND(..., ...)
auto collapseKind = node->kind == PredCombinerKind::And
? PredCombinerKind::False
: PredCombinerKind::True;
auto eraseKind = node->kind == PredCombinerKind::And
? PredCombinerKind::True
: PredCombinerKind::False;
const auto &collapseList =
node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds;
const auto &eraseList =
node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds;
if (simplifiedChild->kind == collapseKind ||
collapseList.count(simplifiedChild->predicate) != 0) {
node->kind = collapseKind;
node->children.clear();
return node;
} else if (simplifiedChild->kind == eraseKind ||
eraseList.count(simplifiedChild->predicate) != 0) {
continue;
}
node->children.push_back(simplifiedChild);
}
return node;
}
// Combine a list of predicate expressions using a binary combiner. If a list
// is empty, return "init".
static std::string combineBinary(ArrayRef<std::string> children,
std::string combiner, std::string init) {
if (children.empty())
return init;
auto size = children.size();
if (size == 1)
return children.front();
std::string str;
llvm::raw_string_ostream os(str);
os << '(' << children.front() << ')';
for (unsigned i = 1; i < size; ++i) {
os << ' ' << combiner << " (" << children[i] << ')';
}
return os.str();
}
// Prepend negation to the only condition in the predicate expression list.
static std::string combineNot(ArrayRef<std::string> children) {
assert(children.size() == 1 && "expected exactly one child predicate of Neg");
return (Twine("!(") + children.front() + Twine(')')).str();
}
// Recursively traverse the predicate tree in depth-first post-order and build
// the final expression.
static std::string getCombinedCondition(const PredNode &root) {
// Immediately return for non-combiner predicates that don't have children.
if (root.kind == PredCombinerKind::Leaf)
return root.expr;
if (root.kind == PredCombinerKind::True)
return "true";
if (root.kind == PredCombinerKind::False)
return "false";
// Recurse into children.
llvm::SmallVector<std::string, 4> childExpressions;
childExpressions.reserve(root.children.size());
for (const auto &child : root.children)
childExpressions.push_back(getCombinedCondition(*child));
// Combine the expressions based on the predicate node kind.
if (root.kind == PredCombinerKind::And)
return combineBinary(childExpressions, "&&", "true");
if (root.kind == PredCombinerKind::Or)
return combineBinary(childExpressions, "||", "false");
if (root.kind == PredCombinerKind::Not)
return combineNot(childExpressions);
// Substitutions were applied before so just ignore them.
if (root.kind == PredCombinerKind::SubstLeaves) {
assert(childExpressions.size() == 1 &&
"substitution predicate must have one child");
return childExpressions[0];
}
llvm::PrintFatalError(root.predicate->getLoc(), "unsupported predicate kind");
}
std::string tblgen::CombinedPred::getConditionImpl() const {
llvm::BumpPtrAllocator allocator;
auto predicateTree = buildPredicateTree(*this, allocator, {});
predicateTree = propagateGroundTruth(
predicateTree,
/*knownTruePreds=*/llvm::SmallPtrSet<tblgen::Pred *, 2>(),
/*knownFalsePreds=*/llvm::SmallPtrSet<tblgen::Pred *, 2>());
return getCombinedCondition(*predicateTree);
}
StringRef tblgen::SubstLeavesPred::getPattern() const {
return def->getValueAsString("pattern");
}
StringRef tblgen::SubstLeavesPred::getReplacement() const {
return def->getValueAsString("replacement");
}

View File

@ -32,13 +32,14 @@ tblgen::TypeConstraint::TypeConstraint(const llvm::Record &record)
tblgen::Pred tblgen::TypeConstraint::getPredicate() const {
auto *val = def.getValue("predicate");
assert(val && "TableGen 'Type' class should have 'predicate' field");
assert(val &&
"TableGen 'TypeConstraint' class should have 'predicate' field");
const auto *pred = dyn_cast<llvm::DefInit>(val->getValue());
return Pred(pred);
}
llvm::StringRef tblgen::TypeConstraint::getConditionTemplate() const {
std::string tblgen::TypeConstraint::getConditionTemplate() const {
return getPredicate().getCondition();
}

View File

@ -180,7 +180,7 @@ void Pattern::matchOp(DagInit *tree, int depth) {
auto constraint = tblgen::TypeConstraint(*defInit);
os.indent(indent)
<< "if (!("
<< formatv(constraint.getConditionTemplate().str().c_str(),
<< formatv(constraint.getConditionTemplate().c_str(),
formatv("op{0}->getOperand({1})->getType()", depth, i))
<< ")) return matchFailure();\n";
}
@ -193,7 +193,7 @@ void Pattern::matchOp(DagInit *tree, int depth) {
tblgen::Pred(defInit->getDef()->getValueInit("predicate"));
os.indent(indent)
<< "if (!("
<< formatv(pred.getCondition().str().c_str(),
<< formatv(pred.getCondition().c_str(),
formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth,
namedAttr->attr.getStorageType(),
namedAttr->getName()))