[TableGen] Support nested dag attributes arguments in the result pattern

Add support to create a new attribute from multiple attributes. It extended the
DagNode class to represent attribute creation dag. It also changed the
RewriterGen::emitOpCreate method to support this nested dag emit.

An unit test is added.

PiperOrigin-RevId: 238090229
This commit is contained in:
Feng Liu 2019-03-12 13:55:50 -07:00 committed by jpienaar
parent 6558f80c8d
commit c52a812700
5 changed files with 157 additions and 39 deletions

View File

@ -587,10 +587,20 @@ class mAttrAnyOf<list<AttrConstraint> attrs> :
!listconcat(prev, [attr.predicate]))>>;
// Attribute transformation. This is the base class to specify a transformation
// of a matched attribute. Used on the output attribute of a rewrite rule.
// of matched attributes. Used on the output attribute of a rewrite rule.
class tAttr<code transform> {
// Code to transform the attribute.
// Format: {0} represents the attribute.
// Code to transform the attributes.
// Format:
// - When it is used as a dag node, {0} represents the builder, {i}
// represents the (i-1)-th attribute argument when i >= 1. For example:
// def attr: tAttr<"{0}.compose({{{1}, {2}})"> for '(attr $a, $b)' will
// expand to '(builder.compose({foo, bar}))'.
// - When it is used as a dag leaf, {0} represents the attribute.
// For example:
// def attr: tAttr<"{0}.cast<FloatAttr>()"> for 'attr:$a' will expand to
// 'foo.cast<FloatAttr>()'.
// In both examples, `foo` and `bar` are the C++ bounded attribute variables
// of $a and $b.
code attrTransform = transform;
}

View File

@ -93,7 +93,7 @@ public:
std::string getConditionTemplate() const;
// Returns the transformation template inside this DAG leaf. Assumes the
// leaf is an attribute matcher and asserts otherwise.
// leaf is an attribute transformation and asserts otherwise.
std::string getTransformationTemplate() const;
private:
@ -169,6 +169,13 @@ public:
// constructor.
bool isNativeCodeBuilder() const;
// Returns true if this DAG construct is transforming attributes.
bool isAttrTransformer() const;
// Returns the transformation template inside this DAG construct.
// Precondition: isAttrTransformer.
std::string getTransformationTemplate() const;
private:
const llvm::DagInit *node; // nullptr means null DagNode
};

View File

@ -42,7 +42,7 @@ bool tblgen::DagLeaf::isOperandMatcher() const {
bool tblgen::DagLeaf::isAttrMatcher() const {
if (!def || !isa<llvm::DefInit>(def))
return false;
// Attribute matchers specify a type constraint.
// Attribute matchers specify an attribute constraint.
return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("AttrConstraint");
}
@ -90,6 +90,21 @@ std::string tblgen::DagLeaf::getTransformationTemplate() const {
.str();
}
bool tblgen::DagNode::isAttrTransformer() const {
auto op = node->getOperator();
if (!op || !isa<llvm::DefInit>(op))
return false;
return cast<llvm::DefInit>(op)->getDef()->isSubClassOf("tAttr");
}
std::string tblgen::DagNode::getTransformationTemplate() const {
assert(isAttrTransformer() && "the DAG leaf must be attribute transformer");
return cast<llvm::DefInit>(node->getOperator())
->getDef()
->getValueAsString("attrTransform")
.str();
}
llvm::StringRef tblgen::DagNode::getOpName() const {
return node->getNameStr();
}

View File

@ -0,0 +1,52 @@
// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s
include "mlir/IR/OpBase.td"
// Create a Type and Attribute.
def T : BuildableType<"buildT">;
def T_Attr : TypeBasedAttr<T, "Attribute", "attribute of T type">;
def T_Const_Attr : ConstantAttr<T_Attr, "attrValue">;
def T_Compose_Attr : tAttr<"{0}.getArrayAttr({{{1}, {2}})">;
// Define ops to rewrite.
def U: Type<CPred<"true">, "U">;
def X_AddOp : Op<"x.add"> {
let arguments = (ins U, U);
}
def Y_AddOp : Op<"y.add"> {
let arguments = (ins U, U, T_Attr:$attrName);
}
def Z_AddOp : Op<"z.add"> {
let arguments = (ins U, U, T_Attr:$attrName1, T_Attr:$attrName2);
}
// Define rewrite pattern.
def : Pat<(Y_AddOp $lhs, $rhs, $attr1), (Y_AddOp $lhs, $rhs, (T_Compose_Attr $attr1, T_Const_Attr:$attr2))>;
// CHECK: struct GeneratedConvert0 : public RewritePattern
// CHECK: RewritePattern("y.add", 1, context)
// CHECK: PatternMatchResult match(Instruction *
// CHECK: void rewrite(Instruction *op, std::unique_ptr<PatternState>
// CHECK-NEXT: PatternRewriter &rewriter)
// CHECK: auto vAddOp0 = rewriter.create<Y::AddOp>(loc, op->getResult(0)->getType(),
// CHECK-NEXT: s.lhs,
// CHECK-NEXT: s.rhs,
// CHECK-NEXT: /*attrName=*/rewriter.getArrayAttr({s.attr1, rewriter.getAttribute(rewriter.buildT, attrValue)})
// CHECK-NEXT: );
// CHECK-NEXT: rewriter.replaceOp(op, {vAddOp0});
def : Pat<(Z_AddOp $lhs, $rhs, $attr1, $attr2), (Y_AddOp $lhs, $rhs, (T_Compose_Attr $attr1, $attr2))>;
// CHECK: struct GeneratedConvert1 : public RewritePattern
// CHECK: RewritePattern("z.add", 1, context)
// CHECK: PatternMatchResult match(Instruction *
// CHECK: void rewrite(Instruction *op, std::unique_ptr<PatternState>
// CHECK-NEXT: PatternRewriter &rewriter)
// CHECK: auto vAddOp0 = rewriter.create<Y::AddOp>(loc, op->getResult(0)->getType(),
// CHECK-NEXT: s.lhs,
// CHECK-NEXT: s.rhs,
// CHECK-NEXT: /*attrName=*/rewriter.getArrayAttr({s.attr1, s.attr2})
// CHECK-NEXT: );
// CHECK-NEXT: rewriter.replaceOp(op, {vAddOp0});
// CHECK: void populateWithGenerated
// CHECK: patterns->push_back(std::make_unique<GeneratedConvert0>(context))
// CHECK: patterns->push_back(std::make_unique<GeneratedConvert1>(context))

View File

@ -40,6 +40,7 @@
using namespace llvm;
using namespace mlir;
using mlir::tblgen::DagLeaf;
using mlir::tblgen::DagNode;
using mlir::tblgen::NamedAttribute;
using mlir::tblgen::Operator;
@ -64,9 +65,6 @@ private:
// Emits the rewrite() method.
void emitRewriteMethod();
// Emits the value of constant attribute to `os`.
void emitConstantAttr(tblgen::ConstantAttr constAttr);
// Emits C++ statements for matching the op constrained by the given DAG
// `tree`.
void emitOpMatch(DagNode tree, int depth);
@ -106,6 +104,16 @@ private:
// result value name.
std::string emitOpCreate(DagNode tree, int resultIndex, int depth);
// Returns the string value of constant attribute as an argument.
std::string handleConstantAttr(tblgen::ConstantAttr constAttr);
// Returns the C++ expression to build an argument from the given DAG `leaf`.
// `patArgName` is used to bound the argument to the source pattern.
std::string handleOpArgument(DagLeaf leaf, llvm::StringRef patArgName);
// Returns the C++ expression to build an argument from the given DAG `tree`.
std::string handleOpArgument(DagNode tree);
private:
// Pattern instantiation location followed by the location of multiclass
// prototypes used. This is intended to be used as a whole to
@ -126,7 +134,7 @@ PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
: loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), nextValueId(0),
os(os) {}
void PatternEmitter::emitConstantAttr(tblgen::ConstantAttr constAttr) {
std::string PatternEmitter::handleConstantAttr(tblgen::ConstantAttr constAttr) {
auto attr = constAttr.getAttribute();
if (!attr.isConstBuildable())
@ -134,8 +142,8 @@ void PatternEmitter::emitConstantAttr(tblgen::ConstantAttr constAttr) {
" does not have the 'constBuilderCall' field");
// TODO(jpienaar): Verify the constants here
os << formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter",
constAttr.getConstantValue());
return formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter",
constAttr.getConstantValue());
}
static Twine resultName(const StringRef &name) { return Twine("res_") + name; }
@ -313,9 +321,10 @@ void PatternEmitter::emitMatchMethod(DagNode tree) {
} else if (constraint.isNativeConstraint()) {
os.indent(4) << "if (!" << constraint.getNativeConstraintFunction()
<< "(";
interleave(constraint.name_begin(), constraint.name_end(),
[&](const std::string &name) { os << deduceName(name); },
[&]() { os << ", "; });
interleave(
constraint.name_begin(), constraint.name_end(),
[&](const std::string &name) { os << deduceName(name); },
[&]() { os << ", "; });
os << ")) return matchFailure();\n";
} else {
llvm_unreachable(
@ -439,6 +448,40 @@ void PatternEmitter::handleVerifyUnusedValue(DagNode tree, int index) {
<< ")->use_empty()) return matchFailure();\n";
}
std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
llvm::StringRef argName) {
if (leaf.isConstantAttr()) {
return handleConstantAttr(leaf.getAsConstantAttr());
}
pattern.ensureArgBoundInSourcePattern(argName);
std::string result = boundArgNameInRewrite(argName).str();
if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
return result;
}
if (leaf.isAttrTransformer()) {
return formatv(leaf.getTransformationTemplate().c_str(), result);
}
PrintFatalError(loc, "unhandled case when rewriting op");
}
std::string PatternEmitter::handleOpArgument(DagNode tree) {
if (!tree.isAttrTransformer()) {
PrintFatalError(loc, "only tAttr is supported in nested dag attribute");
}
auto tempStr = tree.getTransformationTemplate();
// TODO(fengliuai): replace formatv arguments with the exact specified args.
SmallVector<std::string, 8> attrs(8);
if (tree.getNumArgs() > 8) {
PrintFatalError(loc, "unsupported tAttr argument numbers: " +
Twine(tree.getNumArgs()));
}
for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) {
attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
}
return formatv(tempStr.c_str(), "rewriter", attrs[0], attrs[1], attrs[2],
attrs[3], attrs[4], attrs[5], attrs[6], attrs[7]);
}
std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
int depth) {
Operator &resultOp = tree.getDialectOp(opMap);
@ -538,34 +581,25 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
for (int e = tree.getNumArgs(); i != e; ++i) {
// Start each attribute on its own line.
(os << ",\n").indent(6);
auto leaf = tree.getArgAsLeaf(i);
// The argument in the result DAG pattern.
auto patArgName = tree.getArgName(i);
// The argument in the op definition.
auto opArgName = resultOp.getArgName(i);
if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
pattern.ensureArgBoundInSourcePattern(patArgName);
std::string result = boundArgNameInRewrite(patArgName).str();
os << formatv("/*{0}=*/{1}", opArgName, result);
} else if (leaf.isAttrTransformer()) {
pattern.ensureArgBoundInSourcePattern(patArgName);
std::string result = boundArgNameInRewrite(patArgName).str();
result = formatv(leaf.getTransformationTemplate().c_str(), result);
os << formatv("/*{0}=*/{1}", opArgName, result);
} else if (leaf.isConstantAttr()) {
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
auto argument = resultOp.getArg(i);
if (!argument.is<NamedAttribute *>())
PrintFatalError(loc, Twine("expected attribute ") + Twine(i));
if (!patArgName.empty())
os << "/*" << patArgName << "=*/";
emitConstantAttr(leaf.getAsConstantAttr());
// TODO(jpienaar): verify types
if (auto subTree = tree.getArgAsNestedDag(i)) {
os << formatv("/*{0}=*/{1}", opArgName, handleOpArgument(subTree));
} else {
PrintFatalError(loc, "unhandled case when rewriting op");
auto leaf = tree.getArgAsLeaf(i);
// The argument in the result DAG pattern.
auto patArgName = tree.getArgName(i);
if (leaf.isConstantAttr()) {
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
auto argument = resultOp.getArg(i);
if (!argument.is<NamedAttribute *>())
PrintFatalError(loc, Twine("expected attribute ") + Twine(i));
if (!patArgName.empty())
os << "/*" << patArgName << "=*/";
} else {
os << "/*" << opArgName << "=*/";
}
os << handleOpArgument(leaf, patArgName);
}
}
os << "\n );\n";