forked from OSchip/llvm-project
[mlir] RewriterGen NativeCodeCall matcher with ConstantOp matcher
Added an underlying matcher for generic constant ops. This included a rewriter of RewriterGen to make variable use more clear. Differential Revision: https://reviews.llvm.org/D89161
This commit is contained in:
parent
273c299d5d
commit
2bf423b021
|
@ -2351,6 +2351,8 @@ class NativeCodeCall<string expr> {
|
|||
string expression = expr;
|
||||
}
|
||||
|
||||
def ConstantLikeMatcher : NativeCodeCall<"success(matchPattern($0->getResult(0), m_Constant(&$1)))">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Rewrite directives
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -252,6 +252,9 @@ public:
|
|||
static SymbolInfo getAttr(const Operator *op, int index) {
|
||||
return SymbolInfo(op, Kind::Attr, index);
|
||||
}
|
||||
static SymbolInfo getAttr() {
|
||||
return SymbolInfo(nullptr, Kind::Attr, llvm::None);
|
||||
}
|
||||
static SymbolInfo getOperand(const Operator *op, int index) {
|
||||
return SymbolInfo(op, Kind::Operand, index);
|
||||
}
|
||||
|
@ -319,6 +322,10 @@ public:
|
|||
// is already bound.
|
||||
bool bindValue(StringRef symbol);
|
||||
|
||||
// Registers the given `symbol` as bound to an attr. Returns false if `symbol`
|
||||
// is already bound.
|
||||
bool bindAttr(StringRef symbol);
|
||||
|
||||
// Returns true if the given `symbol` is bound.
|
||||
bool contains(StringRef symbol) const;
|
||||
|
||||
|
@ -421,6 +428,9 @@ public:
|
|||
std::vector<IdentifierLine> getLocation() const;
|
||||
|
||||
private:
|
||||
// Helper function to verify variabld binding.
|
||||
void verifyBind(bool result, StringRef symbolName);
|
||||
|
||||
// Recursively collects all bound symbols inside the DAG tree rooted
|
||||
// at `tree` and updates the given `infoMap`.
|
||||
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
|
||||
|
|
|
@ -216,9 +216,13 @@ std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
|
|||
LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
|
||||
switch (kind) {
|
||||
case Kind::Attr: {
|
||||
auto type =
|
||||
op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
|
||||
return std::string(formatv("{0} {1};\n", type, name));
|
||||
if (op) {
|
||||
auto type =
|
||||
op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
|
||||
return std::string(formatv("{0} {1};\n", type, name));
|
||||
}
|
||||
// TODO(suderman): Use a more exact type when available.
|
||||
return std::string(formatv("Attribute {0};\n", name));
|
||||
}
|
||||
case Kind::Operand: {
|
||||
// Use operand range for captured operands (to support potential variadic
|
||||
|
@ -394,6 +398,11 @@ bool SymbolInfoMap::bindValue(StringRef symbol) {
|
|||
return symbolInfoMap.count(inserted->first) == 1;
|
||||
}
|
||||
|
||||
bool SymbolInfoMap::bindAttr(StringRef symbol) {
|
||||
auto inserted = symbolInfoMap.emplace(symbol, SymbolInfo::getAttr());
|
||||
return symbolInfoMap.count(inserted->first) == 1;
|
||||
}
|
||||
|
||||
bool SymbolInfoMap::contains(StringRef symbol) const {
|
||||
return find(symbol) != symbolInfoMap.end();
|
||||
}
|
||||
|
@ -558,15 +567,15 @@ std::vector<AppliedConstraint> Pattern::getConstraints() const {
|
|||
for (auto it : *listInit) {
|
||||
auto *dagInit = dyn_cast<llvm::DagInit>(it);
|
||||
if (!dagInit)
|
||||
PrintFatalError(def.getLoc(), "all elements in Pattern multi-entity "
|
||||
"constraints should be DAG nodes");
|
||||
PrintFatalError(&def, "all elements in Pattern multi-entity "
|
||||
"constraints should be DAG nodes");
|
||||
|
||||
std::vector<std::string> entities;
|
||||
entities.reserve(dagInit->arg_size());
|
||||
for (auto *argName : dagInit->getArgNames()) {
|
||||
if (!argName) {
|
||||
PrintFatalError(
|
||||
def.getLoc(),
|
||||
&def,
|
||||
"operands to additional constraints can only be symbol references");
|
||||
}
|
||||
entities.push_back(std::string(argName->getValue()));
|
||||
|
@ -584,7 +593,7 @@ int Pattern::getBenefit() const {
|
|||
int initBenefit = getSourcePattern().getNumOps();
|
||||
llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
|
||||
if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
|
||||
PrintFatalError(def.getLoc(),
|
||||
PrintFatalError(&def,
|
||||
"The 'addBenefit' takes and only takes one integer value");
|
||||
}
|
||||
return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
|
||||
|
@ -603,64 +612,120 @@ std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
|
|||
return result;
|
||||
}
|
||||
|
||||
void Pattern::verifyBind(bool result, StringRef symbolName) {
|
||||
if (!result) {
|
||||
auto err = formatv("symbol '{0}' bound more than once", symbolName);
|
||||
PrintFatalError(&def, err);
|
||||
}
|
||||
}
|
||||
|
||||
void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
|
||||
bool isSrcPattern) {
|
||||
auto treeName = tree.getSymbol();
|
||||
if (!tree.isOperation()) {
|
||||
auto numTreeArgs = tree.getNumArgs();
|
||||
|
||||
if (tree.isNativeCodeCall()) {
|
||||
if (!treeName.empty()) {
|
||||
PrintFatalError(
|
||||
def.getLoc(),
|
||||
formatv("binding symbol '{0}' to non-operation unsupported right now",
|
||||
treeName));
|
||||
&def,
|
||||
formatv(
|
||||
"binding symbol '{0}' to native code call unsupported right now",
|
||||
treeName));
|
||||
}
|
||||
|
||||
for (int i = 0; i != numTreeArgs; ++i) {
|
||||
if (auto treeArg = tree.getArgAsNestedDag(i)) {
|
||||
// This DAG node argument is a DAG node itself. Go inside recursively.
|
||||
collectBoundSymbols(treeArg, infoMap, isSrcPattern);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isSrcPattern)
|
||||
continue;
|
||||
|
||||
// We can only bind symbols to arguments in source pattern. Those
|
||||
// symbols are referenced in result patterns.
|
||||
auto treeArgName = tree.getArgName(i);
|
||||
|
||||
// `$_` is a special symbol meaning ignore the current argument.
|
||||
if (!treeArgName.empty() && treeArgName != "_") {
|
||||
if (tree.isNestedDagArg(i)) {
|
||||
auto err = formatv("cannot bind '{0}' for nested native call arg",
|
||||
treeArgName);
|
||||
PrintFatalError(&def, err);
|
||||
}
|
||||
|
||||
DagLeaf leaf = tree.getArgAsLeaf(i);
|
||||
auto constraint = leaf.getAsConstraint();
|
||||
bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
|
||||
leaf.isConstantAttr() ||
|
||||
constraint.getKind() == Constraint::Kind::CK_Attr;
|
||||
|
||||
if (isAttr) {
|
||||
verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
|
||||
continue;
|
||||
}
|
||||
|
||||
verifyBind(infoMap.bindValue(treeArgName), treeArgName);
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if (tree.isOperation()) {
|
||||
auto &op = getDialectOp(tree);
|
||||
auto numOpArgs = op.getNumArgs();
|
||||
|
||||
// The pattern might have the last argument specifying the location.
|
||||
bool hasLocDirective = false;
|
||||
if (numTreeArgs != 0) {
|
||||
if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1))
|
||||
hasLocDirective = lastArg.isLocationDirective();
|
||||
}
|
||||
|
||||
if (numOpArgs != numTreeArgs - hasLocDirective) {
|
||||
auto err = formatv("op '{0}' argument number mismatch: "
|
||||
"{1} in pattern vs. {2} in definition",
|
||||
op.getOperationName(), numTreeArgs, numOpArgs);
|
||||
PrintFatalError(&def, err);
|
||||
}
|
||||
|
||||
// The name attached to the DAG node's operator is for representing the
|
||||
// results generated from this op. It should be remembered as bound results.
|
||||
if (!treeName.empty()) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "found symbol bound to op result: " << treeName << '\n');
|
||||
verifyBind(infoMap.bindOpResult(treeName, op), treeName);
|
||||
}
|
||||
|
||||
for (int i = 0; i != numTreeArgs; ++i) {
|
||||
if (auto treeArg = tree.getArgAsNestedDag(i)) {
|
||||
// This DAG node argument is a DAG node itself. Go inside recursively.
|
||||
collectBoundSymbols(treeArg, infoMap, isSrcPattern);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isSrcPattern) {
|
||||
// We can only bind symbols to op arguments in source pattern. Those
|
||||
// symbols are referenced in result patterns.
|
||||
auto treeArgName = tree.getArgName(i);
|
||||
// `$_` is a special symbol meaning ignore the current argument.
|
||||
if (!treeArgName.empty() && treeArgName != "_") {
|
||||
LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
|
||||
<< treeArgName << '\n');
|
||||
verifyBind(infoMap.bindOpArgument(treeArgName, op, i), treeArgName);
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
auto &op = getDialectOp(tree);
|
||||
auto numOpArgs = op.getNumArgs();
|
||||
auto numTreeArgs = tree.getNumArgs();
|
||||
|
||||
// The pattern might have the last argument specifying the location.
|
||||
bool hasLocDirective = false;
|
||||
if (numTreeArgs != 0) {
|
||||
if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1))
|
||||
hasLocDirective = lastArg.isLocationDirective();
|
||||
}
|
||||
|
||||
if (numOpArgs != numTreeArgs - hasLocDirective) {
|
||||
auto err = formatv("op '{0}' argument number mismatch: "
|
||||
"{1} in pattern vs. {2} in definition",
|
||||
op.getOperationName(), numTreeArgs, numOpArgs);
|
||||
PrintFatalError(def.getLoc(), err);
|
||||
}
|
||||
|
||||
// The name attached to the DAG node's operator is for representing the
|
||||
// results generated from this op. It should be remembered as bound results.
|
||||
if (!treeName.empty()) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "found symbol bound to op result: " << treeName << '\n');
|
||||
if (!infoMap.bindOpResult(treeName, op))
|
||||
PrintFatalError(def.getLoc(),
|
||||
formatv("symbol '{0}' bound more than once", treeName));
|
||||
}
|
||||
|
||||
for (int i = 0; i != numTreeArgs; ++i) {
|
||||
if (auto treeArg = tree.getArgAsNestedDag(i)) {
|
||||
// This DAG node argument is a DAG node itself. Go inside recursively.
|
||||
collectBoundSymbols(treeArg, infoMap, isSrcPattern);
|
||||
} else if (isSrcPattern) {
|
||||
// We can only bind symbols to op arguments in source pattern. Those
|
||||
// symbols are referenced in result patterns.
|
||||
auto treeArgName = tree.getArgName(i);
|
||||
// `$_` is a special symbol meaning ignore the current argument.
|
||||
if (!treeArgName.empty() && treeArgName != "_") {
|
||||
LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
|
||||
<< treeArgName << '\n');
|
||||
if (!infoMap.bindOpArgument(treeArgName, op, i)) {
|
||||
auto err = formatv("symbol '{0}' bound more than once", treeArgName);
|
||||
PrintFatalError(def.getLoc(), err);
|
||||
}
|
||||
}
|
||||
}
|
||||
PrintFatalError(
|
||||
&def, formatv("binding symbol '{0}' to non-operation/native code call "
|
||||
"unsupported right now",
|
||||
treeName));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -615,6 +615,10 @@ OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
|
|||
return operand();
|
||||
}
|
||||
|
||||
OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
|
||||
return getValue();
|
||||
}
|
||||
|
||||
LogicalResult TestOpWithVariadicResultsAndFolder::fold(
|
||||
ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
|
||||
for (Value input : this->operands()) {
|
||||
|
|
|
@ -799,6 +799,22 @@ def TestOpWithRegionPattern : TEST_Op<"op_with_region_pattern"> {
|
|||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TestOpConstant : TEST_Op<"constant", [ConstantLike, NoSideEffect]> {
|
||||
let arguments = (ins AnyAttr:$value);
|
||||
let results = (outs AnyType);
|
||||
let extraClassDeclaration = [{
|
||||
Attribute getValue() { return getAttr("value"); }
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def OpR : TEST_Op<"op_r">, Arguments<(ins AnyInteger, AnyInteger)>, Results<(outs AnyInteger)>;
|
||||
def OpS : TEST_Op<"op_s">, Arguments<(ins AnyInteger, AnyAttr:$value)>, Results<(outs AnyInteger)>;
|
||||
|
||||
def : Pat<(OpR $input1, (ConstantLikeMatcher I32Attr:$input2)),
|
||||
(OpS:$unused $input1, $input2)>;
|
||||
|
||||
// Op for testing trivial removal via folding of op with inner ops and no uses.
|
||||
def TestOpWithRegionFoldNoSideEffect : TEST_Op<
|
||||
"op_with_region_fold_no_side_effect", [NoSideEffect]> {
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include "TestDialect.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
|
|
@ -248,6 +248,58 @@ func @verifyUnitAttr() -> (i32, i32) {
|
|||
return %0, %1 : i32, i32
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Constant Matching
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: testConstOp
|
||||
func @testConstOp() -> (i32) {
|
||||
// CHECK-NEXT: [[C0:%.+]] = constant 1
|
||||
%0 = "test.constant"() {value = 1 : i32} : () -> i32
|
||||
|
||||
// CHECK-NEXT: return [[C0]]
|
||||
return %0 : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testConstOpUsed
|
||||
func @testConstOpUsed() -> (i32) {
|
||||
// CHECK-NEXT: [[C0:%.+]] = constant 1
|
||||
%0 = "test.constant"() {value = 1 : i32} : () -> i32
|
||||
|
||||
// CHECK-NEXT: [[V0:%.+]] = "test.op_s"([[C0]])
|
||||
%1 = "test.op_s"(%0) {value = 1 : i32} : (i32) -> i32
|
||||
|
||||
// CHECK-NEXT: return [[V0]]
|
||||
return %1 : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testConstOpReplaced
|
||||
func @testConstOpReplaced() -> (i32) {
|
||||
// CHECK-NEXT: [[C0:%.+]] = constant 1
|
||||
%0 = "test.constant"() {value = 1 : i32} : () -> i32
|
||||
%1 = "test.constant"() {value = 2 : i32} : () -> i32
|
||||
|
||||
// CHECK: [[V0:%.+]] = "test.op_s"([[C0]]) {value = 2 : i32}
|
||||
%2 = "test.op_r"(%0, %1) : (i32, i32) -> i32
|
||||
|
||||
// CHECK: [[V0]]
|
||||
return %2 : i32
|
||||
}
|
||||
// CHECK-LABEL: testConstOpMatchFailure
|
||||
func @testConstOpMatchFailure() -> (i64) {
|
||||
// CHECK-DAG: [[C0:%.+]] = constant 1
|
||||
%0 = "test.constant"() {value = 1 : i64} : () -> i64
|
||||
|
||||
// CHECK-DAG: [[C1:%.+]] = constant 2
|
||||
%1 = "test.constant"() {value = 2 : i64} : () -> i64
|
||||
|
||||
// CHECK: [[V0:%.+]] = "test.op_r"([[C0]], [[C1]])
|
||||
%2 = "test.op_r"(%0, %1) : (i64, i64) -> i64
|
||||
|
||||
// CHECK: [[V0]]
|
||||
return %2 : i64
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Enum Attributes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR1 %s 2>&1 | FileCheck --check-prefix=ERROR1 %s
|
||||
// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR2 %s 2>&1 | FileCheck --check-prefix=ERROR2 %s
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
// Check using the dialect name as the namespace
|
||||
def A_Dialect : Dialect {
|
||||
let name = "a";
|
||||
}
|
||||
|
||||
class A_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<A_Dialect, mnemonic, traits>;
|
||||
|
||||
def OpA : A_Op<"op_a">, Arguments<(ins AnyInteger, AnyInteger)>, Results<(outs AnyInteger)>;
|
||||
def OpB : A_Op<"op_b">, Arguments<(ins AnyInteger, AnyAttr:$value)>, Results<(outs AnyInteger)>;
|
||||
|
||||
#ifdef ERROR1
|
||||
def NativeMatcher : NativeCodeCall<"success(nativeCall($0, $1))">;
|
||||
// ERROR1: [[@LINE+1]]:1: error: binding symbol 'error' to native code call unsupported right now
|
||||
def : Pat<(OpA (NativeMatcher:$error $val), AnyI32Attr:$arg),
|
||||
(OpB $val, $arg)>;
|
||||
#endif
|
||||
|
||||
#ifdef ERROR2
|
||||
def NativeMatcher : NativeCodeCall<"success(nativeCall($0, $1))">;
|
||||
// ERROR2: [[@LINE+1]]:1: error: Matching nested tree in NativeCodecall not support for
|
||||
def : Pat<(OpA (NativeMatcher (OpB $val, $unused)), AnyI32Attr:$arg),
|
||||
(OpB $val, $arg)>;
|
||||
#endif
|
|
@ -63,7 +63,7 @@ public:
|
|||
|
||||
private:
|
||||
// Emits the code for matching ops.
|
||||
void emitMatchLogic(DagNode tree);
|
||||
void emitMatchLogic(DagNode tree, StringRef opName);
|
||||
|
||||
// Emits the code for rewriting ops.
|
||||
void emitRewriteLogic();
|
||||
|
@ -72,26 +72,34 @@ private:
|
|||
// Match utilities
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// Emits C++ statements for matching the DAG structure.
|
||||
void emitMatch(DagNode tree, StringRef name, int depth);
|
||||
|
||||
// Emits C++ statements for matching using a native code call.
|
||||
void emitNativeCodeMatch(DagNode tree, StringRef name, int depth);
|
||||
|
||||
// Emits C++ statements for matching the op constrained by the given DAG
|
||||
// `tree`.
|
||||
void emitOpMatch(DagNode tree, int depth);
|
||||
// `tree` returning the op's variable name.
|
||||
void emitOpMatch(DagNode tree, StringRef opName, int depth);
|
||||
|
||||
// Emits C++ statements for matching the `argIndex`-th argument of the given
|
||||
// DAG `tree` as an operand.
|
||||
void emitOperandMatch(DagNode tree, int argIndex, int depth);
|
||||
void emitOperandMatch(DagNode tree, StringRef opName, int argIndex,
|
||||
int depth);
|
||||
|
||||
// Emits C++ statements for matching the `argIndex`-th argument of the given
|
||||
// DAG `tree` as an attribute.
|
||||
void emitAttributeMatch(DagNode tree, int argIndex, int depth);
|
||||
void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
|
||||
int depth);
|
||||
|
||||
// Emits C++ for checking a match with a corresponding match failure
|
||||
// diagnostic.
|
||||
void emitMatchCheck(int depth, const FmtObjectBase &matchFmt,
|
||||
void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt,
|
||||
const llvm::formatv_object_base &failureFmt);
|
||||
|
||||
// Emits C++ for checking a match with a corresponding match failure
|
||||
// diagnostics.
|
||||
void emitMatchCheck(int depth, const std::string &matchStr,
|
||||
void emitMatchCheck(StringRef opName, const std::string &matchStr,
|
||||
const std::string &failureStr);
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -113,7 +121,7 @@ private:
|
|||
|
||||
// Emits the C++ statement to replace the matched DAG with a value built via
|
||||
// calling native C++ code.
|
||||
std::string handleReplaceWithNativeCodeCall(DagNode resultTree);
|
||||
std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth);
|
||||
|
||||
// Returns the symbol of the old value serving as the replacement.
|
||||
StringRef handleReplaceWithValue(DagNode tree);
|
||||
|
@ -140,12 +148,13 @@ private:
|
|||
|
||||
// Emits the concrete arguments used to call an op's builder.
|
||||
void supplyValuesForOpArgs(DagNode node,
|
||||
const ChildNodeIndexNameMap &childNodeNames);
|
||||
const ChildNodeIndexNameMap &childNodeNames,
|
||||
int depth);
|
||||
|
||||
// Emits the local variables for holding all values as a whole and all named
|
||||
// attributes as a whole to be used for creating an op.
|
||||
void createAggregateLocalVarsForOpArgs(
|
||||
DagNode node, const ChildNodeIndexNameMap &childNodeNames);
|
||||
DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth);
|
||||
|
||||
// Returns the C++ expression to construct a constant attribute of the given
|
||||
// `value` for the given attribute kind `attr`.
|
||||
|
@ -218,21 +227,114 @@ std::string PatternEmitter::handleConstantAttr(Attribute attr,
|
|||
}
|
||||
|
||||
// Helper function to match patterns.
|
||||
void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
|
||||
void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
|
||||
if (tree.isNativeCodeCall()) {
|
||||
emitNativeCodeMatch(tree, name, depth);
|
||||
return;
|
||||
}
|
||||
|
||||
if (tree.isOperation()) {
|
||||
emitOpMatch(tree, name, depth);
|
||||
return;
|
||||
}
|
||||
|
||||
PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match.");
|
||||
}
|
||||
|
||||
// Helper function to match patterns.
|
||||
void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
|
||||
int depth) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: ");
|
||||
LLVM_DEBUG(tree.print(llvm::dbgs()));
|
||||
LLVM_DEBUG(llvm::dbgs() << '\n');
|
||||
|
||||
// TODO(suderman): iterate through arguments, determine their types, output
|
||||
// names.
|
||||
SmallVector<std::string, 8> capture(8);
|
||||
if (tree.getNumArgs() > 8) {
|
||||
PrintFatalError(loc,
|
||||
"unsupported NativeCodeCall matcher argument numbers: " +
|
||||
Twine(tree.getNumArgs()));
|
||||
}
|
||||
|
||||
raw_indented_ostream::DelimitedScope scope(os);
|
||||
|
||||
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
|
||||
std::string argName = formatv("arg{0}_{1}", depth, i);
|
||||
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
|
||||
os << "Value " << argName << ";\n";
|
||||
} else {
|
||||
auto leaf = tree.getArgAsLeaf(i);
|
||||
if (leaf.isAttrMatcher() || leaf.isConstantAttr()) {
|
||||
os << "Attribute " << argName << ";\n";
|
||||
} else if (leaf.isOperandMatcher()) {
|
||||
os << "Operation " << argName << ";\n";
|
||||
}
|
||||
}
|
||||
|
||||
capture[i] = std::move(argName);
|
||||
}
|
||||
|
||||
bool hasLocationDirective;
|
||||
std::string locToUse;
|
||||
std::tie(hasLocationDirective, locToUse) = getLocation(tree);
|
||||
|
||||
auto fmt = tree.getNativeCodeTemplate();
|
||||
auto nativeCodeCall = std::string(tgfmt(
|
||||
fmt, &fmtCtx.addSubst("_loc", locToUse), opName, capture[0], capture[1],
|
||||
capture[2], capture[3], capture[4], capture[5], capture[6], capture[7]));
|
||||
|
||||
os << "if (failed(" << nativeCodeCall << ")) return failure();\n";
|
||||
|
||||
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
|
||||
auto name = tree.getArgName(i);
|
||||
if (!name.empty() && name != "_") {
|
||||
os << formatv("{0} = {1};\n", name, capture[i]);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
|
||||
std::string argName = capture[i];
|
||||
|
||||
// Handle nested DAG construct first
|
||||
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
|
||||
PrintFatalError(
|
||||
loc, formatv("Matching nested tree in NativeCodecall not support for "
|
||||
"{0} as arg {1}",
|
||||
argName, i));
|
||||
}
|
||||
|
||||
DagLeaf leaf = tree.getArgAsLeaf(i);
|
||||
auto constraint = leaf.getAsConstraint();
|
||||
|
||||
auto self = formatv("{0}", argName);
|
||||
emitMatchCheck(
|
||||
opName,
|
||||
tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
|
||||
formatv("\"operand {0} of native code call '{1}' failed to satisfy "
|
||||
"constraint: "
|
||||
"'{2}'\"",
|
||||
i, tree.getNativeCodeTemplate(), constraint.getDescription()));
|
||||
}
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n");
|
||||
}
|
||||
|
||||
// Helper function to match patterns.
|
||||
void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
|
||||
Operator &op = tree.getDialectOp(opMap);
|
||||
LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
|
||||
<< op.getOperationName() << "' at depth " << depth
|
||||
<< '\n');
|
||||
|
||||
int indent = 4 + 2 * depth;
|
||||
os.indent(indent) << formatv(
|
||||
"auto castedOp{0} = ::llvm::dyn_cast_or_null<{1}>(op{0}); "
|
||||
"(void)castedOp{0};\n",
|
||||
depth, op.getQualCppClassName());
|
||||
std::string castedName = formatv("castedOp{0}", depth);
|
||||
os << formatv("auto {0} = ::llvm::dyn_cast_or_null<{2}>({1}); "
|
||||
"(void){0};\n",
|
||||
castedName, opName, op.getQualCppClassName());
|
||||
// Skip the operand matching at depth 0 as the pattern rewriter already does.
|
||||
if (depth != 0) {
|
||||
// Skip if there is no defining operation (e.g., arguments to function).
|
||||
os << formatv("if (!castedOp{0})\n return failure();\n", depth);
|
||||
os << formatv("if (!{0}) return failure();\n", castedName);
|
||||
}
|
||||
if (tree.getNumArgs() != op.getNumArgs()) {
|
||||
PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
|
||||
|
@ -244,10 +346,11 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
|
|||
// If the operand's name is set, set to that variable.
|
||||
auto name = tree.getSymbol();
|
||||
if (!name.empty())
|
||||
os << formatv("{0} = castedOp{1};\n", name, depth);
|
||||
os << formatv("{0} = {1};\n", name, castedName);
|
||||
|
||||
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
|
||||
auto opArg = op.getArg(i);
|
||||
std::string argName = formatv("op{0}", depth + 1);
|
||||
|
||||
// Handle nested DAG construct first
|
||||
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
|
||||
|
@ -262,20 +365,20 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
|
|||
os << "{\n";
|
||||
|
||||
os.indent() << formatv(
|
||||
"auto *op{0} = "
|
||||
"(*castedOp{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
|
||||
depth + 1, depth, i);
|
||||
emitOpMatch(argTree, depth + 1);
|
||||
os << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
|
||||
"auto *{0} = "
|
||||
"(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
|
||||
argName, castedName, i);
|
||||
emitMatch(argTree, argName, depth + 1);
|
||||
os << formatv("tblgen_ops[{0}] = {1};\n", ++opCounter, argName);
|
||||
os.unindent() << "}\n";
|
||||
continue;
|
||||
}
|
||||
|
||||
// Next handle DAG leaf: operand or attribute
|
||||
if (opArg.is<NamedTypeConstraint *>()) {
|
||||
emitOperandMatch(tree, i, depth);
|
||||
emitOperandMatch(tree, castedName, i, depth);
|
||||
} else if (opArg.is<NamedAttribute *>()) {
|
||||
emitAttributeMatch(tree, i, depth);
|
||||
emitAttributeMatch(tree, opName, i, depth);
|
||||
} else {
|
||||
PrintFatalError(loc, "unhandled case when matching op");
|
||||
}
|
||||
|
@ -285,7 +388,8 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
|
|||
<< '\n');
|
||||
}
|
||||
|
||||
void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) {
|
||||
void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
|
||||
int argIndex, int depth) {
|
||||
Operator &op = tree.getDialectOp(opMap);
|
||||
auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
|
||||
auto matcher = tree.getArgAsLeaf(argIndex);
|
||||
|
@ -309,11 +413,10 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) {
|
|||
op.getOperationName(), argIndex);
|
||||
PrintFatalError(loc, error);
|
||||
}
|
||||
auto self =
|
||||
formatv("(*castedOp{0}.getODSOperands({1}).begin()).getType()", depth,
|
||||
argIndex);
|
||||
auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()",
|
||||
opName, argIndex);
|
||||
emitMatchCheck(
|
||||
depth,
|
||||
opName,
|
||||
tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
|
||||
formatv("\"operand {0} of op '{1}' failed to satisfy constraint: "
|
||||
"'{2}'\"",
|
||||
|
@ -333,21 +436,22 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) {
|
|||
[](const Argument &arg) { return arg.is<NamedAttribute *>(); });
|
||||
|
||||
auto res = symbolInfoMap.findBoundSymbol(name, op, argIndex);
|
||||
os << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
|
||||
res->second.getVarName(name), depth, argIndex - numPrevAttrs);
|
||||
os << formatv("{0} = {1}.getODSOperands({2});\n",
|
||||
res->second.getVarName(name), opName,
|
||||
argIndex - numPrevAttrs);
|
||||
}
|
||||
}
|
||||
|
||||
void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
|
||||
void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
|
||||
int argIndex, int depth) {
|
||||
Operator &op = tree.getDialectOp(opMap);
|
||||
auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
|
||||
const auto &attr = namedAttr->attr;
|
||||
|
||||
os << "{\n";
|
||||
os.indent() << formatv(
|
||||
"auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\"); "
|
||||
"(void)tblgen_attr;\n",
|
||||
depth, attr.getStorageType(), namedAttr->name);
|
||||
os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
|
||||
"(void)tblgen_attr;\n",
|
||||
opName, attr.getStorageType(), namedAttr->name);
|
||||
|
||||
// TODO: This should use getter method to avoid duplication.
|
||||
if (attr.hasDefaultValue()) {
|
||||
|
@ -360,7 +464,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
|
|||
// should just capture a mlir::Attribute() to signal the missing state.
|
||||
// That is precisely what getAttr() returns on missing attributes.
|
||||
} else {
|
||||
emitMatchCheck(depth, tgfmt("tblgen_attr", &fmtCtx),
|
||||
emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx),
|
||||
formatv("\"expected op '{0}' to have attribute '{1}' "
|
||||
"of type '{2}'\"",
|
||||
op.getOperationName(), namedAttr->name,
|
||||
|
@ -378,7 +482,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
|
|||
// If a constraint is specified, we need to generate C++ statements to
|
||||
// check the constraint.
|
||||
emitMatchCheck(
|
||||
depth,
|
||||
opName,
|
||||
tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")),
|
||||
formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
|
||||
"{2}\"",
|
||||
|
@ -397,24 +501,25 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
|
|||
}
|
||||
|
||||
void PatternEmitter::emitMatchCheck(
|
||||
int depth, const FmtObjectBase &matchFmt,
|
||||
StringRef opName, const FmtObjectBase &matchFmt,
|
||||
const llvm::formatv_object_base &failureFmt) {
|
||||
emitMatchCheck(depth, matchFmt.str(), failureFmt.str());
|
||||
emitMatchCheck(opName, matchFmt.str(), failureFmt.str());
|
||||
}
|
||||
|
||||
void PatternEmitter::emitMatchCheck(int depth, const std::string &matchStr,
|
||||
void PatternEmitter::emitMatchCheck(StringRef opName,
|
||||
const std::string &matchStr,
|
||||
const std::string &failureStr) {
|
||||
|
||||
os << "if (!(" << matchStr << "))";
|
||||
os.scope("{\n", "\n}\n").os
|
||||
<< "return rewriter.notifyMatchFailure(op" << depth
|
||||
<< ", [&](::mlir::Diagnostic &diag) {\n diag << " << failureStr
|
||||
<< ";\n});";
|
||||
os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName
|
||||
<< ", [&](::mlir::Diagnostic &diag) {\n diag << "
|
||||
<< failureStr << ";\n});";
|
||||
}
|
||||
|
||||
void PatternEmitter::emitMatchLogic(DagNode tree) {
|
||||
void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
|
||||
int depth = 0;
|
||||
emitOpMatch(tree, depth);
|
||||
emitMatch(tree, opName, depth);
|
||||
|
||||
for (auto &appliedConstraint : pattern.getConstraints()) {
|
||||
auto &constraint = appliedConstraint.constraint;
|
||||
|
@ -425,7 +530,7 @@ void PatternEmitter::emitMatchLogic(DagNode tree) {
|
|||
auto self = formatv("({0}.getType())",
|
||||
symbolInfoMap.getValueAndRangeUse(entities.front()));
|
||||
emitMatchCheck(
|
||||
depth, tgfmt(condition, &fmtCtx.withSelf(self.str())),
|
||||
opName, tgfmt(condition, &fmtCtx.withSelf(self.str())),
|
||||
formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"",
|
||||
entities.front(), constraint.getDescription()));
|
||||
|
||||
|
@ -447,7 +552,7 @@ void PatternEmitter::emitMatchLogic(DagNode tree) {
|
|||
self = symbolInfoMap.getValueAndRangeUse(self);
|
||||
for (; i < 4; ++i)
|
||||
names.push_back("<unused>");
|
||||
emitMatchCheck(depth,
|
||||
emitMatchCheck(opName,
|
||||
tgfmt(condition, &fmtCtx.withSelf(self), names[0],
|
||||
names[1], names[2], names[3]),
|
||||
formatv("\"entities '{0}' failed to satisfy constraint: "
|
||||
|
@ -471,7 +576,7 @@ void PatternEmitter::emitMatchLogic(DagNode tree) {
|
|||
for (++startRange; startRange != endRange; ++startRange) {
|
||||
auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
|
||||
emitMatchCheck(
|
||||
depth,
|
||||
opName,
|
||||
formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
|
||||
formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
|
||||
secondOperand));
|
||||
|
@ -567,7 +672,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
|
|||
|
||||
os << "// Match\n";
|
||||
os << "tblgen_ops[0] = op0;\n";
|
||||
emitMatchLogic(sourceTree);
|
||||
emitMatchLogic(sourceTree, "op0");
|
||||
|
||||
os << "\n// Rewrite\n";
|
||||
emitRewriteLogic();
|
||||
|
@ -681,7 +786,7 @@ std::string PatternEmitter::handleResultPattern(DagNode resultTree,
|
|||
}
|
||||
|
||||
if (resultTree.isNativeCodeCall()) {
|
||||
auto symbol = handleReplaceWithNativeCodeCall(resultTree);
|
||||
auto symbol = handleReplaceWithNativeCodeCall(resultTree, depth);
|
||||
symbolInfoMap.bindValue(symbol);
|
||||
return symbol;
|
||||
}
|
||||
|
@ -798,7 +903,8 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
|
|||
PrintFatalError(loc, "unhandled case when rewriting op");
|
||||
}
|
||||
|
||||
std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
|
||||
std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
|
||||
int depth) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
|
||||
LLVM_DEBUG(tree.print(llvm::dbgs()));
|
||||
LLVM_DEBUG(llvm::dbgs() << '\n');
|
||||
|
@ -807,15 +913,20 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
|
|||
// TODO: replace formatv arguments with the exact specified args.
|
||||
SmallVector<std::string, 8> attrs(8);
|
||||
if (tree.getNumArgs() > 8) {
|
||||
PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
|
||||
Twine(tree.getNumArgs()));
|
||||
PrintFatalError(loc,
|
||||
"unsupported NativeCodeCall replace argument numbers: " +
|
||||
Twine(tree.getNumArgs()));
|
||||
}
|
||||
bool hasLocationDirective;
|
||||
std::string locToUse;
|
||||
std::tie(hasLocationDirective, locToUse) = getLocation(tree);
|
||||
|
||||
for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
|
||||
attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
|
||||
if (tree.isNestedDagArg(i)) {
|
||||
attrs[i] = handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1);
|
||||
} else {
|
||||
attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
|
||||
}
|
||||
LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
|
||||
<< " replacement: " << attrs[i] << "\n");
|
||||
}
|
||||
|
@ -924,7 +1035,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
|
|||
// create the ops.
|
||||
|
||||
// First prepare local variables for op arguments used in builder call.
|
||||
createAggregateLocalVarsForOpArgs(tree, childNodeNames);
|
||||
createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
|
||||
|
||||
// Then create the op.
|
||||
os.scope("", "\n}\n").os << formatv(
|
||||
|
@ -948,7 +1059,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
|
|||
|
||||
os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
|
||||
resultOp.getQualCppClassName(), locToUse);
|
||||
supplyValuesForOpArgs(tree, childNodeNames);
|
||||
supplyValuesForOpArgs(tree, childNodeNames, depth);
|
||||
os << "\n );\n}\n";
|
||||
return resultValue;
|
||||
}
|
||||
|
@ -959,7 +1070,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
|
|||
// here.
|
||||
|
||||
// First prepare local variables for op arguments used in builder call.
|
||||
createAggregateLocalVarsForOpArgs(tree, childNodeNames);
|
||||
createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
|
||||
|
||||
// Then prepare the result types. We need to specify the types for all
|
||||
// results.
|
||||
|
@ -1037,7 +1148,7 @@ void PatternEmitter::createSeparateLocalVarsForOpArgs(
|
|||
}
|
||||
|
||||
void PatternEmitter::supplyValuesForOpArgs(
|
||||
DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
|
||||
DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
|
||||
Operator &resultOp = node.getDialectOp(opMap);
|
||||
for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
|
||||
argIndex != numOpArgs; ++argIndex) {
|
||||
|
@ -1060,7 +1171,7 @@ void PatternEmitter::supplyValuesForOpArgs(
|
|||
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
|
||||
"for creating attribute");
|
||||
os << formatv("/*{0}=*/{1}", opArgName,
|
||||
handleReplaceWithNativeCodeCall(subTree));
|
||||
handleReplaceWithNativeCodeCall(subTree, depth));
|
||||
} else {
|
||||
auto leaf = node.getArgAsLeaf(argIndex);
|
||||
// The argument in the result DAG pattern.
|
||||
|
@ -1080,7 +1191,7 @@ void PatternEmitter::supplyValuesForOpArgs(
|
|||
}
|
||||
|
||||
void PatternEmitter::createAggregateLocalVarsForOpArgs(
|
||||
DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
|
||||
DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
|
||||
Operator &resultOp = node.getDialectOp(opMap);
|
||||
|
||||
auto scope = os.scope();
|
||||
|
@ -1102,7 +1213,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
|
|||
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
|
||||
"for creating attribute");
|
||||
os << formatv(addAttrCmd, opArgName,
|
||||
handleReplaceWithNativeCodeCall(subTree));
|
||||
handleReplaceWithNativeCodeCall(subTree, depth + 1));
|
||||
} else {
|
||||
auto leaf = node.getArgAsLeaf(argIndex);
|
||||
// The argument in the result DAG pattern.
|
||||
|
|
Loading…
Reference in New Issue