forked from OSchip/llvm-project
[TableGen] Support nested dag attributes arguments in the result pattern
Add support to create a new attribute from multiple attributes. It extended the DagNode class to represent attribute creation dag. It also changed the RewriterGen::emitOpCreate method to support this nested dag emit. An unit test is added. PiperOrigin-RevId: 238090229
This commit is contained in:
parent
6558f80c8d
commit
c52a812700
|
@ -587,10 +587,20 @@ class mAttrAnyOf<list<AttrConstraint> attrs> :
|
|||
!listconcat(prev, [attr.predicate]))>>;
|
||||
|
||||
// Attribute transformation. This is the base class to specify a transformation
|
||||
// of a matched attribute. Used on the output attribute of a rewrite rule.
|
||||
// of matched attributes. Used on the output attribute of a rewrite rule.
|
||||
class tAttr<code transform> {
|
||||
// Code to transform the attribute.
|
||||
// Format: {0} represents the attribute.
|
||||
// Code to transform the attributes.
|
||||
// Format:
|
||||
// - When it is used as a dag node, {0} represents the builder, {i}
|
||||
// represents the (i-1)-th attribute argument when i >= 1. For example:
|
||||
// def attr: tAttr<"{0}.compose({{{1}, {2}})"> for '(attr $a, $b)' will
|
||||
// expand to '(builder.compose({foo, bar}))'.
|
||||
// - When it is used as a dag leaf, {0} represents the attribute.
|
||||
// For example:
|
||||
// def attr: tAttr<"{0}.cast<FloatAttr>()"> for 'attr:$a' will expand to
|
||||
// 'foo.cast<FloatAttr>()'.
|
||||
// In both examples, `foo` and `bar` are the C++ bounded attribute variables
|
||||
// of $a and $b.
|
||||
code attrTransform = transform;
|
||||
}
|
||||
|
||||
|
|
|
@ -93,7 +93,7 @@ public:
|
|||
std::string getConditionTemplate() const;
|
||||
|
||||
// Returns the transformation template inside this DAG leaf. Assumes the
|
||||
// leaf is an attribute matcher and asserts otherwise.
|
||||
// leaf is an attribute transformation and asserts otherwise.
|
||||
std::string getTransformationTemplate() const;
|
||||
|
||||
private:
|
||||
|
@ -169,6 +169,13 @@ public:
|
|||
// constructor.
|
||||
bool isNativeCodeBuilder() const;
|
||||
|
||||
// Returns true if this DAG construct is transforming attributes.
|
||||
bool isAttrTransformer() const;
|
||||
|
||||
// Returns the transformation template inside this DAG construct.
|
||||
// Precondition: isAttrTransformer.
|
||||
std::string getTransformationTemplate() const;
|
||||
|
||||
private:
|
||||
const llvm::DagInit *node; // nullptr means null DagNode
|
||||
};
|
||||
|
|
|
@ -42,7 +42,7 @@ bool tblgen::DagLeaf::isOperandMatcher() const {
|
|||
bool tblgen::DagLeaf::isAttrMatcher() const {
|
||||
if (!def || !isa<llvm::DefInit>(def))
|
||||
return false;
|
||||
// Attribute matchers specify a type constraint.
|
||||
// Attribute matchers specify an attribute constraint.
|
||||
return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("AttrConstraint");
|
||||
}
|
||||
|
||||
|
@ -90,6 +90,21 @@ std::string tblgen::DagLeaf::getTransformationTemplate() const {
|
|||
.str();
|
||||
}
|
||||
|
||||
bool tblgen::DagNode::isAttrTransformer() const {
|
||||
auto op = node->getOperator();
|
||||
if (!op || !isa<llvm::DefInit>(op))
|
||||
return false;
|
||||
return cast<llvm::DefInit>(op)->getDef()->isSubClassOf("tAttr");
|
||||
}
|
||||
|
||||
std::string tblgen::DagNode::getTransformationTemplate() const {
|
||||
assert(isAttrTransformer() && "the DAG leaf must be attribute transformer");
|
||||
return cast<llvm::DefInit>(node->getOperator())
|
||||
->getDef()
|
||||
->getValueAsString("attrTransform")
|
||||
.str();
|
||||
}
|
||||
|
||||
llvm::StringRef tblgen::DagNode::getOpName() const {
|
||||
return node->getNameStr();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
// Create a Type and Attribute.
|
||||
def T : BuildableType<"buildT">;
|
||||
def T_Attr : TypeBasedAttr<T, "Attribute", "attribute of T type">;
|
||||
def T_Const_Attr : ConstantAttr<T_Attr, "attrValue">;
|
||||
def T_Compose_Attr : tAttr<"{0}.getArrayAttr({{{1}, {2}})">;
|
||||
|
||||
// Define ops to rewrite.
|
||||
def U: Type<CPred<"true">, "U">;
|
||||
def X_AddOp : Op<"x.add"> {
|
||||
let arguments = (ins U, U);
|
||||
}
|
||||
def Y_AddOp : Op<"y.add"> {
|
||||
let arguments = (ins U, U, T_Attr:$attrName);
|
||||
}
|
||||
def Z_AddOp : Op<"z.add"> {
|
||||
let arguments = (ins U, U, T_Attr:$attrName1, T_Attr:$attrName2);
|
||||
}
|
||||
|
||||
// Define rewrite pattern.
|
||||
def : Pat<(Y_AddOp $lhs, $rhs, $attr1), (Y_AddOp $lhs, $rhs, (T_Compose_Attr $attr1, T_Const_Attr:$attr2))>;
|
||||
// CHECK: struct GeneratedConvert0 : public RewritePattern
|
||||
// CHECK: RewritePattern("y.add", 1, context)
|
||||
// CHECK: PatternMatchResult match(Instruction *
|
||||
// CHECK: void rewrite(Instruction *op, std::unique_ptr<PatternState>
|
||||
// CHECK-NEXT: PatternRewriter &rewriter)
|
||||
// CHECK: auto vAddOp0 = rewriter.create<Y::AddOp>(loc, op->getResult(0)->getType(),
|
||||
// CHECK-NEXT: s.lhs,
|
||||
// CHECK-NEXT: s.rhs,
|
||||
// CHECK-NEXT: /*attrName=*/rewriter.getArrayAttr({s.attr1, rewriter.getAttribute(rewriter.buildT, attrValue)})
|
||||
// CHECK-NEXT: );
|
||||
// CHECK-NEXT: rewriter.replaceOp(op, {vAddOp0});
|
||||
|
||||
def : Pat<(Z_AddOp $lhs, $rhs, $attr1, $attr2), (Y_AddOp $lhs, $rhs, (T_Compose_Attr $attr1, $attr2))>;
|
||||
// CHECK: struct GeneratedConvert1 : public RewritePattern
|
||||
// CHECK: RewritePattern("z.add", 1, context)
|
||||
// CHECK: PatternMatchResult match(Instruction *
|
||||
// CHECK: void rewrite(Instruction *op, std::unique_ptr<PatternState>
|
||||
// CHECK-NEXT: PatternRewriter &rewriter)
|
||||
// CHECK: auto vAddOp0 = rewriter.create<Y::AddOp>(loc, op->getResult(0)->getType(),
|
||||
// CHECK-NEXT: s.lhs,
|
||||
// CHECK-NEXT: s.rhs,
|
||||
// CHECK-NEXT: /*attrName=*/rewriter.getArrayAttr({s.attr1, s.attr2})
|
||||
// CHECK-NEXT: );
|
||||
// CHECK-NEXT: rewriter.replaceOp(op, {vAddOp0});
|
||||
|
||||
// CHECK: void populateWithGenerated
|
||||
// CHECK: patterns->push_back(std::make_unique<GeneratedConvert0>(context))
|
||||
// CHECK: patterns->push_back(std::make_unique<GeneratedConvert1>(context))
|
|
@ -40,6 +40,7 @@
|
|||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
|
||||
using mlir::tblgen::DagLeaf;
|
||||
using mlir::tblgen::DagNode;
|
||||
using mlir::tblgen::NamedAttribute;
|
||||
using mlir::tblgen::Operator;
|
||||
|
@ -64,9 +65,6 @@ private:
|
|||
// Emits the rewrite() method.
|
||||
void emitRewriteMethod();
|
||||
|
||||
// Emits the value of constant attribute to `os`.
|
||||
void emitConstantAttr(tblgen::ConstantAttr constAttr);
|
||||
|
||||
// Emits C++ statements for matching the op constrained by the given DAG
|
||||
// `tree`.
|
||||
void emitOpMatch(DagNode tree, int depth);
|
||||
|
@ -106,6 +104,16 @@ private:
|
|||
// result value name.
|
||||
std::string emitOpCreate(DagNode tree, int resultIndex, int depth);
|
||||
|
||||
// Returns the string value of constant attribute as an argument.
|
||||
std::string handleConstantAttr(tblgen::ConstantAttr constAttr);
|
||||
|
||||
// Returns the C++ expression to build an argument from the given DAG `leaf`.
|
||||
// `patArgName` is used to bound the argument to the source pattern.
|
||||
std::string handleOpArgument(DagLeaf leaf, llvm::StringRef patArgName);
|
||||
|
||||
// Returns the C++ expression to build an argument from the given DAG `tree`.
|
||||
std::string handleOpArgument(DagNode tree);
|
||||
|
||||
private:
|
||||
// Pattern instantiation location followed by the location of multiclass
|
||||
// prototypes used. This is intended to be used as a whole to
|
||||
|
@ -126,7 +134,7 @@ PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
|
|||
: loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), nextValueId(0),
|
||||
os(os) {}
|
||||
|
||||
void PatternEmitter::emitConstantAttr(tblgen::ConstantAttr constAttr) {
|
||||
std::string PatternEmitter::handleConstantAttr(tblgen::ConstantAttr constAttr) {
|
||||
auto attr = constAttr.getAttribute();
|
||||
|
||||
if (!attr.isConstBuildable())
|
||||
|
@ -134,8 +142,8 @@ void PatternEmitter::emitConstantAttr(tblgen::ConstantAttr constAttr) {
|
|||
" does not have the 'constBuilderCall' field");
|
||||
|
||||
// TODO(jpienaar): Verify the constants here
|
||||
os << formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter",
|
||||
constAttr.getConstantValue());
|
||||
return formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter",
|
||||
constAttr.getConstantValue());
|
||||
}
|
||||
|
||||
static Twine resultName(const StringRef &name) { return Twine("res_") + name; }
|
||||
|
@ -313,9 +321,10 @@ void PatternEmitter::emitMatchMethod(DagNode tree) {
|
|||
} else if (constraint.isNativeConstraint()) {
|
||||
os.indent(4) << "if (!" << constraint.getNativeConstraintFunction()
|
||||
<< "(";
|
||||
interleave(constraint.name_begin(), constraint.name_end(),
|
||||
[&](const std::string &name) { os << deduceName(name); },
|
||||
[&]() { os << ", "; });
|
||||
interleave(
|
||||
constraint.name_begin(), constraint.name_end(),
|
||||
[&](const std::string &name) { os << deduceName(name); },
|
||||
[&]() { os << ", "; });
|
||||
os << ")) return matchFailure();\n";
|
||||
} else {
|
||||
llvm_unreachable(
|
||||
|
@ -439,6 +448,40 @@ void PatternEmitter::handleVerifyUnusedValue(DagNode tree, int index) {
|
|||
<< ")->use_empty()) return matchFailure();\n";
|
||||
}
|
||||
|
||||
std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
|
||||
llvm::StringRef argName) {
|
||||
if (leaf.isConstantAttr()) {
|
||||
return handleConstantAttr(leaf.getAsConstantAttr());
|
||||
}
|
||||
pattern.ensureArgBoundInSourcePattern(argName);
|
||||
std::string result = boundArgNameInRewrite(argName).str();
|
||||
if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
|
||||
return result;
|
||||
}
|
||||
if (leaf.isAttrTransformer()) {
|
||||
return formatv(leaf.getTransformationTemplate().c_str(), result);
|
||||
}
|
||||
PrintFatalError(loc, "unhandled case when rewriting op");
|
||||
}
|
||||
|
||||
std::string PatternEmitter::handleOpArgument(DagNode tree) {
|
||||
if (!tree.isAttrTransformer()) {
|
||||
PrintFatalError(loc, "only tAttr is supported in nested dag attribute");
|
||||
}
|
||||
auto tempStr = tree.getTransformationTemplate();
|
||||
// TODO(fengliuai): replace formatv arguments with the exact specified args.
|
||||
SmallVector<std::string, 8> attrs(8);
|
||||
if (tree.getNumArgs() > 8) {
|
||||
PrintFatalError(loc, "unsupported tAttr argument numbers: " +
|
||||
Twine(tree.getNumArgs()));
|
||||
}
|
||||
for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) {
|
||||
attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
|
||||
}
|
||||
return formatv(tempStr.c_str(), "rewriter", attrs[0], attrs[1], attrs[2],
|
||||
attrs[3], attrs[4], attrs[5], attrs[6], attrs[7]);
|
||||
}
|
||||
|
||||
std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
|
||||
int depth) {
|
||||
Operator &resultOp = tree.getDialectOp(opMap);
|
||||
|
@ -538,34 +581,25 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
|
|||
for (int e = tree.getNumArgs(); i != e; ++i) {
|
||||
// Start each attribute on its own line.
|
||||
(os << ",\n").indent(6);
|
||||
|
||||
auto leaf = tree.getArgAsLeaf(i);
|
||||
// The argument in the result DAG pattern.
|
||||
auto patArgName = tree.getArgName(i);
|
||||
// The argument in the op definition.
|
||||
auto opArgName = resultOp.getArgName(i);
|
||||
|
||||
if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
|
||||
pattern.ensureArgBoundInSourcePattern(patArgName);
|
||||
std::string result = boundArgNameInRewrite(patArgName).str();
|
||||
os << formatv("/*{0}=*/{1}", opArgName, result);
|
||||
} else if (leaf.isAttrTransformer()) {
|
||||
pattern.ensureArgBoundInSourcePattern(patArgName);
|
||||
std::string result = boundArgNameInRewrite(patArgName).str();
|
||||
result = formatv(leaf.getTransformationTemplate().c_str(), result);
|
||||
os << formatv("/*{0}=*/{1}", opArgName, result);
|
||||
} else if (leaf.isConstantAttr()) {
|
||||
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
|
||||
auto argument = resultOp.getArg(i);
|
||||
if (!argument.is<NamedAttribute *>())
|
||||
PrintFatalError(loc, Twine("expected attribute ") + Twine(i));
|
||||
|
||||
if (!patArgName.empty())
|
||||
os << "/*" << patArgName << "=*/";
|
||||
emitConstantAttr(leaf.getAsConstantAttr());
|
||||
// TODO(jpienaar): verify types
|
||||
if (auto subTree = tree.getArgAsNestedDag(i)) {
|
||||
os << formatv("/*{0}=*/{1}", opArgName, handleOpArgument(subTree));
|
||||
} else {
|
||||
PrintFatalError(loc, "unhandled case when rewriting op");
|
||||
auto leaf = tree.getArgAsLeaf(i);
|
||||
// The argument in the result DAG pattern.
|
||||
auto patArgName = tree.getArgName(i);
|
||||
if (leaf.isConstantAttr()) {
|
||||
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
|
||||
auto argument = resultOp.getArg(i);
|
||||
if (!argument.is<NamedAttribute *>())
|
||||
PrintFatalError(loc, Twine("expected attribute ") + Twine(i));
|
||||
if (!patArgName.empty())
|
||||
os << "/*" << patArgName << "=*/";
|
||||
} else {
|
||||
os << "/*" << opArgName << "=*/";
|
||||
}
|
||||
os << handleOpArgument(leaf, patArgName);
|
||||
}
|
||||
}
|
||||
os << "\n );\n";
|
||||
|
|
Loading…
Reference in New Issue