[mlir] RewriterGen NativeCodeCall matcher with ConstantOp matcher

Added an underlying matcher for generic constant ops. This
included a rewriter of RewriterGen to make variable use more
clear.

Differential Revision: https://reviews.llvm.org/D89161
This commit is contained in:
Rob Suderman 2020-10-09 13:32:01 -07:00
parent 273c299d5d
commit 2bf423b021
9 changed files with 410 additions and 120 deletions

View File

@ -2351,6 +2351,8 @@ class NativeCodeCall<string expr> {
string expression = expr; string expression = expr;
} }
def ConstantLikeMatcher : NativeCodeCall<"success(matchPattern($0->getResult(0), m_Constant(&$1)))">;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Rewrite directives // Rewrite directives
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -252,6 +252,9 @@ public:
static SymbolInfo getAttr(const Operator *op, int index) { static SymbolInfo getAttr(const Operator *op, int index) {
return SymbolInfo(op, Kind::Attr, index); return SymbolInfo(op, Kind::Attr, index);
} }
static SymbolInfo getAttr() {
return SymbolInfo(nullptr, Kind::Attr, llvm::None);
}
static SymbolInfo getOperand(const Operator *op, int index) { static SymbolInfo getOperand(const Operator *op, int index) {
return SymbolInfo(op, Kind::Operand, index); return SymbolInfo(op, Kind::Operand, index);
} }
@ -319,6 +322,10 @@ public:
// is already bound. // is already bound.
bool bindValue(StringRef symbol); bool bindValue(StringRef symbol);
// Registers the given `symbol` as bound to an attr. Returns false if `symbol`
// is already bound.
bool bindAttr(StringRef symbol);
// Returns true if the given `symbol` is bound. // Returns true if the given `symbol` is bound.
bool contains(StringRef symbol) const; bool contains(StringRef symbol) const;
@ -421,6 +428,9 @@ public:
std::vector<IdentifierLine> getLocation() const; std::vector<IdentifierLine> getLocation() const;
private: private:
// Helper function to verify variabld binding.
void verifyBind(bool result, StringRef symbolName);
// Recursively collects all bound symbols inside the DAG tree rooted // Recursively collects all bound symbols inside the DAG tree rooted
// at `tree` and updates the given `infoMap`. // at `tree` and updates the given `infoMap`.
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,

View File

@ -216,9 +216,13 @@ std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': "); LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
switch (kind) { switch (kind) {
case Kind::Attr: { case Kind::Attr: {
auto type = if (op) {
op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType(); auto type =
return std::string(formatv("{0} {1};\n", type, name)); op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
return std::string(formatv("{0} {1};\n", type, name));
}
// TODO(suderman): Use a more exact type when available.
return std::string(formatv("Attribute {0};\n", name));
} }
case Kind::Operand: { case Kind::Operand: {
// Use operand range for captured operands (to support potential variadic // Use operand range for captured operands (to support potential variadic
@ -394,6 +398,11 @@ bool SymbolInfoMap::bindValue(StringRef symbol) {
return symbolInfoMap.count(inserted->first) == 1; return symbolInfoMap.count(inserted->first) == 1;
} }
bool SymbolInfoMap::bindAttr(StringRef symbol) {
auto inserted = symbolInfoMap.emplace(symbol, SymbolInfo::getAttr());
return symbolInfoMap.count(inserted->first) == 1;
}
bool SymbolInfoMap::contains(StringRef symbol) const { bool SymbolInfoMap::contains(StringRef symbol) const {
return find(symbol) != symbolInfoMap.end(); return find(symbol) != symbolInfoMap.end();
} }
@ -558,15 +567,15 @@ std::vector<AppliedConstraint> Pattern::getConstraints() const {
for (auto it : *listInit) { for (auto it : *listInit) {
auto *dagInit = dyn_cast<llvm::DagInit>(it); auto *dagInit = dyn_cast<llvm::DagInit>(it);
if (!dagInit) if (!dagInit)
PrintFatalError(def.getLoc(), "all elements in Pattern multi-entity " PrintFatalError(&def, "all elements in Pattern multi-entity "
"constraints should be DAG nodes"); "constraints should be DAG nodes");
std::vector<std::string> entities; std::vector<std::string> entities;
entities.reserve(dagInit->arg_size()); entities.reserve(dagInit->arg_size());
for (auto *argName : dagInit->getArgNames()) { for (auto *argName : dagInit->getArgNames()) {
if (!argName) { if (!argName) {
PrintFatalError( PrintFatalError(
def.getLoc(), &def,
"operands to additional constraints can only be symbol references"); "operands to additional constraints can only be symbol references");
} }
entities.push_back(std::string(argName->getValue())); entities.push_back(std::string(argName->getValue()));
@ -584,7 +593,7 @@ int Pattern::getBenefit() const {
int initBenefit = getSourcePattern().getNumOps(); int initBenefit = getSourcePattern().getNumOps();
llvm::DagInit *delta = def.getValueAsDag("benefitDelta"); llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) { if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
PrintFatalError(def.getLoc(), PrintFatalError(&def,
"The 'addBenefit' takes and only takes one integer value"); "The 'addBenefit' takes and only takes one integer value");
} }
return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue(); return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
@ -603,64 +612,120 @@ std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
return result; return result;
} }
void Pattern::verifyBind(bool result, StringRef symbolName) {
if (!result) {
auto err = formatv("symbol '{0}' bound more than once", symbolName);
PrintFatalError(&def, err);
}
}
void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
bool isSrcPattern) { bool isSrcPattern) {
auto treeName = tree.getSymbol(); auto treeName = tree.getSymbol();
if (!tree.isOperation()) { auto numTreeArgs = tree.getNumArgs();
if (tree.isNativeCodeCall()) {
if (!treeName.empty()) { if (!treeName.empty()) {
PrintFatalError( PrintFatalError(
def.getLoc(), &def,
formatv("binding symbol '{0}' to non-operation unsupported right now", formatv(
treeName)); "binding symbol '{0}' to native code call unsupported right now",
treeName));
}
for (int i = 0; i != numTreeArgs; ++i) {
if (auto treeArg = tree.getArgAsNestedDag(i)) {
// This DAG node argument is a DAG node itself. Go inside recursively.
collectBoundSymbols(treeArg, infoMap, isSrcPattern);
continue;
}
if (!isSrcPattern)
continue;
// We can only bind symbols to arguments in source pattern. Those
// symbols are referenced in result patterns.
auto treeArgName = tree.getArgName(i);
// `$_` is a special symbol meaning ignore the current argument.
if (!treeArgName.empty() && treeArgName != "_") {
if (tree.isNestedDagArg(i)) {
auto err = formatv("cannot bind '{0}' for nested native call arg",
treeArgName);
PrintFatalError(&def, err);
}
DagLeaf leaf = tree.getArgAsLeaf(i);
auto constraint = leaf.getAsConstraint();
bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
leaf.isConstantAttr() ||
constraint.getKind() == Constraint::Kind::CK_Attr;
if (isAttr) {
verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
continue;
}
verifyBind(infoMap.bindValue(treeArgName), treeArgName);
}
}
return;
}
if (tree.isOperation()) {
auto &op = getDialectOp(tree);
auto numOpArgs = op.getNumArgs();
// The pattern might have the last argument specifying the location.
bool hasLocDirective = false;
if (numTreeArgs != 0) {
if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1))
hasLocDirective = lastArg.isLocationDirective();
}
if (numOpArgs != numTreeArgs - hasLocDirective) {
auto err = formatv("op '{0}' argument number mismatch: "
"{1} in pattern vs. {2} in definition",
op.getOperationName(), numTreeArgs, numOpArgs);
PrintFatalError(&def, err);
}
// The name attached to the DAG node's operator is for representing the
// results generated from this op. It should be remembered as bound results.
if (!treeName.empty()) {
LLVM_DEBUG(llvm::dbgs()
<< "found symbol bound to op result: " << treeName << '\n');
verifyBind(infoMap.bindOpResult(treeName, op), treeName);
}
for (int i = 0; i != numTreeArgs; ++i) {
if (auto treeArg = tree.getArgAsNestedDag(i)) {
// This DAG node argument is a DAG node itself. Go inside recursively.
collectBoundSymbols(treeArg, infoMap, isSrcPattern);
continue;
}
if (isSrcPattern) {
// We can only bind symbols to op arguments in source pattern. Those
// symbols are referenced in result patterns.
auto treeArgName = tree.getArgName(i);
// `$_` is a special symbol meaning ignore the current argument.
if (!treeArgName.empty() && treeArgName != "_") {
LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
<< treeArgName << '\n');
verifyBind(infoMap.bindOpArgument(treeArgName, op, i), treeArgName);
}
}
} }
return; return;
} }
auto &op = getDialectOp(tree);
auto numOpArgs = op.getNumArgs();
auto numTreeArgs = tree.getNumArgs();
// The pattern might have the last argument specifying the location.
bool hasLocDirective = false;
if (numTreeArgs != 0) {
if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1))
hasLocDirective = lastArg.isLocationDirective();
}
if (numOpArgs != numTreeArgs - hasLocDirective) {
auto err = formatv("op '{0}' argument number mismatch: "
"{1} in pattern vs. {2} in definition",
op.getOperationName(), numTreeArgs, numOpArgs);
PrintFatalError(def.getLoc(), err);
}
// The name attached to the DAG node's operator is for representing the
// results generated from this op. It should be remembered as bound results.
if (!treeName.empty()) { if (!treeName.empty()) {
LLVM_DEBUG(llvm::dbgs() PrintFatalError(
<< "found symbol bound to op result: " << treeName << '\n'); &def, formatv("binding symbol '{0}' to non-operation/native code call "
if (!infoMap.bindOpResult(treeName, op)) "unsupported right now",
PrintFatalError(def.getLoc(), treeName));
formatv("symbol '{0}' bound more than once", treeName));
}
for (int i = 0; i != numTreeArgs; ++i) {
if (auto treeArg = tree.getArgAsNestedDag(i)) {
// This DAG node argument is a DAG node itself. Go inside recursively.
collectBoundSymbols(treeArg, infoMap, isSrcPattern);
} else if (isSrcPattern) {
// We can only bind symbols to op arguments in source pattern. Those
// symbols are referenced in result patterns.
auto treeArgName = tree.getArgName(i);
// `$_` is a special symbol meaning ignore the current argument.
if (!treeArgName.empty() && treeArgName != "_") {
LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
<< treeArgName << '\n');
if (!infoMap.bindOpArgument(treeArgName, op, i)) {
auto err = formatv("symbol '{0}' bound more than once", treeArgName);
PrintFatalError(def.getLoc(), err);
}
}
}
} }
return;
} }

View File

@ -615,6 +615,10 @@ OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
return operand(); return operand();
} }
OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
return getValue();
}
LogicalResult TestOpWithVariadicResultsAndFolder::fold( LogicalResult TestOpWithVariadicResultsAndFolder::fold(
ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
for (Value input : this->operands()) { for (Value input : this->operands()) {

View File

@ -799,6 +799,22 @@ def TestOpWithRegionPattern : TEST_Op<"op_with_region_pattern"> {
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
} }
def TestOpConstant : TEST_Op<"constant", [ConstantLike, NoSideEffect]> {
let arguments = (ins AnyAttr:$value);
let results = (outs AnyType);
let extraClassDeclaration = [{
Attribute getValue() { return getAttr("value"); }
}];
let hasFolder = 1;
}
def OpR : TEST_Op<"op_r">, Arguments<(ins AnyInteger, AnyInteger)>, Results<(outs AnyInteger)>;
def OpS : TEST_Op<"op_s">, Arguments<(ins AnyInteger, AnyAttr:$value)>, Results<(outs AnyInteger)>;
def : Pat<(OpR $input1, (ConstantLikeMatcher I32Attr:$input2)),
(OpS:$unused $input1, $input2)>;
// Op for testing trivial removal via folding of op with inner ops and no uses. // Op for testing trivial removal via folding of op with inner ops and no uses.
def TestOpWithRegionFoldNoSideEffect : TEST_Op< def TestOpWithRegionFoldNoSideEffect : TEST_Op<
"op_with_region_fold_no_side_effect", [NoSideEffect]> { "op_with_region_fold_no_side_effect", [NoSideEffect]> {

View File

@ -9,6 +9,7 @@
#include "TestDialect.h" #include "TestDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"

View File

@ -248,6 +248,58 @@ func @verifyUnitAttr() -> (i32, i32) {
return %0, %1 : i32, i32 return %0, %1 : i32, i32
} }
//===----------------------------------------------------------------------===//
// Test Constant Matching
//===----------------------------------------------------------------------===//
// CHECK-LABEL: testConstOp
func @testConstOp() -> (i32) {
// CHECK-NEXT: [[C0:%.+]] = constant 1
%0 = "test.constant"() {value = 1 : i32} : () -> i32
// CHECK-NEXT: return [[C0]]
return %0 : i32
}
// CHECK-LABEL: testConstOpUsed
func @testConstOpUsed() -> (i32) {
// CHECK-NEXT: [[C0:%.+]] = constant 1
%0 = "test.constant"() {value = 1 : i32} : () -> i32
// CHECK-NEXT: [[V0:%.+]] = "test.op_s"([[C0]])
%1 = "test.op_s"(%0) {value = 1 : i32} : (i32) -> i32
// CHECK-NEXT: return [[V0]]
return %1 : i32
}
// CHECK-LABEL: testConstOpReplaced
func @testConstOpReplaced() -> (i32) {
// CHECK-NEXT: [[C0:%.+]] = constant 1
%0 = "test.constant"() {value = 1 : i32} : () -> i32
%1 = "test.constant"() {value = 2 : i32} : () -> i32
// CHECK: [[V0:%.+]] = "test.op_s"([[C0]]) {value = 2 : i32}
%2 = "test.op_r"(%0, %1) : (i32, i32) -> i32
// CHECK: [[V0]]
return %2 : i32
}
// CHECK-LABEL: testConstOpMatchFailure
func @testConstOpMatchFailure() -> (i64) {
// CHECK-DAG: [[C0:%.+]] = constant 1
%0 = "test.constant"() {value = 1 : i64} : () -> i64
// CHECK-DAG: [[C1:%.+]] = constant 2
%1 = "test.constant"() {value = 2 : i64} : () -> i64
// CHECK: [[V0:%.+]] = "test.op_r"([[C0]], [[C1]])
%2 = "test.op_r"(%0, %1) : (i64, i64) -> i64
// CHECK: [[V0]]
return %2 : i64
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Test Enum Attributes // Test Enum Attributes
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -0,0 +1,29 @@
// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR1 %s 2>&1 | FileCheck --check-prefix=ERROR1 %s
// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR2 %s 2>&1 | FileCheck --check-prefix=ERROR2 %s
include "mlir/IR/OpBase.td"
// Check using the dialect name as the namespace
def A_Dialect : Dialect {
let name = "a";
}
class A_Op<string mnemonic, list<OpTrait> traits = []> :
Op<A_Dialect, mnemonic, traits>;
def OpA : A_Op<"op_a">, Arguments<(ins AnyInteger, AnyInteger)>, Results<(outs AnyInteger)>;
def OpB : A_Op<"op_b">, Arguments<(ins AnyInteger, AnyAttr:$value)>, Results<(outs AnyInteger)>;
#ifdef ERROR1
def NativeMatcher : NativeCodeCall<"success(nativeCall($0, $1))">;
// ERROR1: [[@LINE+1]]:1: error: binding symbol 'error' to native code call unsupported right now
def : Pat<(OpA (NativeMatcher:$error $val), AnyI32Attr:$arg),
(OpB $val, $arg)>;
#endif
#ifdef ERROR2
def NativeMatcher : NativeCodeCall<"success(nativeCall($0, $1))">;
// ERROR2: [[@LINE+1]]:1: error: Matching nested tree in NativeCodecall not support for
def : Pat<(OpA (NativeMatcher (OpB $val, $unused)), AnyI32Attr:$arg),
(OpB $val, $arg)>;
#endif

View File

@ -63,7 +63,7 @@ public:
private: private:
// Emits the code for matching ops. // Emits the code for matching ops.
void emitMatchLogic(DagNode tree); void emitMatchLogic(DagNode tree, StringRef opName);
// Emits the code for rewriting ops. // Emits the code for rewriting ops.
void emitRewriteLogic(); void emitRewriteLogic();
@ -72,26 +72,34 @@ private:
// Match utilities // Match utilities
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
// Emits C++ statements for matching the DAG structure.
void emitMatch(DagNode tree, StringRef name, int depth);
// Emits C++ statements for matching using a native code call.
void emitNativeCodeMatch(DagNode tree, StringRef name, int depth);
// Emits C++ statements for matching the op constrained by the given DAG // Emits C++ statements for matching the op constrained by the given DAG
// `tree`. // `tree` returning the op's variable name.
void emitOpMatch(DagNode tree, int depth); void emitOpMatch(DagNode tree, StringRef opName, int depth);
// Emits C++ statements for matching the `argIndex`-th argument of the given // Emits C++ statements for matching the `argIndex`-th argument of the given
// DAG `tree` as an operand. // DAG `tree` as an operand.
void emitOperandMatch(DagNode tree, int argIndex, int depth); void emitOperandMatch(DagNode tree, StringRef opName, int argIndex,
int depth);
// Emits C++ statements for matching the `argIndex`-th argument of the given // Emits C++ statements for matching the `argIndex`-th argument of the given
// DAG `tree` as an attribute. // DAG `tree` as an attribute.
void emitAttributeMatch(DagNode tree, int argIndex, int depth); void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
int depth);
// Emits C++ for checking a match with a corresponding match failure // Emits C++ for checking a match with a corresponding match failure
// diagnostic. // diagnostic.
void emitMatchCheck(int depth, const FmtObjectBase &matchFmt, void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt,
const llvm::formatv_object_base &failureFmt); const llvm::formatv_object_base &failureFmt);
// Emits C++ for checking a match with a corresponding match failure // Emits C++ for checking a match with a corresponding match failure
// diagnostics. // diagnostics.
void emitMatchCheck(int depth, const std::string &matchStr, void emitMatchCheck(StringRef opName, const std::string &matchStr,
const std::string &failureStr); const std::string &failureStr);
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
@ -113,7 +121,7 @@ private:
// Emits the C++ statement to replace the matched DAG with a value built via // Emits the C++ statement to replace the matched DAG with a value built via
// calling native C++ code. // calling native C++ code.
std::string handleReplaceWithNativeCodeCall(DagNode resultTree); std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth);
// Returns the symbol of the old value serving as the replacement. // Returns the symbol of the old value serving as the replacement.
StringRef handleReplaceWithValue(DagNode tree); StringRef handleReplaceWithValue(DagNode tree);
@ -140,12 +148,13 @@ private:
// Emits the concrete arguments used to call an op's builder. // Emits the concrete arguments used to call an op's builder.
void supplyValuesForOpArgs(DagNode node, void supplyValuesForOpArgs(DagNode node,
const ChildNodeIndexNameMap &childNodeNames); const ChildNodeIndexNameMap &childNodeNames,
int depth);
// Emits the local variables for holding all values as a whole and all named // Emits the local variables for holding all values as a whole and all named
// attributes as a whole to be used for creating an op. // attributes as a whole to be used for creating an op.
void createAggregateLocalVarsForOpArgs( void createAggregateLocalVarsForOpArgs(
DagNode node, const ChildNodeIndexNameMap &childNodeNames); DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth);
// Returns the C++ expression to construct a constant attribute of the given // Returns the C++ expression to construct a constant attribute of the given
// `value` for the given attribute kind `attr`. // `value` for the given attribute kind `attr`.
@ -218,21 +227,114 @@ std::string PatternEmitter::handleConstantAttr(Attribute attr,
} }
// Helper function to match patterns. // Helper function to match patterns.
void PatternEmitter::emitOpMatch(DagNode tree, int depth) { void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
if (tree.isNativeCodeCall()) {
emitNativeCodeMatch(tree, name, depth);
return;
}
if (tree.isOperation()) {
emitOpMatch(tree, name, depth);
return;
}
PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match.");
}
// Helper function to match patterns.
void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
int depth) {
LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: ");
LLVM_DEBUG(tree.print(llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << '\n');
// TODO(suderman): iterate through arguments, determine their types, output
// names.
SmallVector<std::string, 8> capture(8);
if (tree.getNumArgs() > 8) {
PrintFatalError(loc,
"unsupported NativeCodeCall matcher argument numbers: " +
Twine(tree.getNumArgs()));
}
raw_indented_ostream::DelimitedScope scope(os);
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
std::string argName = formatv("arg{0}_{1}", depth, i);
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
os << "Value " << argName << ";\n";
} else {
auto leaf = tree.getArgAsLeaf(i);
if (leaf.isAttrMatcher() || leaf.isConstantAttr()) {
os << "Attribute " << argName << ";\n";
} else if (leaf.isOperandMatcher()) {
os << "Operation " << argName << ";\n";
}
}
capture[i] = std::move(argName);
}
bool hasLocationDirective;
std::string locToUse;
std::tie(hasLocationDirective, locToUse) = getLocation(tree);
auto fmt = tree.getNativeCodeTemplate();
auto nativeCodeCall = std::string(tgfmt(
fmt, &fmtCtx.addSubst("_loc", locToUse), opName, capture[0], capture[1],
capture[2], capture[3], capture[4], capture[5], capture[6], capture[7]));
os << "if (failed(" << nativeCodeCall << ")) return failure();\n";
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
auto name = tree.getArgName(i);
if (!name.empty() && name != "_") {
os << formatv("{0} = {1};\n", name, capture[i]);
}
}
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
std::string argName = capture[i];
// Handle nested DAG construct first
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
PrintFatalError(
loc, formatv("Matching nested tree in NativeCodecall not support for "
"{0} as arg {1}",
argName, i));
}
DagLeaf leaf = tree.getArgAsLeaf(i);
auto constraint = leaf.getAsConstraint();
auto self = formatv("{0}", argName);
emitMatchCheck(
opName,
tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
formatv("\"operand {0} of native code call '{1}' failed to satisfy "
"constraint: "
"'{2}'\"",
i, tree.getNativeCodeTemplate(), constraint.getDescription()));
}
LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n");
}
// Helper function to match patterns.
void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
Operator &op = tree.getDialectOp(opMap); Operator &op = tree.getDialectOp(opMap);
LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '" LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
<< op.getOperationName() << "' at depth " << depth << op.getOperationName() << "' at depth " << depth
<< '\n'); << '\n');
int indent = 4 + 2 * depth; std::string castedName = formatv("castedOp{0}", depth);
os.indent(indent) << formatv( os << formatv("auto {0} = ::llvm::dyn_cast_or_null<{2}>({1}); "
"auto castedOp{0} = ::llvm::dyn_cast_or_null<{1}>(op{0}); " "(void){0};\n",
"(void)castedOp{0};\n", castedName, opName, op.getQualCppClassName());
depth, op.getQualCppClassName());
// Skip the operand matching at depth 0 as the pattern rewriter already does. // Skip the operand matching at depth 0 as the pattern rewriter already does.
if (depth != 0) { if (depth != 0) {
// Skip if there is no defining operation (e.g., arguments to function). // Skip if there is no defining operation (e.g., arguments to function).
os << formatv("if (!castedOp{0})\n return failure();\n", depth); os << formatv("if (!{0}) return failure();\n", castedName);
} }
if (tree.getNumArgs() != op.getNumArgs()) { if (tree.getNumArgs() != op.getNumArgs()) {
PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
@ -244,10 +346,11 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
// If the operand's name is set, set to that variable. // If the operand's name is set, set to that variable.
auto name = tree.getSymbol(); auto name = tree.getSymbol();
if (!name.empty()) if (!name.empty())
os << formatv("{0} = castedOp{1};\n", name, depth); os << formatv("{0} = {1};\n", name, castedName);
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
auto opArg = op.getArg(i); auto opArg = op.getArg(i);
std::string argName = formatv("op{0}", depth + 1);
// Handle nested DAG construct first // Handle nested DAG construct first
if (DagNode argTree = tree.getArgAsNestedDag(i)) { if (DagNode argTree = tree.getArgAsNestedDag(i)) {
@ -262,20 +365,20 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
os << "{\n"; os << "{\n";
os.indent() << formatv( os.indent() << formatv(
"auto *op{0} = " "auto *{0} = "
"(*castedOp{1}.getODSOperands({2}).begin()).getDefiningOp();\n", "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
depth + 1, depth, i); argName, castedName, i);
emitOpMatch(argTree, depth + 1); emitMatch(argTree, argName, depth + 1);
os << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1); os << formatv("tblgen_ops[{0}] = {1};\n", ++opCounter, argName);
os.unindent() << "}\n"; os.unindent() << "}\n";
continue; continue;
} }
// Next handle DAG leaf: operand or attribute // Next handle DAG leaf: operand or attribute
if (opArg.is<NamedTypeConstraint *>()) { if (opArg.is<NamedTypeConstraint *>()) {
emitOperandMatch(tree, i, depth); emitOperandMatch(tree, castedName, i, depth);
} else if (opArg.is<NamedAttribute *>()) { } else if (opArg.is<NamedAttribute *>()) {
emitAttributeMatch(tree, i, depth); emitAttributeMatch(tree, opName, i, depth);
} else { } else {
PrintFatalError(loc, "unhandled case when matching op"); PrintFatalError(loc, "unhandled case when matching op");
} }
@ -285,7 +388,8 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
<< '\n'); << '\n');
} }
void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) { void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
int argIndex, int depth) {
Operator &op = tree.getDialectOp(opMap); Operator &op = tree.getDialectOp(opMap);
auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>(); auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
auto matcher = tree.getArgAsLeaf(argIndex); auto matcher = tree.getArgAsLeaf(argIndex);
@ -309,11 +413,10 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) {
op.getOperationName(), argIndex); op.getOperationName(), argIndex);
PrintFatalError(loc, error); PrintFatalError(loc, error);
} }
auto self = auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()",
formatv("(*castedOp{0}.getODSOperands({1}).begin()).getType()", depth, opName, argIndex);
argIndex);
emitMatchCheck( emitMatchCheck(
depth, opName,
tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)), tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
formatv("\"operand {0} of op '{1}' failed to satisfy constraint: " formatv("\"operand {0} of op '{1}' failed to satisfy constraint: "
"'{2}'\"", "'{2}'\"",
@ -333,21 +436,22 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) {
[](const Argument &arg) { return arg.is<NamedAttribute *>(); }); [](const Argument &arg) { return arg.is<NamedAttribute *>(); });
auto res = symbolInfoMap.findBoundSymbol(name, op, argIndex); auto res = symbolInfoMap.findBoundSymbol(name, op, argIndex);
os << formatv("{0} = castedOp{1}.getODSOperands({2});\n", os << formatv("{0} = {1}.getODSOperands({2});\n",
res->second.getVarName(name), depth, argIndex - numPrevAttrs); res->second.getVarName(name), opName,
argIndex - numPrevAttrs);
} }
} }
void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) { void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
int argIndex, int depth) {
Operator &op = tree.getDialectOp(opMap); Operator &op = tree.getDialectOp(opMap);
auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>(); auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
const auto &attr = namedAttr->attr; const auto &attr = namedAttr->attr;
os << "{\n"; os << "{\n";
os.indent() << formatv( os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
"auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\"); " "(void)tblgen_attr;\n",
"(void)tblgen_attr;\n", opName, attr.getStorageType(), namedAttr->name);
depth, attr.getStorageType(), namedAttr->name);
// TODO: This should use getter method to avoid duplication. // TODO: This should use getter method to avoid duplication.
if (attr.hasDefaultValue()) { if (attr.hasDefaultValue()) {
@ -360,7 +464,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
// should just capture a mlir::Attribute() to signal the missing state. // should just capture a mlir::Attribute() to signal the missing state.
// That is precisely what getAttr() returns on missing attributes. // That is precisely what getAttr() returns on missing attributes.
} else { } else {
emitMatchCheck(depth, tgfmt("tblgen_attr", &fmtCtx), emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx),
formatv("\"expected op '{0}' to have attribute '{1}' " formatv("\"expected op '{0}' to have attribute '{1}' "
"of type '{2}'\"", "of type '{2}'\"",
op.getOperationName(), namedAttr->name, op.getOperationName(), namedAttr->name,
@ -378,7 +482,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
// If a constraint is specified, we need to generate C++ statements to // If a constraint is specified, we need to generate C++ statements to
// check the constraint. // check the constraint.
emitMatchCheck( emitMatchCheck(
depth, opName,
tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")), tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")),
formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: " formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
"{2}\"", "{2}\"",
@ -397,24 +501,25 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
} }
void PatternEmitter::emitMatchCheck( void PatternEmitter::emitMatchCheck(
int depth, const FmtObjectBase &matchFmt, StringRef opName, const FmtObjectBase &matchFmt,
const llvm::formatv_object_base &failureFmt) { const llvm::formatv_object_base &failureFmt) {
emitMatchCheck(depth, matchFmt.str(), failureFmt.str()); emitMatchCheck(opName, matchFmt.str(), failureFmt.str());
} }
void PatternEmitter::emitMatchCheck(int depth, const std::string &matchStr, void PatternEmitter::emitMatchCheck(StringRef opName,
const std::string &matchStr,
const std::string &failureStr) { const std::string &failureStr) {
os << "if (!(" << matchStr << "))"; os << "if (!(" << matchStr << "))";
os.scope("{\n", "\n}\n").os os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName
<< "return rewriter.notifyMatchFailure(op" << depth << ", [&](::mlir::Diagnostic &diag) {\n diag << "
<< ", [&](::mlir::Diagnostic &diag) {\n diag << " << failureStr << failureStr << ";\n});";
<< ";\n});";
} }
void PatternEmitter::emitMatchLogic(DagNode tree) { void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n"); LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
int depth = 0; int depth = 0;
emitOpMatch(tree, depth); emitMatch(tree, opName, depth);
for (auto &appliedConstraint : pattern.getConstraints()) { for (auto &appliedConstraint : pattern.getConstraints()) {
auto &constraint = appliedConstraint.constraint; auto &constraint = appliedConstraint.constraint;
@ -425,7 +530,7 @@ void PatternEmitter::emitMatchLogic(DagNode tree) {
auto self = formatv("({0}.getType())", auto self = formatv("({0}.getType())",
symbolInfoMap.getValueAndRangeUse(entities.front())); symbolInfoMap.getValueAndRangeUse(entities.front()));
emitMatchCheck( emitMatchCheck(
depth, tgfmt(condition, &fmtCtx.withSelf(self.str())), opName, tgfmt(condition, &fmtCtx.withSelf(self.str())),
formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"", formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"",
entities.front(), constraint.getDescription())); entities.front(), constraint.getDescription()));
@ -447,7 +552,7 @@ void PatternEmitter::emitMatchLogic(DagNode tree) {
self = symbolInfoMap.getValueAndRangeUse(self); self = symbolInfoMap.getValueAndRangeUse(self);
for (; i < 4; ++i) for (; i < 4; ++i)
names.push_back("<unused>"); names.push_back("<unused>");
emitMatchCheck(depth, emitMatchCheck(opName,
tgfmt(condition, &fmtCtx.withSelf(self), names[0], tgfmt(condition, &fmtCtx.withSelf(self), names[0],
names[1], names[2], names[3]), names[1], names[2], names[3]),
formatv("\"entities '{0}' failed to satisfy constraint: " formatv("\"entities '{0}' failed to satisfy constraint: "
@ -471,7 +576,7 @@ void PatternEmitter::emitMatchLogic(DagNode tree) {
for (++startRange; startRange != endRange; ++startRange) { for (++startRange; startRange != endRange; ++startRange) {
auto secondOperand = startRange->second.getVarName(symbolInfoIt->first); auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
emitMatchCheck( emitMatchCheck(
depth, opName,
formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand), formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand, formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
secondOperand)); secondOperand));
@ -567,7 +672,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
os << "// Match\n"; os << "// Match\n";
os << "tblgen_ops[0] = op0;\n"; os << "tblgen_ops[0] = op0;\n";
emitMatchLogic(sourceTree); emitMatchLogic(sourceTree, "op0");
os << "\n// Rewrite\n"; os << "\n// Rewrite\n";
emitRewriteLogic(); emitRewriteLogic();
@ -681,7 +786,7 @@ std::string PatternEmitter::handleResultPattern(DagNode resultTree,
} }
if (resultTree.isNativeCodeCall()) { if (resultTree.isNativeCodeCall()) {
auto symbol = handleReplaceWithNativeCodeCall(resultTree); auto symbol = handleReplaceWithNativeCodeCall(resultTree, depth);
symbolInfoMap.bindValue(symbol); symbolInfoMap.bindValue(symbol);
return symbol; return symbol;
} }
@ -798,7 +903,8 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
PrintFatalError(loc, "unhandled case when rewriting op"); PrintFatalError(loc, "unhandled case when rewriting op");
} }
std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) { std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
int depth) {
LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: "); LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
LLVM_DEBUG(tree.print(llvm::dbgs())); LLVM_DEBUG(tree.print(llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << '\n'); LLVM_DEBUG(llvm::dbgs() << '\n');
@ -807,15 +913,20 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
// TODO: replace formatv arguments with the exact specified args. // TODO: replace formatv arguments with the exact specified args.
SmallVector<std::string, 8> attrs(8); SmallVector<std::string, 8> attrs(8);
if (tree.getNumArgs() > 8) { if (tree.getNumArgs() > 8) {
PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " + PrintFatalError(loc,
Twine(tree.getNumArgs())); "unsupported NativeCodeCall replace argument numbers: " +
Twine(tree.getNumArgs()));
} }
bool hasLocationDirective; bool hasLocationDirective;
std::string locToUse; std::string locToUse;
std::tie(hasLocationDirective, locToUse) = getLocation(tree); std::tie(hasLocationDirective, locToUse) = getLocation(tree);
for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) { for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); if (tree.isNestedDagArg(i)) {
attrs[i] = handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1);
} else {
attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
}
LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
<< " replacement: " << attrs[i] << "\n"); << " replacement: " << attrs[i] << "\n");
} }
@ -924,7 +1035,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
// create the ops. // create the ops.
// First prepare local variables for op arguments used in builder call. // First prepare local variables for op arguments used in builder call.
createAggregateLocalVarsForOpArgs(tree, childNodeNames); createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
// Then create the op. // Then create the op.
os.scope("", "\n}\n").os << formatv( os.scope("", "\n}\n").os << formatv(
@ -948,7 +1059,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName, os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
resultOp.getQualCppClassName(), locToUse); resultOp.getQualCppClassName(), locToUse);
supplyValuesForOpArgs(tree, childNodeNames); supplyValuesForOpArgs(tree, childNodeNames, depth);
os << "\n );\n}\n"; os << "\n );\n}\n";
return resultValue; return resultValue;
} }
@ -959,7 +1070,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
// here. // here.
// First prepare local variables for op arguments used in builder call. // First prepare local variables for op arguments used in builder call.
createAggregateLocalVarsForOpArgs(tree, childNodeNames); createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
// Then prepare the result types. We need to specify the types for all // Then prepare the result types. We need to specify the types for all
// results. // results.
@ -1037,7 +1148,7 @@ void PatternEmitter::createSeparateLocalVarsForOpArgs(
} }
void PatternEmitter::supplyValuesForOpArgs( void PatternEmitter::supplyValuesForOpArgs(
DagNode node, const ChildNodeIndexNameMap &childNodeNames) { DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
Operator &resultOp = node.getDialectOp(opMap); Operator &resultOp = node.getDialectOp(opMap);
for (int argIndex = 0, numOpArgs = resultOp.getNumArgs(); for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
argIndex != numOpArgs; ++argIndex) { argIndex != numOpArgs; ++argIndex) {
@ -1060,7 +1171,7 @@ void PatternEmitter::supplyValuesForOpArgs(
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
"for creating attribute"); "for creating attribute");
os << formatv("/*{0}=*/{1}", opArgName, os << formatv("/*{0}=*/{1}", opArgName,
handleReplaceWithNativeCodeCall(subTree)); handleReplaceWithNativeCodeCall(subTree, depth));
} else { } else {
auto leaf = node.getArgAsLeaf(argIndex); auto leaf = node.getArgAsLeaf(argIndex);
// The argument in the result DAG pattern. // The argument in the result DAG pattern.
@ -1080,7 +1191,7 @@ void PatternEmitter::supplyValuesForOpArgs(
} }
void PatternEmitter::createAggregateLocalVarsForOpArgs( void PatternEmitter::createAggregateLocalVarsForOpArgs(
DagNode node, const ChildNodeIndexNameMap &childNodeNames) { DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
Operator &resultOp = node.getDialectOp(opMap); Operator &resultOp = node.getDialectOp(opMap);
auto scope = os.scope(); auto scope = os.scope();
@ -1102,7 +1213,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
"for creating attribute"); "for creating attribute");
os << formatv(addAttrCmd, opArgName, os << formatv(addAttrCmd, opArgName,
handleReplaceWithNativeCodeCall(subTree)); handleReplaceWithNativeCodeCall(subTree, depth + 1));
} else { } else {
auto leaf = node.getArgAsLeaf(argIndex); auto leaf = node.getArgAsLeaf(argIndex);
// The argument in the result DAG pattern. // The argument in the result DAG pattern.