diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 9c005059753e..8e9fb57bbe76 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -984,42 +984,25 @@ class Pat preds = [], dag benefitAdded = (addBenefit 0)> : Pattern; -// Attribute transformation. This is the base class to specify a transformation -// of matched attributes. Used on the output attribute of a rewrite rule. +// Native code call wrapper. This allows invoking an arbitrary C++ expression +// to create an op operand/attribute or replace an op result. // // ## Placeholders // -// The following special placeholders are supported +// If used as a DAG leaf, i.e., `(... NativeCodeCall<"...">:$arg, ...)`, +// the wrapped expression can take special placeholders listed below: // // * `$_builder` will be replaced by the current `mlir::PatternRewriter`. // * `$_self` will be replaced with the entity this transformer is attached to. // E.g., with the definition `def transform : tAttr<$_self...>`, `$_self` in // `transform:$attr` will be replaced by the value for `$att`. - -// Besides, if this is used as a DAG node, i.e., `(tAttr , ..., )`, -// then positional placeholders are supported and placholder `$N` will be -// replaced by ``. -class tAttr { - code attrTransform = transform; -} - -// Native code op creation method. This allows performing an arbitrary op -// creation/replacement by invoking a C++ function with the operands and -// attributes. The function specified needs to have the signature: // -// void f(Operation *op, ArrayRef operands, -// ArrayRef attrs, PatternRewriter &rewriter); -// -// The operands and attributes are passed to this function in the order of -// the DAG specified. It is the responsibility of this function to replace the -// matched op(s) using the rewriter. This is intended for the long tail op -// creation and replacement. -// TODO(antiagainst): Unify this and tAttr into a single creation mechanism. -class cOp { - // Function to invoke with the given arguments to construct a new op. The - // operands will be passed to the function first followed by the attributes - // (as in the function signature above and required by Op arguments). - string function = f; +// If used as a DAG node, i.e., `(NativeCodeCall<"..."> , ..., )`, +// then positional placeholders are also supported; placeholder `$N` in the +// wrapped C++ expression will be replaced by ``. + +class NativeCodeCall { + string expression = expr; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h index 22bc7b303fa2..e833e49c73cc 100644 --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -77,8 +77,8 @@ public: // specifies an attribute constraint. bool isAttrMatcher() const; - // Returns true if this DAG leaf is transforming an attribute. - bool isAttrTransformer() const; + // Returns true if this DAG leaf is wrapping native code call. + bool isNativeCodeCall() const; // Returns true if this DAG leaf is specifying a constant attribute. bool isConstantAttr() const; @@ -100,9 +100,9 @@ public: // leaf is an operand/attribute matcher and asserts otherwise. std::string getConditionTemplate() const; - // Returns the transformation template inside this DAG leaf. Assumes the - // leaf is an attribute transformation and asserts otherwise. - std::string getTransformationTemplate() const; + // Returns the native code call template inside this DAG leaf. + // Precondition: isNativeCodeCall() + llvm::StringRef getNativeCodeTemplate() const; private: // Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and @@ -162,10 +162,6 @@ public: // Returns the specified name of the `index`-th argument. llvm::StringRef getArgName(unsigned index) const; - // Returns the native builder for the pattern. - // Precondition: isNativeCodeBuilder. - llvm::StringRef getNativeCodeBuilder() const; - // Returns true if this DAG construct means to replace with an existing SSA // value. bool isReplaceWithValue() const; @@ -173,16 +169,12 @@ public: // Returns true if this DAG node is the `verifyUnusedValue` directive. bool isVerifyUnusedValue() const; - // Returns true if this DAG construct is meant to invoke a native code - // constructor. - bool isNativeCodeBuilder() const; + // Returns true if this DAG node is wrapping native code call. + bool isNativeCodeCall() const; - // Returns true if this DAG construct is transforming attributes. - bool isAttrTransformer() const; - - // Returns the transformation template inside this DAG construct. - // Precondition: isAttrTransformer. - std::string getTransformationTemplate() const; + // Returns the native code call template inside this DAG node. + // Precondition: isNativeCodeCall() + llvm::StringRef getNativeCodeTemplate() const; private: const llvm::DagInit *node; // nullptr means null DagNode diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index 92267d140357..420f9d2090aa 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -44,8 +44,8 @@ bool tblgen::DagLeaf::isAttrMatcher() const { return isSubClassOf("AttrConstraint"); } -bool tblgen::DagLeaf::isAttrTransformer() const { - return isSubClassOf("tAttr"); +bool tblgen::DagLeaf::isNativeCodeCall() const { + return isSubClassOf("NativeCodeCall"); } bool tblgen::DagLeaf::isConstantAttr() const { @@ -76,12 +76,9 @@ std::string tblgen::DagLeaf::getConditionTemplate() const { return getAsConstraint().getConditionTemplate(); } -std::string tblgen::DagLeaf::getTransformationTemplate() const { - assert(isAttrTransformer() && "the DAG leaf must be attribute transformer"); - return cast(def) - ->getDef() - ->getValueAsString("attrTransform") - .str(); +llvm::StringRef tblgen::DagLeaf::getNativeCodeTemplate() const { + assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); + return cast(def)->getDef()->getValueAsString("expression"); } bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const { @@ -90,19 +87,17 @@ bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const { return false; } -bool tblgen::DagNode::isAttrTransformer() const { - auto op = node->getOperator(); - if (!op || !isa(op)) - return false; - return cast(op)->getDef()->isSubClassOf("tAttr"); +bool tblgen::DagNode::isNativeCodeCall() const { + if (auto *defInit = dyn_cast_or_null(node->getOperator())) + return defInit->getDef()->isSubClassOf("NativeCodeCall"); + return false; } -std::string tblgen::DagNode::getTransformationTemplate() const { - assert(isAttrTransformer() && "the DAG leaf must be attribute transformer"); +llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const { + assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); return cast(node->getOperator()) ->getDef() - ->getValueAsString("attrTransform") - .str(); + ->getValueAsString("expression"); } llvm::StringRef tblgen::DagNode::getOpName() const { @@ -156,17 +151,6 @@ bool tblgen::DagNode::isVerifyUnusedValue() const { return dagOpDef->getName() == "verifyUnusedValue"; } -bool tblgen::DagNode::isNativeCodeBuilder() const { - auto *dagOpDef = cast(node->getOperator())->getDef(); - return dagOpDef->isSubClassOf("cOp"); -} - -llvm::StringRef tblgen::DagNode::getNativeCodeBuilder() const { - assert(isNativeCodeBuilder()); - auto *dagOpDef = cast(node->getOperator())->getDef(); - return dagOpDef->getValueAsString("function"); -} - tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) : def(*def), recordOpMap(mapper) { collectBoundArguments(getSourcePattern()); diff --git a/mlir/test/mlir-tblgen/pattern-NativeCodeCall.td b/mlir/test/mlir-tblgen/pattern-NativeCodeCall.td new file mode 100644 index 000000000000..317284de48e4 --- /dev/null +++ b/mlir/test/mlir-tblgen/pattern-NativeCodeCall.td @@ -0,0 +1,35 @@ +// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s + +include "mlir/IR/OpBase.td" + +def CreateOperand : NativeCodeCall<"buildOperand($0, $1)">; +def CreateArrayAttr : NativeCodeCall<"$_builder.getArrayAttr({$0, $1})">; +def CreateOpResult : NativeCodeCall<"buildOp($0, $1)">; + +def NS_AOp : Op<"a_op", []> { + let arguments = (ins I32:$input1, I32:$input2, I32Attr:$attr); + let results = (outs I32:$output); +} + +def NS_BOp : Op<"b_op", []> { + let arguments = (ins I32:$input, I32Attr:$attr); + let results = (outs I32:$output); +} + +def TestCreateOpResult : Pat< + (NS_BOp $input, $attr), + (CreateOpResult $input, $attr)>; + +// CHECK-LABEL: TestCreateOpResult + +// CHECK: rewriter.replaceOp(op, {buildOp(s.input, s.attr)}); + +def TestCreateOperandAndAttr : Pat< + (NS_AOp $input1, $input2, $attr), + (NS_BOp (CreateOperand $input1, $input2), (CreateArrayAttr $attr, $attr))>; + +// CHECK-LABEL: TestCreateOperandAndAttr + +// CHECK: rewriter.create +// CHECK-NEXT: buildOperand(s.input1, s.input2), +// CHECK-NEXT: rewriter.getArrayAttr({s.attr, s.attr}) diff --git a/mlir/test/mlir-tblgen/pattern-bound-symbol.td b/mlir/test/mlir-tblgen/pattern-bound-symbol.td index 0d1a53eecc1b..805a0321079c 100644 --- a/mlir/test/mlir-tblgen/pattern-bound-symbol.td +++ b/mlir/test/mlir-tblgen/pattern-bound-symbol.td @@ -23,10 +23,11 @@ def OpD : Op<"op_d", []> { } def hasOneUse: ConstrainthasOneUse()">, "has one use">; +def getResult0 : NativeCodeCall<"$_self->getResult(0)">; def : Pattern<(OpA:$res_a $operand, $attr), [(OpC:$res_c (OpB:$res_b $operand)), - (OpD $res_b, $res_c, $res_a, $attr)], + (OpD $res_b, $res_c, getResult0:$res_a, $attr)], [(hasOneUse $res_a)]>; // CHECK-LABEL: GeneratedConvert0 @@ -59,5 +60,5 @@ def : Pattern<(OpA:$res_a $operand, $attr), // CHECK: auto vOpD0 = rewriter.create( // CHECK: /*input1=*/res_b, // CHECK: /*input2=*/res_c, -// CHECK: /*input3=*/s.res_a, +// CHECK: /*input3=*/s.res_a->getResult(0), // CHECK: /*attr=*/s.attr diff --git a/mlir/test/mlir-tblgen/pattern-tAttr.td b/mlir/test/mlir-tblgen/pattern-tAttr.td deleted file mode 100644 index 02a12567617f..000000000000 --- a/mlir/test/mlir-tblgen/pattern-tAttr.td +++ /dev/null @@ -1,30 +0,0 @@ -// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s - -include "mlir/IR/OpBase.td" - -// Create a Type and Attribute. -def T : BuildableType<"buildT()">; -def T_Attr : TypedAttrBase, "attribute of T type">; -def T_Const_Attr : ConstantAttr; -def T_Compose_Attr : tAttr<"$_builder.getArrayAttr({$0, $1})">; - -// Define ops to rewrite. -def Y_Op : Op<"y.op"> { - let arguments = (ins T_Attr:$attrName); - let results = (outs I32:$result); -} -def Z_Op : Op<"z.op"> { - let arguments = (ins T_Attr:$attrName1, T_Attr:$attrName2); - let results = (outs I32:$result); -} - -// Define rewrite pattern. -def : Pat<(Y_Op $attr1), (Y_Op (T_Compose_Attr $attr1, T_Const_Attr))>; -// CHECK-LABEL: struct GeneratedConvert0 -// CHECK: void rewrite( -// CHECK: /*attrName=*/rewriter.getArrayAttr({s.attr1, rewriter.getAttribute(rewriter.buildT(), attrValue)}) - -def : Pat<(Z_Op $attr1, $attr2), (Y_Op (T_Compose_Attr $attr1, $attr2))>; -// CHECK-LABEL: struct GeneratedConvert1 -// CHECK: void rewrite( -// CHECK: /*attrName=*/rewriter.getArrayAttr({s.attr1, s.attr2}) diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 797d3367efa0..bbc961f635a4 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -166,9 +166,9 @@ private: std::string handleRewritePattern(DagNode resultTree, int resultIndex, int depth); - // Emits the C++ statement to replace the matched DAG with a native C++ built - // value. - std::string emitReplaceWithNativeBuilder(DagNode resultTree); + // Emits the C++ statement to replace the matched DAG with a value built via + // calling native C++ code. + std::string emitReplaceWithNativeCodeCall(DagNode resultTree); // Returns the C++ expression referencing the old value serving as the // replacement. @@ -193,9 +193,6 @@ private: // `patArgName` is used to bound the argument to the source pattern. std::string handleOpArgument(DagLeaf leaf, llvm::StringRef patArgName); - // Returns the C++ expression to build an argument from the given DAG `tree`. - std::string handleOpArgument(DagNode tree); - // Marks the symbol attached to DagNode `node` as bound. Aborts if the symbol // is already bound. void addSymbol(DagNode node); @@ -515,8 +512,8 @@ std::string PatternEmitter::getUniqueValueName(const Operator *op) { std::string PatternEmitter::handleRewritePattern(DagNode resultTree, int resultIndex, int depth) { - if (resultTree.isNativeCodeBuilder()) - return emitReplaceWithNativeBuilder(resultTree); + if (resultTree.isNativeCodeCall()) + return emitReplaceWithNativeCodeCall(resultTree); if (resultTree.isVerifyUnusedValue()) { if (depth > 0) { @@ -584,22 +581,18 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf, if (leaf.isUnspecified() || leaf.isOperandMatcher()) { return result; } - if (leaf.isAttrTransformer()) { - return tgfmt(leaf.getTransformationTemplate(), - &rewriteCtx.withSelf(result)); + if (leaf.isNativeCodeCall()) { + return tgfmt(leaf.getNativeCodeTemplate(), &rewriteCtx.withSelf(result)); } PrintFatalError(loc, "unhandled case when rewriting op"); } -std::string PatternEmitter::handleOpArgument(DagNode tree) { - if (!tree.isAttrTransformer()) { - PrintFatalError(loc, "only tAttr is supported in nested dag attribute"); - } - auto fmt = tree.getTransformationTemplate(); +std::string PatternEmitter::emitReplaceWithNativeCodeCall(DagNode tree) { + auto fmt = tree.getNativeCodeTemplate(); // TODO(fengliuai): replace formatv arguments with the exact specified args. SmallVector attrs(8); if (tree.getNumArgs() > 8) { - PrintFatalError(loc, "unsupported tAttr argument numbers: " + + PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " + Twine(tree.getNumArgs())); } for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) { @@ -692,7 +685,9 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex, // Create the builder call for the result. // Add operands. int i = 0; - for (auto operand : resultOp.getOperands()) { + for (int e = resultOp.getNumOperands(); i < e; ++i) { + const auto &operand = resultOp.getOperand(i); + // Start each operand on its own line. (os << ",\n").indent(6); @@ -702,11 +697,15 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex, if (tree.isNestedDagArg(i)) { os << childNodeNames[i]; } else { - os << resolveSymbol(tree.getArgName(i)); + DagLeaf leaf = tree.getArgAsLeaf(i); + auto symbol = resolveSymbol(tree.getArgName(i)); + if (leaf.isNativeCodeCall()) { + os << tgfmt(leaf.getNativeCodeTemplate(), &rewriteCtx.withSelf(symbol)); + } else { + os << symbol; + } } - // TODO(jpienaar): verify types - ++i; } // Add attributes. @@ -716,7 +715,11 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex, // The argument in the op definition. auto opArgName = resultOp.getArgName(i); if (auto subTree = tree.getArgAsNestedDag(i)) { - os << formatv("/*{0}=*/{1}", opArgName, handleOpArgument(subTree)); + if (!subTree.isNativeCodeCall()) + PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " + "for creating attribute"); + os << formatv("/*{0}=*/{1}", opArgName, + emitReplaceWithNativeCodeCall(subTree)); } else { auto leaf = tree.getArgAsLeaf(i); // The argument in the result DAG pattern. @@ -739,36 +742,6 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex, return resultValue; } -std::string PatternEmitter::emitReplaceWithNativeBuilder(DagNode resultTree) { - // The variable's name for holding the result of this native builder call - std::string value = formatv("v{0}", nextValueId++).str(); - - os.indent(4) << "auto " << value << " = " << resultTree.getNativeCodeBuilder() - << "(op, {"; - const auto &boundedValues = pattern.getSourcePatternBoundArgs(); - bool first = true; - bool printingAttr = false; - for (int i = 0, e = resultTree.getNumArgs(); i != e; ++i) { - auto name = resultTree.getArgName(i); - pattern.ensureBoundInSourcePattern(name); - const auto &val = boundedValues.find(name); - if (val->second.dyn_cast() && !printingAttr) { - os << "}, {"; - first = true; - printingAttr = true; - } - if (!first) - os << ","; - os << getBoundSymbol(name); - first = false; - } - if (!printingAttr) - os << "},{"; - os << "}, rewriter);\n"; - - return value; -} - void PatternEmitter::emit(StringRef rewriteName, Record *p, RecordOperatorMap *mapper, raw_ostream &os) { PatternEmitter(p, mapper, os).emit(rewriteName);