From 2bf423b0218c9583e3a372950a34facbf93e63d3 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 9 Oct 2020 13:32:01 -0700 Subject: [PATCH] [mlir] RewriterGen NativeCodeCall matcher with ConstantOp matcher Added an underlying matcher for generic constant ops. This included a rewriter of RewriterGen to make variable use more clear. Differential Revision: https://reviews.llvm.org/D89161 --- mlir/include/mlir/IR/OpBase.td | 2 + mlir/include/mlir/TableGen/Pattern.h | 10 + mlir/lib/TableGen/Pattern.cpp | 177 ++++++++++----- mlir/test/lib/Dialect/Test/TestDialect.cpp | 4 + mlir/test/lib/Dialect/Test/TestOps.td | 16 ++ mlir/test/lib/Dialect/Test/TestPatterns.cpp | 1 + mlir/test/mlir-tblgen/pattern.mlir | 52 +++++ mlir/test/mlir-tblgen/rewriter-errors.td | 29 +++ mlir/tools/mlir-tblgen/RewriterGen.cpp | 239 ++++++++++++++------ 9 files changed, 410 insertions(+), 120 deletions(-) create mode 100644 mlir/test/mlir-tblgen/rewriter-errors.td diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 82dc6a456f29..72b3b1ab41f5 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2351,6 +2351,8 @@ class NativeCodeCall { string expression = expr; } +def ConstantLikeMatcher : NativeCodeCall<"success(matchPattern($0->getResult(0), m_Constant(&$1)))">; + //===----------------------------------------------------------------------===// // Rewrite directives //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h index 4fc2ae762a66..98c5d9b18f5d 100644 --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -252,6 +252,9 @@ public: static SymbolInfo getAttr(const Operator *op, int index) { return SymbolInfo(op, Kind::Attr, index); } + static SymbolInfo getAttr() { + return SymbolInfo(nullptr, Kind::Attr, llvm::None); + } static SymbolInfo getOperand(const Operator *op, int index) { return SymbolInfo(op, Kind::Operand, index); } @@ -319,6 +322,10 @@ public: // is already bound. bool bindValue(StringRef symbol); + // Registers the given `symbol` as bound to an attr. Returns false if `symbol` + // is already bound. + bool bindAttr(StringRef symbol); + // Returns true if the given `symbol` is bound. bool contains(StringRef symbol) const; @@ -421,6 +428,9 @@ public: std::vector getLocation() const; private: + // Helper function to verify variabld binding. + void verifyBind(bool result, StringRef symbolName); + // Recursively collects all bound symbols inside the DAG tree rooted // at `tree` and updates the given `infoMap`. void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index 448f70359bd0..7044677fad36 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -216,9 +216,13 @@ std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': "); switch (kind) { case Kind::Attr: { - auto type = - op->getArg(*argIndex).get()->attr.getStorageType(); - return std::string(formatv("{0} {1};\n", type, name)); + if (op) { + auto type = + op->getArg(*argIndex).get()->attr.getStorageType(); + return std::string(formatv("{0} {1};\n", type, name)); + } + // TODO(suderman): Use a more exact type when available. + return std::string(formatv("Attribute {0};\n", name)); } case Kind::Operand: { // Use operand range for captured operands (to support potential variadic @@ -394,6 +398,11 @@ bool SymbolInfoMap::bindValue(StringRef symbol) { return symbolInfoMap.count(inserted->first) == 1; } +bool SymbolInfoMap::bindAttr(StringRef symbol) { + auto inserted = symbolInfoMap.emplace(symbol, SymbolInfo::getAttr()); + return symbolInfoMap.count(inserted->first) == 1; +} + bool SymbolInfoMap::contains(StringRef symbol) const { return find(symbol) != symbolInfoMap.end(); } @@ -558,15 +567,15 @@ std::vector Pattern::getConstraints() const { for (auto it : *listInit) { auto *dagInit = dyn_cast(it); if (!dagInit) - PrintFatalError(def.getLoc(), "all elements in Pattern multi-entity " - "constraints should be DAG nodes"); + PrintFatalError(&def, "all elements in Pattern multi-entity " + "constraints should be DAG nodes"); std::vector entities; entities.reserve(dagInit->arg_size()); for (auto *argName : dagInit->getArgNames()) { if (!argName) { PrintFatalError( - def.getLoc(), + &def, "operands to additional constraints can only be symbol references"); } entities.push_back(std::string(argName->getValue())); @@ -584,7 +593,7 @@ int Pattern::getBenefit() const { int initBenefit = getSourcePattern().getNumOps(); llvm::DagInit *delta = def.getValueAsDag("benefitDelta"); if (delta->getNumArgs() != 1 || !isa(delta->getArg(0))) { - PrintFatalError(def.getLoc(), + PrintFatalError(&def, "The 'addBenefit' takes and only takes one integer value"); } return initBenefit + dyn_cast(delta->getArg(0))->getValue(); @@ -603,64 +612,120 @@ std::vector Pattern::getLocation() const { return result; } +void Pattern::verifyBind(bool result, StringRef symbolName) { + if (!result) { + auto err = formatv("symbol '{0}' bound more than once", symbolName); + PrintFatalError(&def, err); + } +} + void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, bool isSrcPattern) { auto treeName = tree.getSymbol(); - if (!tree.isOperation()) { + auto numTreeArgs = tree.getNumArgs(); + + if (tree.isNativeCodeCall()) { if (!treeName.empty()) { PrintFatalError( - def.getLoc(), - formatv("binding symbol '{0}' to non-operation unsupported right now", - treeName)); + &def, + formatv( + "binding symbol '{0}' to native code call unsupported right now", + treeName)); + } + + for (int i = 0; i != numTreeArgs; ++i) { + if (auto treeArg = tree.getArgAsNestedDag(i)) { + // This DAG node argument is a DAG node itself. Go inside recursively. + collectBoundSymbols(treeArg, infoMap, isSrcPattern); + continue; + } + + if (!isSrcPattern) + continue; + + // We can only bind symbols to arguments in source pattern. Those + // symbols are referenced in result patterns. + auto treeArgName = tree.getArgName(i); + + // `$_` is a special symbol meaning ignore the current argument. + if (!treeArgName.empty() && treeArgName != "_") { + if (tree.isNestedDagArg(i)) { + auto err = formatv("cannot bind '{0}' for nested native call arg", + treeArgName); + PrintFatalError(&def, err); + } + + DagLeaf leaf = tree.getArgAsLeaf(i); + auto constraint = leaf.getAsConstraint(); + bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() || + leaf.isConstantAttr() || + constraint.getKind() == Constraint::Kind::CK_Attr; + + if (isAttr) { + verifyBind(infoMap.bindAttr(treeArgName), treeArgName); + continue; + } + + verifyBind(infoMap.bindValue(treeArgName), treeArgName); + } + } + + return; + } + + if (tree.isOperation()) { + auto &op = getDialectOp(tree); + auto numOpArgs = op.getNumArgs(); + + // The pattern might have the last argument specifying the location. + bool hasLocDirective = false; + if (numTreeArgs != 0) { + if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1)) + hasLocDirective = lastArg.isLocationDirective(); + } + + if (numOpArgs != numTreeArgs - hasLocDirective) { + auto err = formatv("op '{0}' argument number mismatch: " + "{1} in pattern vs. {2} in definition", + op.getOperationName(), numTreeArgs, numOpArgs); + PrintFatalError(&def, err); + } + + // The name attached to the DAG node's operator is for representing the + // results generated from this op. It should be remembered as bound results. + if (!treeName.empty()) { + LLVM_DEBUG(llvm::dbgs() + << "found symbol bound to op result: " << treeName << '\n'); + verifyBind(infoMap.bindOpResult(treeName, op), treeName); + } + + for (int i = 0; i != numTreeArgs; ++i) { + if (auto treeArg = tree.getArgAsNestedDag(i)) { + // This DAG node argument is a DAG node itself. Go inside recursively. + collectBoundSymbols(treeArg, infoMap, isSrcPattern); + continue; + } + + if (isSrcPattern) { + // We can only bind symbols to op arguments in source pattern. Those + // symbols are referenced in result patterns. + auto treeArgName = tree.getArgName(i); + // `$_` is a special symbol meaning ignore the current argument. + if (!treeArgName.empty() && treeArgName != "_") { + LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: " + << treeArgName << '\n'); + verifyBind(infoMap.bindOpArgument(treeArgName, op, i), treeArgName); + } + } } return; } - auto &op = getDialectOp(tree); - auto numOpArgs = op.getNumArgs(); - auto numTreeArgs = tree.getNumArgs(); - - // The pattern might have the last argument specifying the location. - bool hasLocDirective = false; - if (numTreeArgs != 0) { - if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1)) - hasLocDirective = lastArg.isLocationDirective(); - } - - if (numOpArgs != numTreeArgs - hasLocDirective) { - auto err = formatv("op '{0}' argument number mismatch: " - "{1} in pattern vs. {2} in definition", - op.getOperationName(), numTreeArgs, numOpArgs); - PrintFatalError(def.getLoc(), err); - } - - // The name attached to the DAG node's operator is for representing the - // results generated from this op. It should be remembered as bound results. if (!treeName.empty()) { - LLVM_DEBUG(llvm::dbgs() - << "found symbol bound to op result: " << treeName << '\n'); - if (!infoMap.bindOpResult(treeName, op)) - PrintFatalError(def.getLoc(), - formatv("symbol '{0}' bound more than once", treeName)); - } - - for (int i = 0; i != numTreeArgs; ++i) { - if (auto treeArg = tree.getArgAsNestedDag(i)) { - // This DAG node argument is a DAG node itself. Go inside recursively. - collectBoundSymbols(treeArg, infoMap, isSrcPattern); - } else if (isSrcPattern) { - // We can only bind symbols to op arguments in source pattern. Those - // symbols are referenced in result patterns. - auto treeArgName = tree.getArgName(i); - // `$_` is a special symbol meaning ignore the current argument. - if (!treeArgName.empty() && treeArgName != "_") { - LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: " - << treeArgName << '\n'); - if (!infoMap.bindOpArgument(treeArgName, op, i)) { - auto err = formatv("symbol '{0}' bound more than once", treeArgName); - PrintFatalError(def.getLoc(), err); - } - } - } + PrintFatalError( + &def, formatv("binding symbol '{0}' to non-operation/native code call " + "unsupported right now", + treeName)); } + return; } diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 3bfb82495ce1..d34e997644a5 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -615,6 +615,10 @@ OpFoldResult TestOpWithRegionFold::fold(ArrayRef operands) { return operand(); } +OpFoldResult TestOpConstant::fold(ArrayRef operands) { + return getValue(); +} + LogicalResult TestOpWithVariadicResultsAndFolder::fold( ArrayRef operands, SmallVectorImpl &results) { for (Value input : this->operands()) { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index aef39a9e19fe..fcc677361dcc 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -799,6 +799,22 @@ def TestOpWithRegionPattern : TEST_Op<"op_with_region_pattern"> { let hasCanonicalizer = 1; } +def TestOpConstant : TEST_Op<"constant", [ConstantLike, NoSideEffect]> { + let arguments = (ins AnyAttr:$value); + let results = (outs AnyType); + let extraClassDeclaration = [{ + Attribute getValue() { return getAttr("value"); } + }]; + + let hasFolder = 1; +} + +def OpR : TEST_Op<"op_r">, Arguments<(ins AnyInteger, AnyInteger)>, Results<(outs AnyInteger)>; +def OpS : TEST_Op<"op_s">, Arguments<(ins AnyInteger, AnyAttr:$value)>, Results<(outs AnyInteger)>; + +def : Pat<(OpR $input1, (ConstantLikeMatcher I32Attr:$input2)), + (OpS:$unused $input1, $input2)>; + // Op for testing trivial removal via folding of op with inner ops and no uses. def TestOpWithRegionFoldNoSideEffect : TEST_Op< "op_with_region_fold_no_side_effect", [NoSideEffect]> { diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 32d618d9008e..282d31065549 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -9,6 +9,7 @@ #include "TestDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir index 5986be6240f9..616e116cb170 100644 --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -248,6 +248,58 @@ func @verifyUnitAttr() -> (i32, i32) { return %0, %1 : i32, i32 } +//===----------------------------------------------------------------------===// +// Test Constant Matching +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: testConstOp +func @testConstOp() -> (i32) { + // CHECK-NEXT: [[C0:%.+]] = constant 1 + %0 = "test.constant"() {value = 1 : i32} : () -> i32 + + // CHECK-NEXT: return [[C0]] + return %0 : i32 +} + +// CHECK-LABEL: testConstOpUsed +func @testConstOpUsed() -> (i32) { + // CHECK-NEXT: [[C0:%.+]] = constant 1 + %0 = "test.constant"() {value = 1 : i32} : () -> i32 + + // CHECK-NEXT: [[V0:%.+]] = "test.op_s"([[C0]]) + %1 = "test.op_s"(%0) {value = 1 : i32} : (i32) -> i32 + + // CHECK-NEXT: return [[V0]] + return %1 : i32 +} + +// CHECK-LABEL: testConstOpReplaced +func @testConstOpReplaced() -> (i32) { + // CHECK-NEXT: [[C0:%.+]] = constant 1 + %0 = "test.constant"() {value = 1 : i32} : () -> i32 + %1 = "test.constant"() {value = 2 : i32} : () -> i32 + + // CHECK: [[V0:%.+]] = "test.op_s"([[C0]]) {value = 2 : i32} + %2 = "test.op_r"(%0, %1) : (i32, i32) -> i32 + + // CHECK: [[V0]] + return %2 : i32 +} +// CHECK-LABEL: testConstOpMatchFailure +func @testConstOpMatchFailure() -> (i64) { + // CHECK-DAG: [[C0:%.+]] = constant 1 + %0 = "test.constant"() {value = 1 : i64} : () -> i64 + + // CHECK-DAG: [[C1:%.+]] = constant 2 + %1 = "test.constant"() {value = 2 : i64} : () -> i64 + + // CHECK: [[V0:%.+]] = "test.op_r"([[C0]], [[C1]]) + %2 = "test.op_r"(%0, %1) : (i64, i64) -> i64 + + // CHECK: [[V0]] + return %2 : i64 +} + //===----------------------------------------------------------------------===// // Test Enum Attributes //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/rewriter-errors.td b/mlir/test/mlir-tblgen/rewriter-errors.td new file mode 100644 index 000000000000..eeb049482b88 --- /dev/null +++ b/mlir/test/mlir-tblgen/rewriter-errors.td @@ -0,0 +1,29 @@ +// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR1 %s 2>&1 | FileCheck --check-prefix=ERROR1 %s +// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR2 %s 2>&1 | FileCheck --check-prefix=ERROR2 %s + +include "mlir/IR/OpBase.td" + +// Check using the dialect name as the namespace +def A_Dialect : Dialect { + let name = "a"; +} + +class A_Op traits = []> : + Op; + +def OpA : A_Op<"op_a">, Arguments<(ins AnyInteger, AnyInteger)>, Results<(outs AnyInteger)>; +def OpB : A_Op<"op_b">, Arguments<(ins AnyInteger, AnyAttr:$value)>, Results<(outs AnyInteger)>; + +#ifdef ERROR1 +def NativeMatcher : NativeCodeCall<"success(nativeCall($0, $1))">; +// ERROR1: [[@LINE+1]]:1: error: binding symbol 'error' to native code call unsupported right now +def : Pat<(OpA (NativeMatcher:$error $val), AnyI32Attr:$arg), + (OpB $val, $arg)>; +#endif + +#ifdef ERROR2 +def NativeMatcher : NativeCodeCall<"success(nativeCall($0, $1))">; +// ERROR2: [[@LINE+1]]:1: error: Matching nested tree in NativeCodecall not support for +def : Pat<(OpA (NativeMatcher (OpB $val, $unused)), AnyI32Attr:$arg), + (OpB $val, $arg)>; +#endif diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 7bff3e3b40b6..5521eea38252 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -63,7 +63,7 @@ public: private: // Emits the code for matching ops. - void emitMatchLogic(DagNode tree); + void emitMatchLogic(DagNode tree, StringRef opName); // Emits the code for rewriting ops. void emitRewriteLogic(); @@ -72,26 +72,34 @@ private: // Match utilities //===--------------------------------------------------------------------===// + // Emits C++ statements for matching the DAG structure. + void emitMatch(DagNode tree, StringRef name, int depth); + + // Emits C++ statements for matching using a native code call. + void emitNativeCodeMatch(DagNode tree, StringRef name, int depth); + // Emits C++ statements for matching the op constrained by the given DAG - // `tree`. - void emitOpMatch(DagNode tree, int depth); + // `tree` returning the op's variable name. + void emitOpMatch(DagNode tree, StringRef opName, int depth); // Emits C++ statements for matching the `argIndex`-th argument of the given // DAG `tree` as an operand. - void emitOperandMatch(DagNode tree, int argIndex, int depth); + void emitOperandMatch(DagNode tree, StringRef opName, int argIndex, + int depth); // Emits C++ statements for matching the `argIndex`-th argument of the given // DAG `tree` as an attribute. - void emitAttributeMatch(DagNode tree, int argIndex, int depth); + void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex, + int depth); // Emits C++ for checking a match with a corresponding match failure // diagnostic. - void emitMatchCheck(int depth, const FmtObjectBase &matchFmt, + void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt, const llvm::formatv_object_base &failureFmt); // Emits C++ for checking a match with a corresponding match failure // diagnostics. - void emitMatchCheck(int depth, const std::string &matchStr, + void emitMatchCheck(StringRef opName, const std::string &matchStr, const std::string &failureStr); //===--------------------------------------------------------------------===// @@ -113,7 +121,7 @@ private: // Emits the C++ statement to replace the matched DAG with a value built via // calling native C++ code. - std::string handleReplaceWithNativeCodeCall(DagNode resultTree); + std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth); // Returns the symbol of the old value serving as the replacement. StringRef handleReplaceWithValue(DagNode tree); @@ -140,12 +148,13 @@ private: // Emits the concrete arguments used to call an op's builder. void supplyValuesForOpArgs(DagNode node, - const ChildNodeIndexNameMap &childNodeNames); + const ChildNodeIndexNameMap &childNodeNames, + int depth); // Emits the local variables for holding all values as a whole and all named // attributes as a whole to be used for creating an op. void createAggregateLocalVarsForOpArgs( - DagNode node, const ChildNodeIndexNameMap &childNodeNames); + DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth); // Returns the C++ expression to construct a constant attribute of the given // `value` for the given attribute kind `attr`. @@ -218,21 +227,114 @@ std::string PatternEmitter::handleConstantAttr(Attribute attr, } // Helper function to match patterns. -void PatternEmitter::emitOpMatch(DagNode tree, int depth) { +void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) { + if (tree.isNativeCodeCall()) { + emitNativeCodeMatch(tree, name, depth); + return; + } + + if (tree.isOperation()) { + emitOpMatch(tree, name, depth); + return; + } + + PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match."); +} + +// Helper function to match patterns. +void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName, + int depth) { + LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: "); + LLVM_DEBUG(tree.print(llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << '\n'); + + // TODO(suderman): iterate through arguments, determine their types, output + // names. + SmallVector capture(8); + if (tree.getNumArgs() > 8) { + PrintFatalError(loc, + "unsupported NativeCodeCall matcher argument numbers: " + + Twine(tree.getNumArgs())); + } + + raw_indented_ostream::DelimitedScope scope(os); + + for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { + std::string argName = formatv("arg{0}_{1}", depth, i); + if (DagNode argTree = tree.getArgAsNestedDag(i)) { + os << "Value " << argName << ";\n"; + } else { + auto leaf = tree.getArgAsLeaf(i); + if (leaf.isAttrMatcher() || leaf.isConstantAttr()) { + os << "Attribute " << argName << ";\n"; + } else if (leaf.isOperandMatcher()) { + os << "Operation " << argName << ";\n"; + } + } + + capture[i] = std::move(argName); + } + + bool hasLocationDirective; + std::string locToUse; + std::tie(hasLocationDirective, locToUse) = getLocation(tree); + + auto fmt = tree.getNativeCodeTemplate(); + auto nativeCodeCall = std::string(tgfmt( + fmt, &fmtCtx.addSubst("_loc", locToUse), opName, capture[0], capture[1], + capture[2], capture[3], capture[4], capture[5], capture[6], capture[7])); + + os << "if (failed(" << nativeCodeCall << ")) return failure();\n"; + + for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { + auto name = tree.getArgName(i); + if (!name.empty() && name != "_") { + os << formatv("{0} = {1};\n", name, capture[i]); + } + } + + for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { + std::string argName = capture[i]; + + // Handle nested DAG construct first + if (DagNode argTree = tree.getArgAsNestedDag(i)) { + PrintFatalError( + loc, formatv("Matching nested tree in NativeCodecall not support for " + "{0} as arg {1}", + argName, i)); + } + + DagLeaf leaf = tree.getArgAsLeaf(i); + auto constraint = leaf.getAsConstraint(); + + auto self = formatv("{0}", argName); + emitMatchCheck( + opName, + tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)), + formatv("\"operand {0} of native code call '{1}' failed to satisfy " + "constraint: " + "'{2}'\"", + i, tree.getNativeCodeTemplate(), constraint.getDescription())); + } + + LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n"); +} + +// Helper function to match patterns. +void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { Operator &op = tree.getDialectOp(opMap); LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '" << op.getOperationName() << "' at depth " << depth << '\n'); - int indent = 4 + 2 * depth; - os.indent(indent) << formatv( - "auto castedOp{0} = ::llvm::dyn_cast_or_null<{1}>(op{0}); " - "(void)castedOp{0};\n", - depth, op.getQualCppClassName()); + std::string castedName = formatv("castedOp{0}", depth); + os << formatv("auto {0} = ::llvm::dyn_cast_or_null<{2}>({1}); " + "(void){0};\n", + castedName, opName, op.getQualCppClassName()); // Skip the operand matching at depth 0 as the pattern rewriter already does. if (depth != 0) { // Skip if there is no defining operation (e.g., arguments to function). - os << formatv("if (!castedOp{0})\n return failure();\n", depth); + os << formatv("if (!{0}) return failure();\n", castedName); } if (tree.getNumArgs() != op.getNumArgs()) { PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " @@ -244,10 +346,11 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) { // If the operand's name is set, set to that variable. auto name = tree.getSymbol(); if (!name.empty()) - os << formatv("{0} = castedOp{1};\n", name, depth); + os << formatv("{0} = {1};\n", name, castedName); for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { auto opArg = op.getArg(i); + std::string argName = formatv("op{0}", depth + 1); // Handle nested DAG construct first if (DagNode argTree = tree.getArgAsNestedDag(i)) { @@ -262,20 +365,20 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) { os << "{\n"; os.indent() << formatv( - "auto *op{0} = " - "(*castedOp{1}.getODSOperands({2}).begin()).getDefiningOp();\n", - depth + 1, depth, i); - emitOpMatch(argTree, depth + 1); - os << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1); + "auto *{0} = " + "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n", + argName, castedName, i); + emitMatch(argTree, argName, depth + 1); + os << formatv("tblgen_ops[{0}] = {1};\n", ++opCounter, argName); os.unindent() << "}\n"; continue; } // Next handle DAG leaf: operand or attribute if (opArg.is()) { - emitOperandMatch(tree, i, depth); + emitOperandMatch(tree, castedName, i, depth); } else if (opArg.is()) { - emitAttributeMatch(tree, i, depth); + emitAttributeMatch(tree, opName, i, depth); } else { PrintFatalError(loc, "unhandled case when matching op"); } @@ -285,7 +388,8 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) { << '\n'); } -void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) { +void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName, + int argIndex, int depth) { Operator &op = tree.getDialectOp(opMap); auto *operand = op.getArg(argIndex).get(); auto matcher = tree.getArgAsLeaf(argIndex); @@ -309,11 +413,10 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) { op.getOperationName(), argIndex); PrintFatalError(loc, error); } - auto self = - formatv("(*castedOp{0}.getODSOperands({1}).begin()).getType()", depth, - argIndex); + auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()", + opName, argIndex); emitMatchCheck( - depth, + opName, tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)), formatv("\"operand {0} of op '{1}' failed to satisfy constraint: " "'{2}'\"", @@ -333,21 +436,22 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) { [](const Argument &arg) { return arg.is(); }); auto res = symbolInfoMap.findBoundSymbol(name, op, argIndex); - os << formatv("{0} = castedOp{1}.getODSOperands({2});\n", - res->second.getVarName(name), depth, argIndex - numPrevAttrs); + os << formatv("{0} = {1}.getODSOperands({2});\n", + res->second.getVarName(name), opName, + argIndex - numPrevAttrs); } } -void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) { +void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName, + int argIndex, int depth) { Operator &op = tree.getDialectOp(opMap); auto *namedAttr = op.getArg(argIndex).get(); const auto &attr = namedAttr->attr; os << "{\n"; - os.indent() << formatv( - "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\"); " - "(void)tblgen_attr;\n", - depth, attr.getStorageType(), namedAttr->name); + os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");" + "(void)tblgen_attr;\n", + opName, attr.getStorageType(), namedAttr->name); // TODO: This should use getter method to avoid duplication. if (attr.hasDefaultValue()) { @@ -360,7 +464,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) { // should just capture a mlir::Attribute() to signal the missing state. // That is precisely what getAttr() returns on missing attributes. } else { - emitMatchCheck(depth, tgfmt("tblgen_attr", &fmtCtx), + emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx), formatv("\"expected op '{0}' to have attribute '{1}' " "of type '{2}'\"", op.getOperationName(), namedAttr->name, @@ -378,7 +482,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) { // If a constraint is specified, we need to generate C++ statements to // check the constraint. emitMatchCheck( - depth, + opName, tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")), formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: " "{2}\"", @@ -397,24 +501,25 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) { } void PatternEmitter::emitMatchCheck( - int depth, const FmtObjectBase &matchFmt, + StringRef opName, const FmtObjectBase &matchFmt, const llvm::formatv_object_base &failureFmt) { - emitMatchCheck(depth, matchFmt.str(), failureFmt.str()); + emitMatchCheck(opName, matchFmt.str(), failureFmt.str()); } -void PatternEmitter::emitMatchCheck(int depth, const std::string &matchStr, +void PatternEmitter::emitMatchCheck(StringRef opName, + const std::string &matchStr, const std::string &failureStr) { + os << "if (!(" << matchStr << "))"; - os.scope("{\n", "\n}\n").os - << "return rewriter.notifyMatchFailure(op" << depth - << ", [&](::mlir::Diagnostic &diag) {\n diag << " << failureStr - << ";\n});"; + os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName + << ", [&](::mlir::Diagnostic &diag) {\n diag << " + << failureStr << ";\n});"; } -void PatternEmitter::emitMatchLogic(DagNode tree) { +void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) { LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n"); int depth = 0; - emitOpMatch(tree, depth); + emitMatch(tree, opName, depth); for (auto &appliedConstraint : pattern.getConstraints()) { auto &constraint = appliedConstraint.constraint; @@ -425,7 +530,7 @@ void PatternEmitter::emitMatchLogic(DagNode tree) { auto self = formatv("({0}.getType())", symbolInfoMap.getValueAndRangeUse(entities.front())); emitMatchCheck( - depth, tgfmt(condition, &fmtCtx.withSelf(self.str())), + opName, tgfmt(condition, &fmtCtx.withSelf(self.str())), formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"", entities.front(), constraint.getDescription())); @@ -447,7 +552,7 @@ void PatternEmitter::emitMatchLogic(DagNode tree) { self = symbolInfoMap.getValueAndRangeUse(self); for (; i < 4; ++i) names.push_back(""); - emitMatchCheck(depth, + emitMatchCheck(opName, tgfmt(condition, &fmtCtx.withSelf(self), names[0], names[1], names[2], names[3]), formatv("\"entities '{0}' failed to satisfy constraint: " @@ -471,7 +576,7 @@ void PatternEmitter::emitMatchLogic(DagNode tree) { for (++startRange; startRange != endRange; ++startRange) { auto secondOperand = startRange->second.getVarName(symbolInfoIt->first); emitMatchCheck( - depth, + opName, formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand), formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand, secondOperand)); @@ -567,7 +672,7 @@ void PatternEmitter::emit(StringRef rewriteName) { os << "// Match\n"; os << "tblgen_ops[0] = op0;\n"; - emitMatchLogic(sourceTree); + emitMatchLogic(sourceTree, "op0"); os << "\n// Rewrite\n"; emitRewriteLogic(); @@ -681,7 +786,7 @@ std::string PatternEmitter::handleResultPattern(DagNode resultTree, } if (resultTree.isNativeCodeCall()) { - auto symbol = handleReplaceWithNativeCodeCall(resultTree); + auto symbol = handleReplaceWithNativeCodeCall(resultTree, depth); symbolInfoMap.bindValue(symbol); return symbol; } @@ -798,7 +903,8 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf, PrintFatalError(loc, "unhandled case when rewriting op"); } -std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) { +std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree, + int depth) { LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: "); LLVM_DEBUG(tree.print(llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << '\n'); @@ -807,15 +913,20 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) { // TODO: replace formatv arguments with the exact specified args. SmallVector attrs(8); if (tree.getNumArgs() > 8) { - PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " + - Twine(tree.getNumArgs())); + PrintFatalError(loc, + "unsupported NativeCodeCall replace argument numbers: " + + Twine(tree.getNumArgs())); } bool hasLocationDirective; std::string locToUse; std::tie(hasLocationDirective, locToUse) = getLocation(tree); for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) { - attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); + if (tree.isNestedDagArg(i)) { + attrs[i] = handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1); + } else { + attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); + } LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i << " replacement: " << attrs[i] << "\n"); } @@ -924,7 +1035,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, // create the ops. // First prepare local variables for op arguments used in builder call. - createAggregateLocalVarsForOpArgs(tree, childNodeNames); + createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); // Then create the op. os.scope("", "\n}\n").os << formatv( @@ -948,7 +1059,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName, resultOp.getQualCppClassName(), locToUse); - supplyValuesForOpArgs(tree, childNodeNames); + supplyValuesForOpArgs(tree, childNodeNames, depth); os << "\n );\n}\n"; return resultValue; } @@ -959,7 +1070,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, // here. // First prepare local variables for op arguments used in builder call. - createAggregateLocalVarsForOpArgs(tree, childNodeNames); + createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); // Then prepare the result types. We need to specify the types for all // results. @@ -1037,7 +1148,7 @@ void PatternEmitter::createSeparateLocalVarsForOpArgs( } void PatternEmitter::supplyValuesForOpArgs( - DagNode node, const ChildNodeIndexNameMap &childNodeNames) { + DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { Operator &resultOp = node.getDialectOp(opMap); for (int argIndex = 0, numOpArgs = resultOp.getNumArgs(); argIndex != numOpArgs; ++argIndex) { @@ -1060,7 +1171,7 @@ void PatternEmitter::supplyValuesForOpArgs( PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " "for creating attribute"); os << formatv("/*{0}=*/{1}", opArgName, - handleReplaceWithNativeCodeCall(subTree)); + handleReplaceWithNativeCodeCall(subTree, depth)); } else { auto leaf = node.getArgAsLeaf(argIndex); // The argument in the result DAG pattern. @@ -1080,7 +1191,7 @@ void PatternEmitter::supplyValuesForOpArgs( } void PatternEmitter::createAggregateLocalVarsForOpArgs( - DagNode node, const ChildNodeIndexNameMap &childNodeNames) { + DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { Operator &resultOp = node.getDialectOp(opMap); auto scope = os.scope(); @@ -1102,7 +1213,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs( PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " "for creating attribute"); os << formatv(addAttrCmd, opArgName, - handleReplaceWithNativeCodeCall(subTree)); + handleReplaceWithNativeCodeCall(subTree, depth + 1)); } else { auto leaf = node.getArgAsLeaf(argIndex); // The argument in the result DAG pattern.