[TableGen] Use tblgen::DagLeaf to model DAG arguments

This CL added a tblgen::DagLeaf wrapper class with several helper methods for handling
DAG arguments. It helps to refactor the rewriter generation logic to be more higher
level.

This CL also added a tblgen::ConstantAttr wrapper class for constant attributes.

PiperOrigin-RevId: 232050683
This commit is contained in:
Lei Zhang 2019-02-01 15:40:22 -08:00 committed by jpienaar
parent 70e3873e86
commit e0774c008f
7 changed files with 292 additions and 144 deletions

View File

@ -111,6 +111,24 @@ public:
StringRef getDerivedCodeBody() const;
};
// Wrapper class providing helper methods for accessing MLIR constant attribute
// defined in TableGen. This class should closely reflect what is defined as
// class `ConstantAttr` in TableGen.
class ConstantAttr {
public:
explicit ConstantAttr(const llvm::DefInit *init);
// Returns the attribute kind.
Attribute getAttribute() const;
// Returns the constant value.
StringRef getConstantValue() const;
private:
// The TableGen definition of this constant attribute.
const llvm::Record *def;
};
} // end namespace tblgen
} // end namespace mlir

View File

@ -23,6 +23,7 @@
#ifndef MLIR_TABLEGEN_PATTERN_H_
#define MLIR_TABLEGEN_PATTERN_H_
#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Argument.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/DenseMap.h"
@ -30,9 +31,9 @@
#include "llvm/TableGen/Error.h"
namespace llvm {
class Record;
class Init;
class DagInit;
class Init;
class Record;
class StringRef;
} // end namespace llvm
@ -42,20 +43,62 @@ namespace tblgen {
// Mapping from TableGen Record to Operator wrapper object
using RecordOperatorMap = llvm::DenseMap<const llvm::Record *, Operator>;
// Wrapper around DAG argument.
struct DagArg {
DagArg(Argument arg, llvm::Init *constraint)
: arg(arg), constraint(constraint) {}
// Returns true if this DAG argument concerns an operation attribute.
bool isAttr() const;
Argument arg;
llvm::Init *constraint;
};
class Pattern;
// Wrapper class providing helper methods for accessing TableGen DAG leaves
// used inside Patterns. This class is lightweight and designed to be used like
// values.
//
// A TableGen DAG construct is of the syntax
// `(operator, arg0, arg1, ...)`.
//
// This class provides getters to retrieve `arg*` as tblgen:: wrapper objects
// for handy helper methods. It only works on `arg*`s that are not nested DAG
// constructs.
class DagLeaf {
public:
explicit DagLeaf(const llvm::Init *def) : def(def) {}
// Returns true if this DAG leaf is not specified in the pattern. That is, it
// places no further constraints/transforms and just carries over the original
// value.
bool isUnspecified() const;
// Returns true if this DAG leaf is matching an operand. That is, it specifies
// a type constraint.
bool isOperandMatcher() const;
// Returns true if this DAG leaf is matching an attribute. That is, it
// specifies an attribute constraint.
bool isAttrMatcher() const;
// Returns true if this DAG leaf is transforming an attribute.
bool isAttrTransformer() const;
// Returns true if this DAG leaf is specifying a constant attribute.
bool isConstantAttr() const;
// Returns this DAG leaf as a type constraint. Asserts if fails.
TypeConstraint getAsTypeConstraint() const;
// Returns this DAG leaf as an attribute constraint. Asserts if fails.
AttrConstraint getAsAttrConstraint() const;
// Returns this DAG leaf as an constant attribute. Asserts if fails.
ConstantAttr getAsConstantAttr() const;
// Returns the matching condition template inside this DAG leaf. Assumes the
// leaf is an operand/attribute matcher and asserts otherwise.
std::string getConditionTemplate() const;
// Returns the transformation template inside this DAG leaf. Assumes the
// leaf is an attribute matcher and asserts otherwise.
std::string getTransformationTemplate() const;
private:
const llvm::Init *def;
};
// Wrapper class providing helper methods for accessing TableGen DAG constructs
// used inside Patterns. This class is lightweight and designed to be used like
// values.
@ -96,10 +139,9 @@ public:
// Gets the `index`-th argument as a nested DAG construct if possible. Returns
// null DagNode otherwise.
DagNode getArgAsNestedDag(unsigned index) const;
// Gets the `index`-th argument as a TableGen DefInit* if possible. Returns
// nullptr otherwise.
// TODO: This method is exposing raw TableGen object and should be changed.
llvm::DefInit *getArgAsDefInit(unsigned index) const;
// Gets the `index`-th argument as a DAG leaf.
DagLeaf getArgAsLeaf(unsigned index) const;
// Returns the specified name of the `index`-th argument.
llvm::StringRef getArgName(unsigned index) const;
@ -146,7 +188,7 @@ public:
void ensureArgBoundInSourcePattern(llvm::StringRef name) const;
// Returns a reference to all the bound arguments in the source pattern.
llvm::StringMap<DagArg> &getSourcePatternBoundArgs();
llvm::StringMap<Argument> &getSourcePatternBoundArgs();
// Returns the op that the root node of the source pattern matches.
const Operator &getSourceRootOp();
@ -159,8 +201,10 @@ private:
// The TableGen definition of this pattern.
const llvm::Record &def;
RecordOperatorMap *recordOpMap; // All operators
llvm::StringMap<DagArg> boundArguments; // All bound arguments
// All operators
RecordOperatorMap *recordOpMap;
// All bound arguments
llvm::StringMap<Argument> boundArguments;
};
} // end namespace tblgen

View File

@ -42,6 +42,7 @@ public:
explicit TypeConstraint(const llvm::DefInit &init);
bool operator==(const TypeConstraint &that) { return def == that.def; }
bool operator!=(const TypeConstraint &that) { return def != that.def; }
// Returns the predicate that can be used to check if a type satisfies this
// type constraint.

View File

@ -133,3 +133,17 @@ StringRef tblgen::Attribute::getDerivedCodeBody() const {
assert(isDerivedAttr() && "only derived attribute has 'body' field");
return def->getValueAsString("body");
}
tblgen::ConstantAttr::ConstantAttr(const llvm::DefInit *init)
: def(init->getDef()) {
assert(def->isSubClassOf("ConstantAttr") &&
"must be subclass of TableGen 'ConstantAttr' class");
}
tblgen::Attribute tblgen::ConstantAttr::getAttribute() const {
return Attribute(def->getValueAsDef("attr"));
}
StringRef tblgen::ConstantAttr::getConstantValue() const {
return def->getValueAsString("value");
}

View File

@ -28,8 +28,66 @@ using namespace mlir;
using mlir::tblgen::Operator;
bool tblgen::DagArg::isAttr() const {
return arg.is<tblgen::NamedAttribute *>();
bool tblgen::DagLeaf::isUnspecified() const {
return !def || isa<llvm::UnsetInit>(def);
}
bool tblgen::DagLeaf::isOperandMatcher() const {
if (!def || !isa<llvm::DefInit>(def))
return false;
// Operand matchers specify a type constraint.
return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("TypeConstraint");
}
bool tblgen::DagLeaf::isAttrMatcher() const {
if (!def || !isa<llvm::DefInit>(def))
return false;
// Attribute matchers specify a type constraint.
return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("AttrConstraint");
}
bool tblgen::DagLeaf::isAttrTransformer() const {
if (!def || !isa<llvm::DefInit>(def))
return false;
return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("tAttr");
}
bool tblgen::DagLeaf::isConstantAttr() const {
if (!def || !isa<llvm::DefInit>(def))
return false;
return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("ConstantAttr");
}
tblgen::TypeConstraint tblgen::DagLeaf::getAsTypeConstraint() const {
assert(isOperandMatcher() && "the DAG leaf must be operand");
return TypeConstraint(*cast<llvm::DefInit>(def)->getDef());
}
tblgen::AttrConstraint tblgen::DagLeaf::getAsAttrConstraint() const {
assert(isAttrMatcher() && "the DAG leaf must be attribute");
return AttrConstraint(cast<llvm::DefInit>(def)->getDef());
}
tblgen::ConstantAttr tblgen::DagLeaf::getAsConstantAttr() const {
assert(isConstantAttr() && "the DAG leaf must be constant attribute");
return ConstantAttr(cast<llvm::DefInit>(def));
}
std::string tblgen::DagLeaf::getConditionTemplate() const {
assert((isOperandMatcher() || isAttrMatcher()) &&
"the DAG leaf must be operand/attribute matcher");
if (isOperandMatcher()) {
return getAsTypeConstraint().getConditionTemplate();
}
return getAsAttrConstraint().getConditionTemplate();
}
std::string tblgen::DagLeaf::getTransformationTemplate() const {
assert(isAttrTransformer() && "the DAG leaf must be attribute transformer");
return cast<llvm::DefInit>(def)
->getDef()
->getValueAsString("attrTransform")
.str();
}
Operator &tblgen::DagNode::getDialectOp(RecordOperatorMap *mapper) const {
@ -56,8 +114,9 @@ tblgen::DagNode tblgen::DagNode::getArgAsNestedDag(unsigned index) const {
return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
}
llvm::DefInit *tblgen::DagNode::getArgAsDefInit(unsigned index) const {
return dyn_cast<llvm::DefInit>(node->getArg(index));
tblgen::DagLeaf tblgen::DagNode::getArgAsLeaf(unsigned index) const {
assert(!isNestedDagArg(index));
return DagLeaf(node->getArg(index));
}
StringRef tblgen::DagNode::getArgName(unsigned index) const {
@ -81,7 +140,7 @@ static void collectBoundArguments(const llvm::DagInit *tree,
if (name.empty())
continue;
pattern->getSourcePatternBoundArgs().try_emplace(name, op.getArg(i), arg);
pattern->getSourcePatternBoundArgs().try_emplace(name, op.getArg(i));
}
}
@ -131,7 +190,8 @@ void tblgen::Pattern::ensureArgBoundInSourcePattern(
Twine("referencing unbound variable '") + name + "'");
}
llvm::StringMap<tblgen::DagArg> &tblgen::Pattern::getSourcePatternBoundArgs() {
llvm::StringMap<tblgen::Argument> &
tblgen::Pattern::getSourcePatternBoundArgs() {
return boundArguments;
}

View File

@ -3,24 +3,21 @@
include "mlir/IR/op_base.td"
// Create a Type and Attribute.
def YT : BuildableType<"buildYT">;
def Y_Attr : TypeBasedAttr<YT, "Attribute", "attribute of Y type">;
def Y_Const_Attr {
Attr attr = Y_Attr;
string value = "attrValue";
}
def T : BuildableType<"buildT">;
def T_Attr : TypeBasedAttr<T, "Attribute", "attribute of T type">;
def T_Const_Attr : ConstantAttr<T_Attr, "attrValue">;
// Define ops to rewrite.
def T1: Type<CPred<"true">, "T1">;
def U: Type<CPred<"true">, "U">;
def X_AddOp : Op<"x.add"> {
let arguments = (ins T1, T1);
let arguments = (ins U, U);
}
def Y_AddOp : Op<"y.add"> {
let arguments = (ins T1, T1, Y_Attr:$attrName);
let arguments = (ins U, U, T_Attr:$attrName);
}
// Define rewrite pattern.
def : Pat<(X_AddOp $lhs, $rhs), (Y_AddOp $lhs, T1:$rhs, Y_Const_Attr:$x)>;
def : Pat<(X_AddOp $lhs, $rhs), (Y_AddOp $lhs, U:$rhs, T_Const_Attr:$x)>;
// CHECK: struct GeneratedConvert0 : public RewritePattern
// CHECK: RewritePattern("x.add", 1, context)

View File

@ -39,31 +39,11 @@
using namespace llvm;
using namespace mlir;
using mlir::tblgen::Argument;
using mlir::tblgen::Attribute;
using mlir::tblgen::DagNode;
using mlir::tblgen::NamedAttribute;
using mlir::tblgen::Operand;
using mlir::tblgen::Operator;
using mlir::tblgen::Pattern;
using mlir::tblgen::RecordOperatorMap;
using mlir::tblgen::Type;
namespace {
// Wrapper around DAG argument.
struct DagArg {
DagArg(Argument arg, Init *constraintInit)
: arg(arg), constraintInit(constraintInit) {}
bool isAttr();
Argument arg;
Init *constraintInit;
};
} // end namespace
bool DagArg::isAttr() { return arg.is<NamedAttribute *>(); }
namespace {
class PatternEmitter {
@ -93,12 +73,19 @@ private:
void emitReplaceWithNativeBuilder(DagNode resultTree);
// Emits the value of constant attribute to `os`.
void emitAttributeValue(Record *constAttr);
void emitConstantAttr(tblgen::ConstantAttr constAttr);
// Emits C++ statements for matching the op constrained by the given DAG
// `tree`.
void emitOpMatch(DagNode tree, int depth);
// Emits C++ statements for matching the `index`-th argument of the given DAG
// `tree` as an operand.
void emitOperandMatch(DagNode tree, int index, int depth, int indent);
// Emits C++ statements for matching the `index`-th argument of the given DAG
// `tree` as an attribute.
void emitAttributeMatch(DagNode tree, int index, int depth, int indent);
private:
// Pattern instantiation location followed by the location of multiclass
// prototypes used. This is intended to be used as a whole to
@ -107,14 +94,13 @@ private:
// Op's TableGen Record to wrapper object
RecordOperatorMap *opMap;
// Handy wrapper for pattern being emitted
Pattern pattern;
tblgen::Pattern pattern;
raw_ostream &os;
};
} // end namespace
void PatternEmitter::emitAttributeValue(Record *constAttr) {
Attribute attr(constAttr->getValueAsDef("attr"));
auto value = constAttr->getValue("value");
void PatternEmitter::emitConstantAttr(tblgen::ConstantAttr constAttr) {
auto attr = constAttr.getAttribute();
if (!attr.isConstBuildable())
PrintFatalError(loc, "Attribute " + attr.getTableGenDefName() +
@ -122,7 +108,7 @@ void PatternEmitter::emitAttributeValue(Record *constAttr) {
// TODO(jpienaar): Verify the constants here
os << formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter",
value->getValue()->getAsUnquotedString());
constAttr.getConstantValue());
}
// Helper function to match patterns.
@ -137,13 +123,17 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
"if (!op{0}->isa<{1}>()) return matchFailure();\n", depth,
op.getQualCppClassName());
}
if (tree.getNumArgs() != op.getNumArgs())
PrintFatalError(loc, Twine("mismatch in number of arguments to op '") +
op.getOperationName() +
"' in pattern and op's definition");
if (tree.getNumArgs() != op.getNumArgs()) {
PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
"pattern vs. {2} in definition",
op.getOperationName(), tree.getNumArgs(),
op.getNumArgs()));
}
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
auto opArg = op.getArg(i);
// Handle nested DAG construct first
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
os.indent(indent) << "{\n";
os.indent(indent + 2) << formatv(
@ -154,50 +144,78 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
continue;
}
// Verify arguments.
if (auto defInit = tree.getArgAsDefInit(i)) {
// Verify operands.
// Next handle DAG leaf: operand or attribute
if (auto *operand = opArg.dyn_cast<Operand *>()) {
// Skip verification where not needed due to definition of op.
if (operand->type == Type(defInit))
goto StateCapture;
emitOperandMatch(tree, i, depth, indent);
} else if (auto *namedAttr = opArg.dyn_cast<NamedAttribute *>()) {
emitAttributeMatch(tree, i, depth, indent);
} else {
PrintFatalError(loc, "unhandled case when matching op");
}
}
}
if (!defInit->getDef()->isSubClassOf("Type"))
PrintFatalError(loc, "type argument required for operand");
void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
int indent) {
Operator &op = tree.getDialectOp(opMap);
auto *operand = op.getArg(index).get<Operand *>();
auto matcher = tree.getArgAsLeaf(index);
auto constraint = tblgen::TypeConstraint(*defInit);
os.indent(indent)
<< "if (!("
<< formatv(constraint.getConditionTemplate().c_str(),
formatv("op{0}->getOperand({1})->getType()", depth, i))
<< ")) return matchFailure();\n";
// If a constraint is specified, we need to generate C++ statements to
// check the constraint.
if (!matcher.isUnspecified()) {
if (!matcher.isOperandMatcher()) {
PrintFatalError(
loc, formatv("the {1}-th argument of op '{0}' should be an operand",
op.getOperationName(), index + 1));
}
// TODO(jpienaar): Verify attributes.
if (auto *namedAttr = opArg.dyn_cast<NamedAttribute *>()) {
auto constraint = tblgen::AttrConstraint(defInit);
std::string condition = formatv(
constraint.getConditionTemplate().c_str(),
// Only need to verify if the matcher's type is different from the one
// of op definition.
if (static_cast<tblgen::TypeConstraint>(operand->type) !=
matcher.getAsTypeConstraint()) {
os.indent(indent) << "if (!("
<< formatv(matcher.getConditionTemplate().c_str(),
formatv("op{0}->getOperand({1})->getType()",
depth, index))
<< ")) return matchFailure();\n";
}
}
// Capture the value
auto name = tree.getArgName(index);
if (!name.empty()) {
os.indent(indent) << "state->" << name << " = op" << depth
<< "->getOperand(" << index << ");\n";
}
}
void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
int indent) {
Operator &op = tree.getDialectOp(opMap);
auto *namedAttr = op.getArg(index).get<NamedAttribute *>();
auto matcher = tree.getArgAsLeaf(index);
if (!matcher.isUnspecified() && !matcher.isAttrMatcher()) {
PrintFatalError(
loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
op.getOperationName(), index + 1));
}
// If a constraint is specified, we need to generate C++ statements to
// check the constraint.
std::string condition =
formatv(matcher.getConditionTemplate().c_str(),
formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth,
namedAttr->attr.getStorageType(), namedAttr->getName()));
os.indent(indent) << "if (!(" << condition
<< ")) return matchFailure();\n";
}
}
os.indent(indent) << "if (!(" << condition << ")) return matchFailure();\n";
StateCapture:
auto name = tree.getArgName(i);
if (name.empty())
continue;
if (opArg.is<Operand *>())
// Capture the value
auto name = tree.getArgName(index);
if (!name.empty()) {
os.indent(indent) << "state->" << name << " = op" << depth
<< "->getOperand(" << i << ");\n";
if (auto namedAttr = opArg.dyn_cast<NamedAttribute *>()) {
os.indent(indent) << "state->" << name << " = op" << depth
<< "->getAttrOfType<"
<< namedAttr->attr.getStorageType() << ">(\""
<< namedAttr->getName() << "\");\n";
}
<< "->getAttrOfType<" << namedAttr->attr.getStorageType()
<< ">(\"" << namedAttr->getName() << "\");\n";
}
}
@ -234,11 +252,12 @@ void PatternEmitter::emit(StringRef rewriteName) {
// Emit matched state.
os << " struct MatchedState : public PatternState {\n";
for (const auto &arg : pattern.getSourcePatternBoundArgs()) {
if (auto namedAttr = arg.second.arg.dyn_cast<NamedAttribute *>()) {
os.indent(4) << namedAttr->attr.getStorageType() << " " << arg.first()
auto fieldName = arg.first();
if (auto namedAttr = arg.second.dyn_cast<NamedAttribute *>()) {
os.indent(4) << namedAttr->attr.getStorageType() << " " << fieldName
<< ";\n";
} else {
os.indent(4) << "Value* " << arg.first() << ";\n";
os.indent(4) << "Value* " << fieldName << ";\n";
}
}
os << " };\n";
@ -285,10 +304,10 @@ void PatternEmitter::emitReplaceOpWithNewOp(DagNode resultTree) {
rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
resultOp.getCppClassName());
if (numOpArgs != resultTree.getNumArgs()) {
PrintFatalError(loc, Twine("mismatch between arguments of resultant op (") +
Twine(numOpArgs) +
") and arguments provided for rewrite (" +
Twine(resultTree.getNumArgs()) + Twine(')'));
PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: "
"{1} in pattern vs. {2} in definition",
resultOp.getOperationName(),
resultTree.getNumArgs(), numOpArgs));
}
// Create the builder call for the result.
@ -312,38 +331,33 @@ void PatternEmitter::emitReplaceOpWithNewOp(DagNode resultTree) {
// Start each attribute on its own line.
(os << ",\n").indent(6);
auto leaf = resultTree.getArgAsLeaf(i);
// The argument in the result DAG pattern.
auto argName = resultTree.getArgName(i);
auto opName = resultOp.getArgName(i);
auto *defInit = resultTree.getArgAsDefInit(i);
auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
if (!value) {
pattern.ensureArgBoundInSourcePattern(argName);
auto result = "s." + argName;
os << "/*" << opName << "=*/";
if (defInit) {
auto transform = defInit->getDef();
if (transform->isSubClassOf("tAttr")) {
// TODO(jpienaar): move to helper class.
os << formatv(
transform->getValueAsString("attrTransform").str().c_str(),
result);
continue;
}
}
os << result;
continue;
}
auto patArgName = resultTree.getArgName(i);
// The argument in the op definition.
auto opArgName = resultOp.getArgName(i);
if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
pattern.ensureArgBoundInSourcePattern(patArgName);
os << formatv("/*{0}=*/s.{1}", opArgName, patArgName);
} else if (leaf.isAttrTransformer()) {
pattern.ensureArgBoundInSourcePattern(patArgName);
std::string result = std::string("s.") + patArgName.str();
result = formatv(leaf.getTransformationTemplate().c_str(), result);
os << formatv("/*{0}=*/{1}", opArgName, result);
} else if (leaf.isConstantAttr()) {
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
auto argument = resultOp.getArg(i);
if (!argument.is<NamedAttribute *>())
PrintFatalError(loc, Twine("expected attribute ") + Twine(i));
if (!argName.empty())
os << "/*" << argName << "=*/";
emitAttributeValue(defInit->getDef());
if (!patArgName.empty())
os << "/*" << patArgName << "=*/";
emitConstantAttr(leaf.getAsConstantAttr());
// TODO(jpienaar): verify types
} else {
PrintFatalError(loc, "unhandled case when rewriting op");
}
}
os << "\n );\n";
}
@ -367,7 +381,7 @@ void PatternEmitter::emitReplaceWithNativeBuilder(DagNode resultTree) {
auto name = resultTree.getArgName(i);
pattern.ensureArgBoundInSourcePattern(name);
const auto &val = boundedValues.find(name);
if (val->second.isAttr() && !printingAttr) {
if (val->second.dyn_cast<NamedAttribute *>() && !printingAttr) {
os << "}, {";
first = true;
printingAttr = true;