forked from OSchip/llvm-project
[mlir] RewriterGen NativeCodeCall matcher with ConstantOp matcher
Added an underlying matcher for generic constant ops. This included a rewriter of RewriterGen to make variable use more clear. Differential Revision: https://reviews.llvm.org/D89161
This commit is contained in:
parent
273c299d5d
commit
2bf423b021
|
@ -2351,6 +2351,8 @@ class NativeCodeCall<string expr> {
|
||||||
string expression = expr;
|
string expression = expr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def ConstantLikeMatcher : NativeCodeCall<"success(matchPattern($0->getResult(0), m_Constant(&$1)))">;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Rewrite directives
|
// Rewrite directives
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -252,6 +252,9 @@ public:
|
||||||
static SymbolInfo getAttr(const Operator *op, int index) {
|
static SymbolInfo getAttr(const Operator *op, int index) {
|
||||||
return SymbolInfo(op, Kind::Attr, index);
|
return SymbolInfo(op, Kind::Attr, index);
|
||||||
}
|
}
|
||||||
|
static SymbolInfo getAttr() {
|
||||||
|
return SymbolInfo(nullptr, Kind::Attr, llvm::None);
|
||||||
|
}
|
||||||
static SymbolInfo getOperand(const Operator *op, int index) {
|
static SymbolInfo getOperand(const Operator *op, int index) {
|
||||||
return SymbolInfo(op, Kind::Operand, index);
|
return SymbolInfo(op, Kind::Operand, index);
|
||||||
}
|
}
|
||||||
|
@ -319,6 +322,10 @@ public:
|
||||||
// is already bound.
|
// is already bound.
|
||||||
bool bindValue(StringRef symbol);
|
bool bindValue(StringRef symbol);
|
||||||
|
|
||||||
|
// Registers the given `symbol` as bound to an attr. Returns false if `symbol`
|
||||||
|
// is already bound.
|
||||||
|
bool bindAttr(StringRef symbol);
|
||||||
|
|
||||||
// Returns true if the given `symbol` is bound.
|
// Returns true if the given `symbol` is bound.
|
||||||
bool contains(StringRef symbol) const;
|
bool contains(StringRef symbol) const;
|
||||||
|
|
||||||
|
@ -421,6 +428,9 @@ public:
|
||||||
std::vector<IdentifierLine> getLocation() const;
|
std::vector<IdentifierLine> getLocation() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
// Helper function to verify variabld binding.
|
||||||
|
void verifyBind(bool result, StringRef symbolName);
|
||||||
|
|
||||||
// Recursively collects all bound symbols inside the DAG tree rooted
|
// Recursively collects all bound symbols inside the DAG tree rooted
|
||||||
// at `tree` and updates the given `infoMap`.
|
// at `tree` and updates the given `infoMap`.
|
||||||
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
|
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
|
||||||
|
|
|
@ -216,9 +216,13 @@ std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
|
LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
|
||||||
switch (kind) {
|
switch (kind) {
|
||||||
case Kind::Attr: {
|
case Kind::Attr: {
|
||||||
auto type =
|
if (op) {
|
||||||
op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
|
auto type =
|
||||||
return std::string(formatv("{0} {1};\n", type, name));
|
op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
|
||||||
|
return std::string(formatv("{0} {1};\n", type, name));
|
||||||
|
}
|
||||||
|
// TODO(suderman): Use a more exact type when available.
|
||||||
|
return std::string(formatv("Attribute {0};\n", name));
|
||||||
}
|
}
|
||||||
case Kind::Operand: {
|
case Kind::Operand: {
|
||||||
// Use operand range for captured operands (to support potential variadic
|
// Use operand range for captured operands (to support potential variadic
|
||||||
|
@ -394,6 +398,11 @@ bool SymbolInfoMap::bindValue(StringRef symbol) {
|
||||||
return symbolInfoMap.count(inserted->first) == 1;
|
return symbolInfoMap.count(inserted->first) == 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool SymbolInfoMap::bindAttr(StringRef symbol) {
|
||||||
|
auto inserted = symbolInfoMap.emplace(symbol, SymbolInfo::getAttr());
|
||||||
|
return symbolInfoMap.count(inserted->first) == 1;
|
||||||
|
}
|
||||||
|
|
||||||
bool SymbolInfoMap::contains(StringRef symbol) const {
|
bool SymbolInfoMap::contains(StringRef symbol) const {
|
||||||
return find(symbol) != symbolInfoMap.end();
|
return find(symbol) != symbolInfoMap.end();
|
||||||
}
|
}
|
||||||
|
@ -558,15 +567,15 @@ std::vector<AppliedConstraint> Pattern::getConstraints() const {
|
||||||
for (auto it : *listInit) {
|
for (auto it : *listInit) {
|
||||||
auto *dagInit = dyn_cast<llvm::DagInit>(it);
|
auto *dagInit = dyn_cast<llvm::DagInit>(it);
|
||||||
if (!dagInit)
|
if (!dagInit)
|
||||||
PrintFatalError(def.getLoc(), "all elements in Pattern multi-entity "
|
PrintFatalError(&def, "all elements in Pattern multi-entity "
|
||||||
"constraints should be DAG nodes");
|
"constraints should be DAG nodes");
|
||||||
|
|
||||||
std::vector<std::string> entities;
|
std::vector<std::string> entities;
|
||||||
entities.reserve(dagInit->arg_size());
|
entities.reserve(dagInit->arg_size());
|
||||||
for (auto *argName : dagInit->getArgNames()) {
|
for (auto *argName : dagInit->getArgNames()) {
|
||||||
if (!argName) {
|
if (!argName) {
|
||||||
PrintFatalError(
|
PrintFatalError(
|
||||||
def.getLoc(),
|
&def,
|
||||||
"operands to additional constraints can only be symbol references");
|
"operands to additional constraints can only be symbol references");
|
||||||
}
|
}
|
||||||
entities.push_back(std::string(argName->getValue()));
|
entities.push_back(std::string(argName->getValue()));
|
||||||
|
@ -584,7 +593,7 @@ int Pattern::getBenefit() const {
|
||||||
int initBenefit = getSourcePattern().getNumOps();
|
int initBenefit = getSourcePattern().getNumOps();
|
||||||
llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
|
llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
|
||||||
if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
|
if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
|
||||||
PrintFatalError(def.getLoc(),
|
PrintFatalError(&def,
|
||||||
"The 'addBenefit' takes and only takes one integer value");
|
"The 'addBenefit' takes and only takes one integer value");
|
||||||
}
|
}
|
||||||
return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
|
return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
|
||||||
|
@ -603,64 +612,120 @@ std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Pattern::verifyBind(bool result, StringRef symbolName) {
|
||||||
|
if (!result) {
|
||||||
|
auto err = formatv("symbol '{0}' bound more than once", symbolName);
|
||||||
|
PrintFatalError(&def, err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
|
void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
|
||||||
bool isSrcPattern) {
|
bool isSrcPattern) {
|
||||||
auto treeName = tree.getSymbol();
|
auto treeName = tree.getSymbol();
|
||||||
if (!tree.isOperation()) {
|
auto numTreeArgs = tree.getNumArgs();
|
||||||
|
|
||||||
|
if (tree.isNativeCodeCall()) {
|
||||||
if (!treeName.empty()) {
|
if (!treeName.empty()) {
|
||||||
PrintFatalError(
|
PrintFatalError(
|
||||||
def.getLoc(),
|
&def,
|
||||||
formatv("binding symbol '{0}' to non-operation unsupported right now",
|
formatv(
|
||||||
treeName));
|
"binding symbol '{0}' to native code call unsupported right now",
|
||||||
|
treeName));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i != numTreeArgs; ++i) {
|
||||||
|
if (auto treeArg = tree.getArgAsNestedDag(i)) {
|
||||||
|
// This DAG node argument is a DAG node itself. Go inside recursively.
|
||||||
|
collectBoundSymbols(treeArg, infoMap, isSrcPattern);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isSrcPattern)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// We can only bind symbols to arguments in source pattern. Those
|
||||||
|
// symbols are referenced in result patterns.
|
||||||
|
auto treeArgName = tree.getArgName(i);
|
||||||
|
|
||||||
|
// `$_` is a special symbol meaning ignore the current argument.
|
||||||
|
if (!treeArgName.empty() && treeArgName != "_") {
|
||||||
|
if (tree.isNestedDagArg(i)) {
|
||||||
|
auto err = formatv("cannot bind '{0}' for nested native call arg",
|
||||||
|
treeArgName);
|
||||||
|
PrintFatalError(&def, err);
|
||||||
|
}
|
||||||
|
|
||||||
|
DagLeaf leaf = tree.getArgAsLeaf(i);
|
||||||
|
auto constraint = leaf.getAsConstraint();
|
||||||
|
bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
|
||||||
|
leaf.isConstantAttr() ||
|
||||||
|
constraint.getKind() == Constraint::Kind::CK_Attr;
|
||||||
|
|
||||||
|
if (isAttr) {
|
||||||
|
verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
verifyBind(infoMap.bindValue(treeArgName), treeArgName);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tree.isOperation()) {
|
||||||
|
auto &op = getDialectOp(tree);
|
||||||
|
auto numOpArgs = op.getNumArgs();
|
||||||
|
|
||||||
|
// The pattern might have the last argument specifying the location.
|
||||||
|
bool hasLocDirective = false;
|
||||||
|
if (numTreeArgs != 0) {
|
||||||
|
if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1))
|
||||||
|
hasLocDirective = lastArg.isLocationDirective();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (numOpArgs != numTreeArgs - hasLocDirective) {
|
||||||
|
auto err = formatv("op '{0}' argument number mismatch: "
|
||||||
|
"{1} in pattern vs. {2} in definition",
|
||||||
|
op.getOperationName(), numTreeArgs, numOpArgs);
|
||||||
|
PrintFatalError(&def, err);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The name attached to the DAG node's operator is for representing the
|
||||||
|
// results generated from this op. It should be remembered as bound results.
|
||||||
|
if (!treeName.empty()) {
|
||||||
|
LLVM_DEBUG(llvm::dbgs()
|
||||||
|
<< "found symbol bound to op result: " << treeName << '\n');
|
||||||
|
verifyBind(infoMap.bindOpResult(treeName, op), treeName);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i != numTreeArgs; ++i) {
|
||||||
|
if (auto treeArg = tree.getArgAsNestedDag(i)) {
|
||||||
|
// This DAG node argument is a DAG node itself. Go inside recursively.
|
||||||
|
collectBoundSymbols(treeArg, infoMap, isSrcPattern);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isSrcPattern) {
|
||||||
|
// We can only bind symbols to op arguments in source pattern. Those
|
||||||
|
// symbols are referenced in result patterns.
|
||||||
|
auto treeArgName = tree.getArgName(i);
|
||||||
|
// `$_` is a special symbol meaning ignore the current argument.
|
||||||
|
if (!treeArgName.empty() && treeArgName != "_") {
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
|
||||||
|
<< treeArgName << '\n');
|
||||||
|
verifyBind(infoMap.bindOpArgument(treeArgName, op, i), treeArgName);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto &op = getDialectOp(tree);
|
|
||||||
auto numOpArgs = op.getNumArgs();
|
|
||||||
auto numTreeArgs = tree.getNumArgs();
|
|
||||||
|
|
||||||
// The pattern might have the last argument specifying the location.
|
|
||||||
bool hasLocDirective = false;
|
|
||||||
if (numTreeArgs != 0) {
|
|
||||||
if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1))
|
|
||||||
hasLocDirective = lastArg.isLocationDirective();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (numOpArgs != numTreeArgs - hasLocDirective) {
|
|
||||||
auto err = formatv("op '{0}' argument number mismatch: "
|
|
||||||
"{1} in pattern vs. {2} in definition",
|
|
||||||
op.getOperationName(), numTreeArgs, numOpArgs);
|
|
||||||
PrintFatalError(def.getLoc(), err);
|
|
||||||
}
|
|
||||||
|
|
||||||
// The name attached to the DAG node's operator is for representing the
|
|
||||||
// results generated from this op. It should be remembered as bound results.
|
|
||||||
if (!treeName.empty()) {
|
if (!treeName.empty()) {
|
||||||
LLVM_DEBUG(llvm::dbgs()
|
PrintFatalError(
|
||||||
<< "found symbol bound to op result: " << treeName << '\n');
|
&def, formatv("binding symbol '{0}' to non-operation/native code call "
|
||||||
if (!infoMap.bindOpResult(treeName, op))
|
"unsupported right now",
|
||||||
PrintFatalError(def.getLoc(),
|
treeName));
|
||||||
formatv("symbol '{0}' bound more than once", treeName));
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i != numTreeArgs; ++i) {
|
|
||||||
if (auto treeArg = tree.getArgAsNestedDag(i)) {
|
|
||||||
// This DAG node argument is a DAG node itself. Go inside recursively.
|
|
||||||
collectBoundSymbols(treeArg, infoMap, isSrcPattern);
|
|
||||||
} else if (isSrcPattern) {
|
|
||||||
// We can only bind symbols to op arguments in source pattern. Those
|
|
||||||
// symbols are referenced in result patterns.
|
|
||||||
auto treeArgName = tree.getArgName(i);
|
|
||||||
// `$_` is a special symbol meaning ignore the current argument.
|
|
||||||
if (!treeArgName.empty() && treeArgName != "_") {
|
|
||||||
LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
|
|
||||||
<< treeArgName << '\n');
|
|
||||||
if (!infoMap.bindOpArgument(treeArgName, op, i)) {
|
|
||||||
auto err = formatv("symbol '{0}' bound more than once", treeArgName);
|
|
||||||
PrintFatalError(def.getLoc(), err);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -615,6 +615,10 @@ OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
|
||||||
return operand();
|
return operand();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
|
||||||
|
return getValue();
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult TestOpWithVariadicResultsAndFolder::fold(
|
LogicalResult TestOpWithVariadicResultsAndFolder::fold(
|
||||||
ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
|
ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
|
||||||
for (Value input : this->operands()) {
|
for (Value input : this->operands()) {
|
||||||
|
|
|
@ -799,6 +799,22 @@ def TestOpWithRegionPattern : TEST_Op<"op_with_region_pattern"> {
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TestOpConstant : TEST_Op<"constant", [ConstantLike, NoSideEffect]> {
|
||||||
|
let arguments = (ins AnyAttr:$value);
|
||||||
|
let results = (outs AnyType);
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
Attribute getValue() { return getAttr("value"); }
|
||||||
|
}];
|
||||||
|
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
def OpR : TEST_Op<"op_r">, Arguments<(ins AnyInteger, AnyInteger)>, Results<(outs AnyInteger)>;
|
||||||
|
def OpS : TEST_Op<"op_s">, Arguments<(ins AnyInteger, AnyAttr:$value)>, Results<(outs AnyInteger)>;
|
||||||
|
|
||||||
|
def : Pat<(OpR $input1, (ConstantLikeMatcher I32Attr:$input2)),
|
||||||
|
(OpS:$unused $input1, $input2)>;
|
||||||
|
|
||||||
// Op for testing trivial removal via folding of op with inner ops and no uses.
|
// Op for testing trivial removal via folding of op with inner ops and no uses.
|
||||||
def TestOpWithRegionFoldNoSideEffect : TEST_Op<
|
def TestOpWithRegionFoldNoSideEffect : TEST_Op<
|
||||||
"op_with_region_fold_no_side_effect", [NoSideEffect]> {
|
"op_with_region_fold_no_side_effect", [NoSideEffect]> {
|
||||||
|
|
|
@ -9,6 +9,7 @@
|
||||||
#include "TestDialect.h"
|
#include "TestDialect.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
|
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
|
@ -248,6 +248,58 @@ func @verifyUnitAttr() -> (i32, i32) {
|
||||||
return %0, %1 : i32, i32
|
return %0, %1 : i32, i32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Test Constant Matching
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// CHECK-LABEL: testConstOp
|
||||||
|
func @testConstOp() -> (i32) {
|
||||||
|
// CHECK-NEXT: [[C0:%.+]] = constant 1
|
||||||
|
%0 = "test.constant"() {value = 1 : i32} : () -> i32
|
||||||
|
|
||||||
|
// CHECK-NEXT: return [[C0]]
|
||||||
|
return %0 : i32
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: testConstOpUsed
|
||||||
|
func @testConstOpUsed() -> (i32) {
|
||||||
|
// CHECK-NEXT: [[C0:%.+]] = constant 1
|
||||||
|
%0 = "test.constant"() {value = 1 : i32} : () -> i32
|
||||||
|
|
||||||
|
// CHECK-NEXT: [[V0:%.+]] = "test.op_s"([[C0]])
|
||||||
|
%1 = "test.op_s"(%0) {value = 1 : i32} : (i32) -> i32
|
||||||
|
|
||||||
|
// CHECK-NEXT: return [[V0]]
|
||||||
|
return %1 : i32
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: testConstOpReplaced
|
||||||
|
func @testConstOpReplaced() -> (i32) {
|
||||||
|
// CHECK-NEXT: [[C0:%.+]] = constant 1
|
||||||
|
%0 = "test.constant"() {value = 1 : i32} : () -> i32
|
||||||
|
%1 = "test.constant"() {value = 2 : i32} : () -> i32
|
||||||
|
|
||||||
|
// CHECK: [[V0:%.+]] = "test.op_s"([[C0]]) {value = 2 : i32}
|
||||||
|
%2 = "test.op_r"(%0, %1) : (i32, i32) -> i32
|
||||||
|
|
||||||
|
// CHECK: [[V0]]
|
||||||
|
return %2 : i32
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: testConstOpMatchFailure
|
||||||
|
func @testConstOpMatchFailure() -> (i64) {
|
||||||
|
// CHECK-DAG: [[C0:%.+]] = constant 1
|
||||||
|
%0 = "test.constant"() {value = 1 : i64} : () -> i64
|
||||||
|
|
||||||
|
// CHECK-DAG: [[C1:%.+]] = constant 2
|
||||||
|
%1 = "test.constant"() {value = 2 : i64} : () -> i64
|
||||||
|
|
||||||
|
// CHECK: [[V0:%.+]] = "test.op_r"([[C0]], [[C1]])
|
||||||
|
%2 = "test.op_r"(%0, %1) : (i64, i64) -> i64
|
||||||
|
|
||||||
|
// CHECK: [[V0]]
|
||||||
|
return %2 : i64
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Test Enum Attributes
|
// Test Enum Attributes
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR1 %s 2>&1 | FileCheck --check-prefix=ERROR1 %s
|
||||||
|
// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR2 %s 2>&1 | FileCheck --check-prefix=ERROR2 %s
|
||||||
|
|
||||||
|
include "mlir/IR/OpBase.td"
|
||||||
|
|
||||||
|
// Check using the dialect name as the namespace
|
||||||
|
def A_Dialect : Dialect {
|
||||||
|
let name = "a";
|
||||||
|
}
|
||||||
|
|
||||||
|
class A_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||||
|
Op<A_Dialect, mnemonic, traits>;
|
||||||
|
|
||||||
|
def OpA : A_Op<"op_a">, Arguments<(ins AnyInteger, AnyInteger)>, Results<(outs AnyInteger)>;
|
||||||
|
def OpB : A_Op<"op_b">, Arguments<(ins AnyInteger, AnyAttr:$value)>, Results<(outs AnyInteger)>;
|
||||||
|
|
||||||
|
#ifdef ERROR1
|
||||||
|
def NativeMatcher : NativeCodeCall<"success(nativeCall($0, $1))">;
|
||||||
|
// ERROR1: [[@LINE+1]]:1: error: binding symbol 'error' to native code call unsupported right now
|
||||||
|
def : Pat<(OpA (NativeMatcher:$error $val), AnyI32Attr:$arg),
|
||||||
|
(OpB $val, $arg)>;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef ERROR2
|
||||||
|
def NativeMatcher : NativeCodeCall<"success(nativeCall($0, $1))">;
|
||||||
|
// ERROR2: [[@LINE+1]]:1: error: Matching nested tree in NativeCodecall not support for
|
||||||
|
def : Pat<(OpA (NativeMatcher (OpB $val, $unused)), AnyI32Attr:$arg),
|
||||||
|
(OpB $val, $arg)>;
|
||||||
|
#endif
|
|
@ -63,7 +63,7 @@ public:
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Emits the code for matching ops.
|
// Emits the code for matching ops.
|
||||||
void emitMatchLogic(DagNode tree);
|
void emitMatchLogic(DagNode tree, StringRef opName);
|
||||||
|
|
||||||
// Emits the code for rewriting ops.
|
// Emits the code for rewriting ops.
|
||||||
void emitRewriteLogic();
|
void emitRewriteLogic();
|
||||||
|
@ -72,26 +72,34 @@ private:
|
||||||
// Match utilities
|
// Match utilities
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Emits C++ statements for matching the DAG structure.
|
||||||
|
void emitMatch(DagNode tree, StringRef name, int depth);
|
||||||
|
|
||||||
|
// Emits C++ statements for matching using a native code call.
|
||||||
|
void emitNativeCodeMatch(DagNode tree, StringRef name, int depth);
|
||||||
|
|
||||||
// Emits C++ statements for matching the op constrained by the given DAG
|
// Emits C++ statements for matching the op constrained by the given DAG
|
||||||
// `tree`.
|
// `tree` returning the op's variable name.
|
||||||
void emitOpMatch(DagNode tree, int depth);
|
void emitOpMatch(DagNode tree, StringRef opName, int depth);
|
||||||
|
|
||||||
// Emits C++ statements for matching the `argIndex`-th argument of the given
|
// Emits C++ statements for matching the `argIndex`-th argument of the given
|
||||||
// DAG `tree` as an operand.
|
// DAG `tree` as an operand.
|
||||||
void emitOperandMatch(DagNode tree, int argIndex, int depth);
|
void emitOperandMatch(DagNode tree, StringRef opName, int argIndex,
|
||||||
|
int depth);
|
||||||
|
|
||||||
// Emits C++ statements for matching the `argIndex`-th argument of the given
|
// Emits C++ statements for matching the `argIndex`-th argument of the given
|
||||||
// DAG `tree` as an attribute.
|
// DAG `tree` as an attribute.
|
||||||
void emitAttributeMatch(DagNode tree, int argIndex, int depth);
|
void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
|
||||||
|
int depth);
|
||||||
|
|
||||||
// Emits C++ for checking a match with a corresponding match failure
|
// Emits C++ for checking a match with a corresponding match failure
|
||||||
// diagnostic.
|
// diagnostic.
|
||||||
void emitMatchCheck(int depth, const FmtObjectBase &matchFmt,
|
void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt,
|
||||||
const llvm::formatv_object_base &failureFmt);
|
const llvm::formatv_object_base &failureFmt);
|
||||||
|
|
||||||
// Emits C++ for checking a match with a corresponding match failure
|
// Emits C++ for checking a match with a corresponding match failure
|
||||||
// diagnostics.
|
// diagnostics.
|
||||||
void emitMatchCheck(int depth, const std::string &matchStr,
|
void emitMatchCheck(StringRef opName, const std::string &matchStr,
|
||||||
const std::string &failureStr);
|
const std::string &failureStr);
|
||||||
|
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
|
@ -113,7 +121,7 @@ private:
|
||||||
|
|
||||||
// Emits the C++ statement to replace the matched DAG with a value built via
|
// Emits the C++ statement to replace the matched DAG with a value built via
|
||||||
// calling native C++ code.
|
// calling native C++ code.
|
||||||
std::string handleReplaceWithNativeCodeCall(DagNode resultTree);
|
std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth);
|
||||||
|
|
||||||
// Returns the symbol of the old value serving as the replacement.
|
// Returns the symbol of the old value serving as the replacement.
|
||||||
StringRef handleReplaceWithValue(DagNode tree);
|
StringRef handleReplaceWithValue(DagNode tree);
|
||||||
|
@ -140,12 +148,13 @@ private:
|
||||||
|
|
||||||
// Emits the concrete arguments used to call an op's builder.
|
// Emits the concrete arguments used to call an op's builder.
|
||||||
void supplyValuesForOpArgs(DagNode node,
|
void supplyValuesForOpArgs(DagNode node,
|
||||||
const ChildNodeIndexNameMap &childNodeNames);
|
const ChildNodeIndexNameMap &childNodeNames,
|
||||||
|
int depth);
|
||||||
|
|
||||||
// Emits the local variables for holding all values as a whole and all named
|
// Emits the local variables for holding all values as a whole and all named
|
||||||
// attributes as a whole to be used for creating an op.
|
// attributes as a whole to be used for creating an op.
|
||||||
void createAggregateLocalVarsForOpArgs(
|
void createAggregateLocalVarsForOpArgs(
|
||||||
DagNode node, const ChildNodeIndexNameMap &childNodeNames);
|
DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth);
|
||||||
|
|
||||||
// Returns the C++ expression to construct a constant attribute of the given
|
// Returns the C++ expression to construct a constant attribute of the given
|
||||||
// `value` for the given attribute kind `attr`.
|
// `value` for the given attribute kind `attr`.
|
||||||
|
@ -218,21 +227,114 @@ std::string PatternEmitter::handleConstantAttr(Attribute attr,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to match patterns.
|
// Helper function to match patterns.
|
||||||
void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
|
void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
|
||||||
|
if (tree.isNativeCodeCall()) {
|
||||||
|
emitNativeCodeMatch(tree, name, depth);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tree.isOperation()) {
|
||||||
|
emitOpMatch(tree, name, depth);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to match patterns.
|
||||||
|
void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
|
||||||
|
int depth) {
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: ");
|
||||||
|
LLVM_DEBUG(tree.print(llvm::dbgs()));
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << '\n');
|
||||||
|
|
||||||
|
// TODO(suderman): iterate through arguments, determine their types, output
|
||||||
|
// names.
|
||||||
|
SmallVector<std::string, 8> capture(8);
|
||||||
|
if (tree.getNumArgs() > 8) {
|
||||||
|
PrintFatalError(loc,
|
||||||
|
"unsupported NativeCodeCall matcher argument numbers: " +
|
||||||
|
Twine(tree.getNumArgs()));
|
||||||
|
}
|
||||||
|
|
||||||
|
raw_indented_ostream::DelimitedScope scope(os);
|
||||||
|
|
||||||
|
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
|
||||||
|
std::string argName = formatv("arg{0}_{1}", depth, i);
|
||||||
|
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
|
||||||
|
os << "Value " << argName << ";\n";
|
||||||
|
} else {
|
||||||
|
auto leaf = tree.getArgAsLeaf(i);
|
||||||
|
if (leaf.isAttrMatcher() || leaf.isConstantAttr()) {
|
||||||
|
os << "Attribute " << argName << ";\n";
|
||||||
|
} else if (leaf.isOperandMatcher()) {
|
||||||
|
os << "Operation " << argName << ";\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
capture[i] = std::move(argName);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool hasLocationDirective;
|
||||||
|
std::string locToUse;
|
||||||
|
std::tie(hasLocationDirective, locToUse) = getLocation(tree);
|
||||||
|
|
||||||
|
auto fmt = tree.getNativeCodeTemplate();
|
||||||
|
auto nativeCodeCall = std::string(tgfmt(
|
||||||
|
fmt, &fmtCtx.addSubst("_loc", locToUse), opName, capture[0], capture[1],
|
||||||
|
capture[2], capture[3], capture[4], capture[5], capture[6], capture[7]));
|
||||||
|
|
||||||
|
os << "if (failed(" << nativeCodeCall << ")) return failure();\n";
|
||||||
|
|
||||||
|
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
|
||||||
|
auto name = tree.getArgName(i);
|
||||||
|
if (!name.empty() && name != "_") {
|
||||||
|
os << formatv("{0} = {1};\n", name, capture[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
|
||||||
|
std::string argName = capture[i];
|
||||||
|
|
||||||
|
// Handle nested DAG construct first
|
||||||
|
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
|
||||||
|
PrintFatalError(
|
||||||
|
loc, formatv("Matching nested tree in NativeCodecall not support for "
|
||||||
|
"{0} as arg {1}",
|
||||||
|
argName, i));
|
||||||
|
}
|
||||||
|
|
||||||
|
DagLeaf leaf = tree.getArgAsLeaf(i);
|
||||||
|
auto constraint = leaf.getAsConstraint();
|
||||||
|
|
||||||
|
auto self = formatv("{0}", argName);
|
||||||
|
emitMatchCheck(
|
||||||
|
opName,
|
||||||
|
tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
|
||||||
|
formatv("\"operand {0} of native code call '{1}' failed to satisfy "
|
||||||
|
"constraint: "
|
||||||
|
"'{2}'\"",
|
||||||
|
i, tree.getNativeCodeTemplate(), constraint.getDescription()));
|
||||||
|
}
|
||||||
|
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to match patterns.
|
||||||
|
void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
|
||||||
Operator &op = tree.getDialectOp(opMap);
|
Operator &op = tree.getDialectOp(opMap);
|
||||||
LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
|
LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
|
||||||
<< op.getOperationName() << "' at depth " << depth
|
<< op.getOperationName() << "' at depth " << depth
|
||||||
<< '\n');
|
<< '\n');
|
||||||
|
|
||||||
int indent = 4 + 2 * depth;
|
std::string castedName = formatv("castedOp{0}", depth);
|
||||||
os.indent(indent) << formatv(
|
os << formatv("auto {0} = ::llvm::dyn_cast_or_null<{2}>({1}); "
|
||||||
"auto castedOp{0} = ::llvm::dyn_cast_or_null<{1}>(op{0}); "
|
"(void){0};\n",
|
||||||
"(void)castedOp{0};\n",
|
castedName, opName, op.getQualCppClassName());
|
||||||
depth, op.getQualCppClassName());
|
|
||||||
// Skip the operand matching at depth 0 as the pattern rewriter already does.
|
// Skip the operand matching at depth 0 as the pattern rewriter already does.
|
||||||
if (depth != 0) {
|
if (depth != 0) {
|
||||||
// Skip if there is no defining operation (e.g., arguments to function).
|
// Skip if there is no defining operation (e.g., arguments to function).
|
||||||
os << formatv("if (!castedOp{0})\n return failure();\n", depth);
|
os << formatv("if (!{0}) return failure();\n", castedName);
|
||||||
}
|
}
|
||||||
if (tree.getNumArgs() != op.getNumArgs()) {
|
if (tree.getNumArgs() != op.getNumArgs()) {
|
||||||
PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
|
PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
|
||||||
|
@ -244,10 +346,11 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
|
||||||
// If the operand's name is set, set to that variable.
|
// If the operand's name is set, set to that variable.
|
||||||
auto name = tree.getSymbol();
|
auto name = tree.getSymbol();
|
||||||
if (!name.empty())
|
if (!name.empty())
|
||||||
os << formatv("{0} = castedOp{1};\n", name, depth);
|
os << formatv("{0} = {1};\n", name, castedName);
|
||||||
|
|
||||||
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
|
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
|
||||||
auto opArg = op.getArg(i);
|
auto opArg = op.getArg(i);
|
||||||
|
std::string argName = formatv("op{0}", depth + 1);
|
||||||
|
|
||||||
// Handle nested DAG construct first
|
// Handle nested DAG construct first
|
||||||
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
|
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
|
||||||
|
@ -262,20 +365,20 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
|
||||||
os << "{\n";
|
os << "{\n";
|
||||||
|
|
||||||
os.indent() << formatv(
|
os.indent() << formatv(
|
||||||
"auto *op{0} = "
|
"auto *{0} = "
|
||||||
"(*castedOp{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
|
"(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
|
||||||
depth + 1, depth, i);
|
argName, castedName, i);
|
||||||
emitOpMatch(argTree, depth + 1);
|
emitMatch(argTree, argName, depth + 1);
|
||||||
os << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
|
os << formatv("tblgen_ops[{0}] = {1};\n", ++opCounter, argName);
|
||||||
os.unindent() << "}\n";
|
os.unindent() << "}\n";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Next handle DAG leaf: operand or attribute
|
// Next handle DAG leaf: operand or attribute
|
||||||
if (opArg.is<NamedTypeConstraint *>()) {
|
if (opArg.is<NamedTypeConstraint *>()) {
|
||||||
emitOperandMatch(tree, i, depth);
|
emitOperandMatch(tree, castedName, i, depth);
|
||||||
} else if (opArg.is<NamedAttribute *>()) {
|
} else if (opArg.is<NamedAttribute *>()) {
|
||||||
emitAttributeMatch(tree, i, depth);
|
emitAttributeMatch(tree, opName, i, depth);
|
||||||
} else {
|
} else {
|
||||||
PrintFatalError(loc, "unhandled case when matching op");
|
PrintFatalError(loc, "unhandled case when matching op");
|
||||||
}
|
}
|
||||||
|
@ -285,7 +388,8 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
|
||||||
<< '\n');
|
<< '\n');
|
||||||
}
|
}
|
||||||
|
|
||||||
void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) {
|
void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
|
||||||
|
int argIndex, int depth) {
|
||||||
Operator &op = tree.getDialectOp(opMap);
|
Operator &op = tree.getDialectOp(opMap);
|
||||||
auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
|
auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
|
||||||
auto matcher = tree.getArgAsLeaf(argIndex);
|
auto matcher = tree.getArgAsLeaf(argIndex);
|
||||||
|
@ -309,11 +413,10 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) {
|
||||||
op.getOperationName(), argIndex);
|
op.getOperationName(), argIndex);
|
||||||
PrintFatalError(loc, error);
|
PrintFatalError(loc, error);
|
||||||
}
|
}
|
||||||
auto self =
|
auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()",
|
||||||
formatv("(*castedOp{0}.getODSOperands({1}).begin()).getType()", depth,
|
opName, argIndex);
|
||||||
argIndex);
|
|
||||||
emitMatchCheck(
|
emitMatchCheck(
|
||||||
depth,
|
opName,
|
||||||
tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
|
tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
|
||||||
formatv("\"operand {0} of op '{1}' failed to satisfy constraint: "
|
formatv("\"operand {0} of op '{1}' failed to satisfy constraint: "
|
||||||
"'{2}'\"",
|
"'{2}'\"",
|
||||||
|
@ -333,21 +436,22 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) {
|
||||||
[](const Argument &arg) { return arg.is<NamedAttribute *>(); });
|
[](const Argument &arg) { return arg.is<NamedAttribute *>(); });
|
||||||
|
|
||||||
auto res = symbolInfoMap.findBoundSymbol(name, op, argIndex);
|
auto res = symbolInfoMap.findBoundSymbol(name, op, argIndex);
|
||||||
os << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
|
os << formatv("{0} = {1}.getODSOperands({2});\n",
|
||||||
res->second.getVarName(name), depth, argIndex - numPrevAttrs);
|
res->second.getVarName(name), opName,
|
||||||
|
argIndex - numPrevAttrs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
|
void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
|
||||||
|
int argIndex, int depth) {
|
||||||
Operator &op = tree.getDialectOp(opMap);
|
Operator &op = tree.getDialectOp(opMap);
|
||||||
auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
|
auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
|
||||||
const auto &attr = namedAttr->attr;
|
const auto &attr = namedAttr->attr;
|
||||||
|
|
||||||
os << "{\n";
|
os << "{\n";
|
||||||
os.indent() << formatv(
|
os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
|
||||||
"auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\"); "
|
"(void)tblgen_attr;\n",
|
||||||
"(void)tblgen_attr;\n",
|
opName, attr.getStorageType(), namedAttr->name);
|
||||||
depth, attr.getStorageType(), namedAttr->name);
|
|
||||||
|
|
||||||
// TODO: This should use getter method to avoid duplication.
|
// TODO: This should use getter method to avoid duplication.
|
||||||
if (attr.hasDefaultValue()) {
|
if (attr.hasDefaultValue()) {
|
||||||
|
@ -360,7 +464,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
|
||||||
// should just capture a mlir::Attribute() to signal the missing state.
|
// should just capture a mlir::Attribute() to signal the missing state.
|
||||||
// That is precisely what getAttr() returns on missing attributes.
|
// That is precisely what getAttr() returns on missing attributes.
|
||||||
} else {
|
} else {
|
||||||
emitMatchCheck(depth, tgfmt("tblgen_attr", &fmtCtx),
|
emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx),
|
||||||
formatv("\"expected op '{0}' to have attribute '{1}' "
|
formatv("\"expected op '{0}' to have attribute '{1}' "
|
||||||
"of type '{2}'\"",
|
"of type '{2}'\"",
|
||||||
op.getOperationName(), namedAttr->name,
|
op.getOperationName(), namedAttr->name,
|
||||||
|
@ -378,7 +482,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
|
||||||
// If a constraint is specified, we need to generate C++ statements to
|
// If a constraint is specified, we need to generate C++ statements to
|
||||||
// check the constraint.
|
// check the constraint.
|
||||||
emitMatchCheck(
|
emitMatchCheck(
|
||||||
depth,
|
opName,
|
||||||
tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")),
|
tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")),
|
||||||
formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
|
formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
|
||||||
"{2}\"",
|
"{2}\"",
|
||||||
|
@ -397,24 +501,25 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void PatternEmitter::emitMatchCheck(
|
void PatternEmitter::emitMatchCheck(
|
||||||
int depth, const FmtObjectBase &matchFmt,
|
StringRef opName, const FmtObjectBase &matchFmt,
|
||||||
const llvm::formatv_object_base &failureFmt) {
|
const llvm::formatv_object_base &failureFmt) {
|
||||||
emitMatchCheck(depth, matchFmt.str(), failureFmt.str());
|
emitMatchCheck(opName, matchFmt.str(), failureFmt.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
void PatternEmitter::emitMatchCheck(int depth, const std::string &matchStr,
|
void PatternEmitter::emitMatchCheck(StringRef opName,
|
||||||
|
const std::string &matchStr,
|
||||||
const std::string &failureStr) {
|
const std::string &failureStr) {
|
||||||
|
|
||||||
os << "if (!(" << matchStr << "))";
|
os << "if (!(" << matchStr << "))";
|
||||||
os.scope("{\n", "\n}\n").os
|
os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName
|
||||||
<< "return rewriter.notifyMatchFailure(op" << depth
|
<< ", [&](::mlir::Diagnostic &diag) {\n diag << "
|
||||||
<< ", [&](::mlir::Diagnostic &diag) {\n diag << " << failureStr
|
<< failureStr << ";\n});";
|
||||||
<< ";\n});";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void PatternEmitter::emitMatchLogic(DagNode tree) {
|
void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
|
LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
|
||||||
int depth = 0;
|
int depth = 0;
|
||||||
emitOpMatch(tree, depth);
|
emitMatch(tree, opName, depth);
|
||||||
|
|
||||||
for (auto &appliedConstraint : pattern.getConstraints()) {
|
for (auto &appliedConstraint : pattern.getConstraints()) {
|
||||||
auto &constraint = appliedConstraint.constraint;
|
auto &constraint = appliedConstraint.constraint;
|
||||||
|
@ -425,7 +530,7 @@ void PatternEmitter::emitMatchLogic(DagNode tree) {
|
||||||
auto self = formatv("({0}.getType())",
|
auto self = formatv("({0}.getType())",
|
||||||
symbolInfoMap.getValueAndRangeUse(entities.front()));
|
symbolInfoMap.getValueAndRangeUse(entities.front()));
|
||||||
emitMatchCheck(
|
emitMatchCheck(
|
||||||
depth, tgfmt(condition, &fmtCtx.withSelf(self.str())),
|
opName, tgfmt(condition, &fmtCtx.withSelf(self.str())),
|
||||||
formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"",
|
formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"",
|
||||||
entities.front(), constraint.getDescription()));
|
entities.front(), constraint.getDescription()));
|
||||||
|
|
||||||
|
@ -447,7 +552,7 @@ void PatternEmitter::emitMatchLogic(DagNode tree) {
|
||||||
self = symbolInfoMap.getValueAndRangeUse(self);
|
self = symbolInfoMap.getValueAndRangeUse(self);
|
||||||
for (; i < 4; ++i)
|
for (; i < 4; ++i)
|
||||||
names.push_back("<unused>");
|
names.push_back("<unused>");
|
||||||
emitMatchCheck(depth,
|
emitMatchCheck(opName,
|
||||||
tgfmt(condition, &fmtCtx.withSelf(self), names[0],
|
tgfmt(condition, &fmtCtx.withSelf(self), names[0],
|
||||||
names[1], names[2], names[3]),
|
names[1], names[2], names[3]),
|
||||||
formatv("\"entities '{0}' failed to satisfy constraint: "
|
formatv("\"entities '{0}' failed to satisfy constraint: "
|
||||||
|
@ -471,7 +576,7 @@ void PatternEmitter::emitMatchLogic(DagNode tree) {
|
||||||
for (++startRange; startRange != endRange; ++startRange) {
|
for (++startRange; startRange != endRange; ++startRange) {
|
||||||
auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
|
auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
|
||||||
emitMatchCheck(
|
emitMatchCheck(
|
||||||
depth,
|
opName,
|
||||||
formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
|
formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
|
||||||
formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
|
formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
|
||||||
secondOperand));
|
secondOperand));
|
||||||
|
@ -567,7 +672,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
|
||||||
|
|
||||||
os << "// Match\n";
|
os << "// Match\n";
|
||||||
os << "tblgen_ops[0] = op0;\n";
|
os << "tblgen_ops[0] = op0;\n";
|
||||||
emitMatchLogic(sourceTree);
|
emitMatchLogic(sourceTree, "op0");
|
||||||
|
|
||||||
os << "\n// Rewrite\n";
|
os << "\n// Rewrite\n";
|
||||||
emitRewriteLogic();
|
emitRewriteLogic();
|
||||||
|
@ -681,7 +786,7 @@ std::string PatternEmitter::handleResultPattern(DagNode resultTree,
|
||||||
}
|
}
|
||||||
|
|
||||||
if (resultTree.isNativeCodeCall()) {
|
if (resultTree.isNativeCodeCall()) {
|
||||||
auto symbol = handleReplaceWithNativeCodeCall(resultTree);
|
auto symbol = handleReplaceWithNativeCodeCall(resultTree, depth);
|
||||||
symbolInfoMap.bindValue(symbol);
|
symbolInfoMap.bindValue(symbol);
|
||||||
return symbol;
|
return symbol;
|
||||||
}
|
}
|
||||||
|
@ -798,7 +903,8 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
|
||||||
PrintFatalError(loc, "unhandled case when rewriting op");
|
PrintFatalError(loc, "unhandled case when rewriting op");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
|
std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
|
||||||
|
int depth) {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
|
LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
|
||||||
LLVM_DEBUG(tree.print(llvm::dbgs()));
|
LLVM_DEBUG(tree.print(llvm::dbgs()));
|
||||||
LLVM_DEBUG(llvm::dbgs() << '\n');
|
LLVM_DEBUG(llvm::dbgs() << '\n');
|
||||||
|
@ -807,15 +913,20 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
|
||||||
// TODO: replace formatv arguments with the exact specified args.
|
// TODO: replace formatv arguments with the exact specified args.
|
||||||
SmallVector<std::string, 8> attrs(8);
|
SmallVector<std::string, 8> attrs(8);
|
||||||
if (tree.getNumArgs() > 8) {
|
if (tree.getNumArgs() > 8) {
|
||||||
PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
|
PrintFatalError(loc,
|
||||||
Twine(tree.getNumArgs()));
|
"unsupported NativeCodeCall replace argument numbers: " +
|
||||||
|
Twine(tree.getNumArgs()));
|
||||||
}
|
}
|
||||||
bool hasLocationDirective;
|
bool hasLocationDirective;
|
||||||
std::string locToUse;
|
std::string locToUse;
|
||||||
std::tie(hasLocationDirective, locToUse) = getLocation(tree);
|
std::tie(hasLocationDirective, locToUse) = getLocation(tree);
|
||||||
|
|
||||||
for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
|
for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
|
||||||
attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
|
if (tree.isNestedDagArg(i)) {
|
||||||
|
attrs[i] = handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1);
|
||||||
|
} else {
|
||||||
|
attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
|
||||||
|
}
|
||||||
LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
|
LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
|
||||||
<< " replacement: " << attrs[i] << "\n");
|
<< " replacement: " << attrs[i] << "\n");
|
||||||
}
|
}
|
||||||
|
@ -924,7 +1035,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
|
||||||
// create the ops.
|
// create the ops.
|
||||||
|
|
||||||
// First prepare local variables for op arguments used in builder call.
|
// First prepare local variables for op arguments used in builder call.
|
||||||
createAggregateLocalVarsForOpArgs(tree, childNodeNames);
|
createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
|
||||||
|
|
||||||
// Then create the op.
|
// Then create the op.
|
||||||
os.scope("", "\n}\n").os << formatv(
|
os.scope("", "\n}\n").os << formatv(
|
||||||
|
@ -948,7 +1059,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
|
||||||
|
|
||||||
os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
|
os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
|
||||||
resultOp.getQualCppClassName(), locToUse);
|
resultOp.getQualCppClassName(), locToUse);
|
||||||
supplyValuesForOpArgs(tree, childNodeNames);
|
supplyValuesForOpArgs(tree, childNodeNames, depth);
|
||||||
os << "\n );\n}\n";
|
os << "\n );\n}\n";
|
||||||
return resultValue;
|
return resultValue;
|
||||||
}
|
}
|
||||||
|
@ -959,7 +1070,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
|
||||||
// here.
|
// here.
|
||||||
|
|
||||||
// First prepare local variables for op arguments used in builder call.
|
// First prepare local variables for op arguments used in builder call.
|
||||||
createAggregateLocalVarsForOpArgs(tree, childNodeNames);
|
createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
|
||||||
|
|
||||||
// Then prepare the result types. We need to specify the types for all
|
// Then prepare the result types. We need to specify the types for all
|
||||||
// results.
|
// results.
|
||||||
|
@ -1037,7 +1148,7 @@ void PatternEmitter::createSeparateLocalVarsForOpArgs(
|
||||||
}
|
}
|
||||||
|
|
||||||
void PatternEmitter::supplyValuesForOpArgs(
|
void PatternEmitter::supplyValuesForOpArgs(
|
||||||
DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
|
DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
|
||||||
Operator &resultOp = node.getDialectOp(opMap);
|
Operator &resultOp = node.getDialectOp(opMap);
|
||||||
for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
|
for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
|
||||||
argIndex != numOpArgs; ++argIndex) {
|
argIndex != numOpArgs; ++argIndex) {
|
||||||
|
@ -1060,7 +1171,7 @@ void PatternEmitter::supplyValuesForOpArgs(
|
||||||
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
|
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
|
||||||
"for creating attribute");
|
"for creating attribute");
|
||||||
os << formatv("/*{0}=*/{1}", opArgName,
|
os << formatv("/*{0}=*/{1}", opArgName,
|
||||||
handleReplaceWithNativeCodeCall(subTree));
|
handleReplaceWithNativeCodeCall(subTree, depth));
|
||||||
} else {
|
} else {
|
||||||
auto leaf = node.getArgAsLeaf(argIndex);
|
auto leaf = node.getArgAsLeaf(argIndex);
|
||||||
// The argument in the result DAG pattern.
|
// The argument in the result DAG pattern.
|
||||||
|
@ -1080,7 +1191,7 @@ void PatternEmitter::supplyValuesForOpArgs(
|
||||||
}
|
}
|
||||||
|
|
||||||
void PatternEmitter::createAggregateLocalVarsForOpArgs(
|
void PatternEmitter::createAggregateLocalVarsForOpArgs(
|
||||||
DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
|
DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
|
||||||
Operator &resultOp = node.getDialectOp(opMap);
|
Operator &resultOp = node.getDialectOp(opMap);
|
||||||
|
|
||||||
auto scope = os.scope();
|
auto scope = os.scope();
|
||||||
|
@ -1102,7 +1213,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
|
||||||
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
|
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
|
||||||
"for creating attribute");
|
"for creating attribute");
|
||||||
os << formatv(addAttrCmd, opArgName,
|
os << formatv(addAttrCmd, opArgName,
|
||||||
handleReplaceWithNativeCodeCall(subTree));
|
handleReplaceWithNativeCodeCall(subTree, depth + 1));
|
||||||
} else {
|
} else {
|
||||||
auto leaf = node.getArgAsLeaf(argIndex);
|
auto leaf = node.getArgAsLeaf(argIndex);
|
||||||
// The argument in the result DAG pattern.
|
// The argument in the result DAG pattern.
|
||||||
|
|
Loading…
Reference in New Issue