forked from OSchip/llvm-project
Add tblgen::Pattern to model Patterns defined in TableGen
Similar to other tblgen:: abstractions, tblgen::Pattern hides the native TableGen API and provides a nicer API that is more coherent with the TableGen definitions. PiperOrigin-RevId: 231285143
This commit is contained in:
parent
0fbf4ff232
commit
eb753f4aec
|
@ -68,10 +68,10 @@ public:
|
|||
|
||||
// Op attribute accessors.
|
||||
int getNumAttributes() const { return attributes.size(); }
|
||||
// Returns the total number of native attributes.
|
||||
int getNumNativeAttributes() const;
|
||||
NamedAttribute &getAttribute(int index) { return attributes[index]; }
|
||||
const NamedAttribute &getAttribute(int index) const {
|
||||
return attributes[index];
|
||||
}
|
||||
const NamedAttribute &getAttribute(int index) const;
|
||||
|
||||
// Op operand iterators.
|
||||
using operand_iterator = Operand *;
|
||||
|
@ -87,6 +87,7 @@ public:
|
|||
// Op argument (attribute or operand) accessors.
|
||||
Argument getArg(int index);
|
||||
StringRef getArgName(int index) const;
|
||||
// Returns the total number of arguments.
|
||||
int getNumArgs() const { return operands.size() + attributes.size(); }
|
||||
|
||||
// Query functions for the documentation of the operator.
|
||||
|
|
|
@ -0,0 +1,161 @@
|
|||
//===- Pattern.h - Pattern wrapper class ------------------------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// Pattern wrapper class to simplify using TableGen Record defining a MLIR
|
||||
// Pattern.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TABLEGEN_PATTERN_H_
|
||||
#define MLIR_TABLEGEN_PATTERN_H_
|
||||
|
||||
#include "mlir/TableGen/Argument.h"
|
||||
#include "mlir/TableGen/Operator.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "llvm/TableGen/Error.h"
|
||||
|
||||
namespace llvm {
|
||||
class Record;
|
||||
class Init;
|
||||
class DagInit;
|
||||
class StringRef;
|
||||
} // end namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
namespace tblgen {
|
||||
|
||||
// Mapping from TableGen Record to Operator wrapper object
|
||||
using RecordOperatorMap = llvm::DenseMap<const llvm::Record *, Operator>;
|
||||
|
||||
// Wrapper around DAG argument.
|
||||
struct DagArg {
|
||||
DagArg(Argument arg, llvm::Init *constraint)
|
||||
: arg(arg), constraint(constraint) {}
|
||||
|
||||
// Returns true if this DAG argument concerns an operation attribute.
|
||||
bool isAttr() const;
|
||||
|
||||
Argument arg;
|
||||
llvm::Init *constraint;
|
||||
};
|
||||
|
||||
class Pattern;
|
||||
|
||||
// Wrapper class providing helper methods for accessing TableGen DAG constructs
|
||||
// used inside Patterns. This class is lightweight and designed to be used like
|
||||
// values.
|
||||
//
|
||||
// A TableGen DAG construct is of the syntax
|
||||
// `(operator, arg0, arg1, ...)`.
|
||||
//
|
||||
// When used inside Patterns, `operator` corresponds to some dialect op, or
|
||||
// a known list of verbs that defines special transformation actions. This
|
||||
// `arg*` can be a nested DAG construct. This class provides getters to
|
||||
// retrieve `operator` and `arg*` as tblgen:: wrapper objects for handy helper
|
||||
// methods.
|
||||
//
|
||||
// A null DagNode contains a nullptr and converts to false implicitly.
|
||||
class DagNode {
|
||||
public:
|
||||
explicit DagNode(const llvm::DagInit *node) : node(node) {}
|
||||
|
||||
// Implicit bool converter that returns true if this DagNode is not a null
|
||||
// DagNode.
|
||||
operator bool() const { return node != nullptr; }
|
||||
|
||||
// Returns the operator wrapper object corresponding to the dialect op matched
|
||||
// by this DAG. The operator wrapper will be queried from the given `mapper`
|
||||
// and created in it if not existing.
|
||||
Operator &getDialectOp(RecordOperatorMap *mapper) const;
|
||||
|
||||
// Returns the number of operations recursively involved in the DAG tree
|
||||
// rooted from this node.
|
||||
unsigned getNumOps() const;
|
||||
|
||||
// Returns the number of immediate arguments to this DAG node.
|
||||
unsigned getNumArgs() const;
|
||||
|
||||
// Returns true if the `index`-th argument is a nested DAG construct.
|
||||
bool isNestedDagArg(unsigned index) const;
|
||||
|
||||
// Gets the `index`-th argument as a nested DAG construct if possible. Returns
|
||||
// null DagNode otherwise.
|
||||
DagNode getArgAsNestedDag(unsigned index) const;
|
||||
// Gets the `index`-th argument as a TableGen DefInit* if possible. Returns
|
||||
// nullptr otherwise.
|
||||
// TODO: This method is exposing raw TableGen object and should be changed.
|
||||
llvm::DefInit *getArgAsDefInit(unsigned index) const;
|
||||
|
||||
// Returns the specified name of the `index`-th argument.
|
||||
llvm::StringRef getArgName(unsigned index) const;
|
||||
|
||||
// Collects all recursively bound arguments involved in the DAG tree rooted
|
||||
// from this node.
|
||||
void collectBoundArguments(Pattern *pattern) const;
|
||||
|
||||
// Returns true if this DAG construct means to replace with an existing SSA
|
||||
// value.
|
||||
bool isReplaceWithValue() const;
|
||||
|
||||
private:
|
||||
const llvm::DagInit *node; // nullptr means null DagNode
|
||||
};
|
||||
|
||||
// Wrapper class providing helper methods for accessing MLIR Pattern defined
|
||||
// in TableGen. This class should closely reflect what is defined as class
|
||||
// `Pattern` in TableGen. This class contains maps so it is not intended to be
|
||||
// used as values.
|
||||
class Pattern {
|
||||
public:
|
||||
explicit Pattern(const llvm::Record *def, RecordOperatorMap *mapper);
|
||||
|
||||
// Returns the source pattern to match.
|
||||
DagNode getSourcePattern() const;
|
||||
|
||||
// Returns the number of results generated by applying this rewrite pattern.
|
||||
unsigned getNumResults() const;
|
||||
|
||||
// Returns the DAG tree root node of the `index`-th result pattern.
|
||||
DagNode getResultPattern(unsigned index) const;
|
||||
|
||||
// Checks whether an argument with the given `name` is bound in source
|
||||
// pattern. Prints fatal error if not; does nothing otherwise.
|
||||
void ensureArgBoundInSourcePattern(llvm::StringRef name) const;
|
||||
|
||||
// Returns a reference to all the bound arguments in the source pattern.
|
||||
llvm::StringMap<DagArg> &getSourcePatternBoundArgs();
|
||||
|
||||
// Returns the op that the root node of the source pattern matches.
|
||||
const Operator &getSourceRootOp();
|
||||
|
||||
// Returns the operator wrapper object corresponding to the given `node`'s DAG
|
||||
// operator.
|
||||
Operator &getDialectOp(DagNode node);
|
||||
|
||||
private:
|
||||
// The TableGen definition of this pattern.
|
||||
const llvm::Record &def;
|
||||
|
||||
RecordOperatorMap *recordOpMap; // All operators
|
||||
llvm::StringMap<DagArg> boundArguments; // All bound arguments
|
||||
};
|
||||
|
||||
} // end namespace tblgen
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TABLEGEN_PATTERN_H_
|
|
@ -53,6 +53,14 @@ std::string tblgen::Operator::qualifiedCppClassName() const {
|
|||
return llvm::join(getSplitDefName(), "::");
|
||||
}
|
||||
|
||||
int tblgen::Operator::getNumNativeAttributes() const {
|
||||
return derivedAttrStart - nativeAttrStart;
|
||||
}
|
||||
|
||||
const tblgen::NamedAttribute &tblgen::Operator::getAttribute(int index) const {
|
||||
return attributes[index];
|
||||
}
|
||||
|
||||
StringRef tblgen::Operator::getArgName(int index) const {
|
||||
DagInit *argumentValues = def.getValueAsDag("arguments");
|
||||
return argumentValues->getArgName(index)->getValue();
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
//===- Pattern.cpp - Pattern wrapper class ----------------------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// Pattern wrapper class to simplify using TableGen Record defining a MLIR
|
||||
// Pattern.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/TableGen/Pattern.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
using mlir::tblgen::Operator;
|
||||
|
||||
bool tblgen::DagArg::isAttr() const {
|
||||
return arg.is<tblgen::NamedAttribute *>();
|
||||
}
|
||||
|
||||
Operator &tblgen::DagNode::getDialectOp(RecordOperatorMap *mapper) const {
|
||||
llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef();
|
||||
return mapper->try_emplace(opDef, opDef).first->second;
|
||||
}
|
||||
|
||||
unsigned tblgen::DagNode::getNumOps() const {
|
||||
unsigned count = isReplaceWithValue() ? 0 : 1;
|
||||
for (unsigned i = 0, e = getNumArgs(); i != e; ++i) {
|
||||
if (auto child = getArgAsNestedDag(i))
|
||||
count += child.getNumOps();
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
unsigned tblgen::DagNode::getNumArgs() const { return node->getNumArgs(); }
|
||||
|
||||
bool tblgen::DagNode::isNestedDagArg(unsigned index) const {
|
||||
return isa<llvm::DagInit>(node->getArg(index));
|
||||
}
|
||||
|
||||
tblgen::DagNode tblgen::DagNode::getArgAsNestedDag(unsigned index) const {
|
||||
return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
|
||||
}
|
||||
|
||||
llvm::DefInit *tblgen::DagNode::getArgAsDefInit(unsigned index) const {
|
||||
return dyn_cast<llvm::DefInit>(node->getArg(index));
|
||||
}
|
||||
|
||||
StringRef tblgen::DagNode::getArgName(unsigned index) const {
|
||||
return node->getArgNameStr(index);
|
||||
}
|
||||
|
||||
static void collectBoundArguments(const llvm::DagInit *tree,
|
||||
tblgen::Pattern *pattern) {
|
||||
auto &op = pattern->getDialectOp(tblgen::DagNode(tree));
|
||||
|
||||
// TODO(jpienaar): Expand to multiple matches.
|
||||
for (unsigned i = 0, e = tree->getNumArgs(); i != e; ++i) {
|
||||
auto *arg = tree->getArg(i);
|
||||
|
||||
if (auto *argTree = dyn_cast<llvm::DagInit>(arg)) {
|
||||
collectBoundArguments(argTree, pattern);
|
||||
continue;
|
||||
}
|
||||
|
||||
StringRef name = tree->getArgNameStr(i);
|
||||
if (name.empty())
|
||||
continue;
|
||||
|
||||
pattern->getSourcePatternBoundArgs().try_emplace(name, op.getArg(i), arg);
|
||||
}
|
||||
}
|
||||
|
||||
void tblgen::DagNode::collectBoundArguments(tblgen::Pattern *pattern) const {
|
||||
::collectBoundArguments(node, pattern);
|
||||
}
|
||||
|
||||
bool tblgen::DagNode::isReplaceWithValue() const {
|
||||
auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
|
||||
return dagOpDef->getName() == "replaceWithValue";
|
||||
}
|
||||
|
||||
tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
|
||||
: def(*def), recordOpMap(mapper) {
|
||||
getSourcePattern().collectBoundArguments(this);
|
||||
}
|
||||
|
||||
tblgen::DagNode tblgen::Pattern::getSourcePattern() const {
|
||||
return tblgen::DagNode(def.getValueAsDag("PatternToMatch"));
|
||||
}
|
||||
|
||||
unsigned tblgen::Pattern::getNumResults() const {
|
||||
auto *results = def.getValueAsListInit("ResultOps");
|
||||
return results->size();
|
||||
}
|
||||
|
||||
tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const {
|
||||
auto *results = def.getValueAsListInit("ResultOps");
|
||||
return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index)));
|
||||
}
|
||||
|
||||
void tblgen::Pattern::ensureArgBoundInSourcePattern(
|
||||
llvm::StringRef name) const {
|
||||
if (boundArguments.find(name) == boundArguments.end())
|
||||
PrintFatalError(def.getLoc(),
|
||||
Twine("referencing unbound variable '") + name + "'");
|
||||
}
|
||||
|
||||
llvm::StringMap<tblgen::DagArg> &tblgen::Pattern::getSourcePatternBoundArgs() {
|
||||
return boundArguments;
|
||||
}
|
||||
|
||||
const tblgen::Operator &tblgen::Pattern::getSourceRootOp() {
|
||||
return getSourcePattern().getDialectOp(recordOpMap);
|
||||
}
|
||||
|
||||
tblgen::Operator &tblgen::Pattern::getDialectOp(DagNode node) {
|
||||
return node.getDialectOp(recordOpMap);
|
||||
}
|
|
@ -22,6 +22,7 @@
|
|||
#include "mlir/TableGen/Attribute.h"
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
#include "mlir/TableGen/Operator.h"
|
||||
#include "mlir/TableGen/Pattern.h"
|
||||
#include "mlir/TableGen/Predicate.h"
|
||||
#include "mlir/TableGen/Type.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
|
@ -40,9 +41,12 @@ using namespace mlir;
|
|||
|
||||
using mlir::tblgen::Argument;
|
||||
using mlir::tblgen::Attribute;
|
||||
using mlir::tblgen::DagNode;
|
||||
using mlir::tblgen::NamedAttribute;
|
||||
using mlir::tblgen::Operand;
|
||||
using mlir::tblgen::Operator;
|
||||
using mlir::tblgen::Pattern;
|
||||
using mlir::tblgen::RecordOperatorMap;
|
||||
using mlir::tblgen::Type;
|
||||
|
||||
namespace {
|
||||
|
@ -62,102 +66,65 @@ struct DagArg {
|
|||
bool DagArg::isAttr() { return arg.is<NamedAttribute *>(); }
|
||||
|
||||
namespace {
|
||||
class Pattern {
|
||||
class PatternEmitter {
|
||||
public:
|
||||
static void emit(StringRef rewriteName, Record *p, raw_ostream &os);
|
||||
static void emit(StringRef rewriteName, Record *p, RecordOperatorMap *mapper,
|
||||
raw_ostream &os);
|
||||
|
||||
private:
|
||||
Pattern(Record *pattern, raw_ostream &os) : pattern(pattern), os(os) {}
|
||||
PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os)
|
||||
: loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), os(os) {}
|
||||
|
||||
// Emits the rewrite pattern named `rewriteName`.
|
||||
// Emits the mlir::RewritePattern struct named `rewriteName`.
|
||||
void emit(StringRef rewriteName);
|
||||
|
||||
// Emits the matcher.
|
||||
void emitMatcher(DagInit *tree);
|
||||
// Emits the match() method.
|
||||
void emitMatchMethod(DagNode tree);
|
||||
|
||||
// Emits the rewrite() method.
|
||||
void emitRewriteMethod();
|
||||
|
||||
// Emits the C++ statement to replace the matched DAG with an existing value.
|
||||
void emitReplaceWithExistingValue(DagInit *resultTree);
|
||||
void emitReplaceWithExistingValue(DagNode resultTree);
|
||||
// Emits the C++ statement to replace the matched DAG with a new op.
|
||||
void emitReplaceOpWithNewOp(DagInit *resultTree);
|
||||
void emitReplaceOpWithNewOp(DagNode resultTree);
|
||||
|
||||
// Emits the value of constant attribute to `os`.
|
||||
void emitAttributeValue(Record *constAttr);
|
||||
|
||||
// Collects bound arguments.
|
||||
void collectBoundArguments(DagInit *tree);
|
||||
// Emits C++ statements for matching the op constrained by the given DAG
|
||||
// `tree`.
|
||||
void emitOpMatch(DagNode tree, int depth);
|
||||
|
||||
// Checks whether an argument with the given `name` is bound in source
|
||||
// pattern. Prints fatal error if not; does nothing otherwise.
|
||||
void checkArgumentBound(StringRef name) const;
|
||||
|
||||
// Helper function to match patterns.
|
||||
void matchOp(DagInit *tree, int depth);
|
||||
|
||||
// Returns the Operator stored for the given record.
|
||||
Operator &getOperator(const llvm::Record *record);
|
||||
|
||||
// Map from bound argument name to DagArg.
|
||||
StringMap<DagArg> boundArguments;
|
||||
|
||||
// Map from Record* to Operator.
|
||||
DenseMap<const llvm::Record *, Operator> opMap;
|
||||
|
||||
// Number of the operations in the input pattern.
|
||||
int numberOfOpsMatched = 0;
|
||||
|
||||
Record *pattern;
|
||||
private:
|
||||
// Pattern instantiation location followed by the location of multiclass
|
||||
// prototypes used. This is intended to be used as a whole to
|
||||
// PrintFatalError() on errors.
|
||||
ArrayRef<llvm::SMLoc> loc;
|
||||
// Op's TableGen Record to wrapper object
|
||||
RecordOperatorMap *opMap;
|
||||
// Handy wrapper for pattern being emitted
|
||||
Pattern pattern;
|
||||
raw_ostream &os;
|
||||
};
|
||||
} // end namespace
|
||||
|
||||
// Returns the Operator stored for the given record.
|
||||
auto Pattern::getOperator(const llvm::Record *record) -> Operator & {
|
||||
return opMap.try_emplace(record, record).first->second;
|
||||
}
|
||||
|
||||
void Pattern::emitAttributeValue(Record *constAttr) {
|
||||
void PatternEmitter::emitAttributeValue(Record *constAttr) {
|
||||
Attribute attr(constAttr->getValueAsDef("attr"));
|
||||
auto value = constAttr->getValue("value");
|
||||
|
||||
if (!attr.isConstBuildable())
|
||||
PrintFatalError(pattern->getLoc(),
|
||||
"Attribute " + attr.getTableGenDefName() +
|
||||
" does not have the 'constBuilderCall' field");
|
||||
PrintFatalError(loc, "Attribute " + attr.getTableGenDefName() +
|
||||
" does not have the 'constBuilderCall' field");
|
||||
|
||||
// TODO(jpienaar): Verify the constants here
|
||||
os << formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter",
|
||||
value->getValue()->getAsUnquotedString());
|
||||
}
|
||||
|
||||
void Pattern::collectBoundArguments(DagInit *tree) {
|
||||
++numberOfOpsMatched;
|
||||
Operator &op = getOperator(cast<DefInit>(tree->getOperator())->getDef());
|
||||
// TODO(jpienaar): Expand to multiple matches.
|
||||
for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
|
||||
auto arg = tree->getArg(i);
|
||||
if (auto argTree = dyn_cast<DagInit>(arg)) {
|
||||
collectBoundArguments(argTree);
|
||||
continue;
|
||||
}
|
||||
auto name = tree->getArgNameStr(i);
|
||||
if (name.empty())
|
||||
continue;
|
||||
boundArguments.try_emplace(name, op.getArg(i), arg);
|
||||
}
|
||||
}
|
||||
|
||||
void Pattern::checkArgumentBound(StringRef name) const {
|
||||
if (boundArguments.find(name) == boundArguments.end())
|
||||
PrintFatalError(pattern->getLoc(),
|
||||
Twine("referencing unbound variable '") + name + "'");
|
||||
}
|
||||
|
||||
// Helper function to match patterns.
|
||||
void Pattern::matchOp(DagInit *tree, int depth) {
|
||||
Operator &op = getOperator(cast<DefInit>(tree->getOperator())->getDef());
|
||||
void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
|
||||
Operator &op = tree.getDialectOp(opMap);
|
||||
int indent = 4 + 2 * depth;
|
||||
// Skip the operand matching at depth 0 as the pattern rewriter already does.
|
||||
if (depth != 0) {
|
||||
|
@ -167,27 +134,25 @@ void Pattern::matchOp(DagInit *tree, int depth) {
|
|||
"if (!op{0}->isa<{1}>()) return matchFailure();\n", depth,
|
||||
op.qualifiedCppClassName());
|
||||
}
|
||||
if (tree->getNumArgs() != op.getNumArgs())
|
||||
PrintFatalError(pattern->getLoc(),
|
||||
Twine("mismatch in number of arguments to op '") +
|
||||
op.getOperationName() +
|
||||
"' in pattern and op's definition");
|
||||
for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
|
||||
auto arg = tree->getArg(i);
|
||||
if (tree.getNumArgs() != op.getNumArgs())
|
||||
PrintFatalError(loc, Twine("mismatch in number of arguments to op '") +
|
||||
op.getOperationName() +
|
||||
"' in pattern and op's definition");
|
||||
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
|
||||
auto opArg = op.getArg(i);
|
||||
|
||||
if (auto argTree = dyn_cast<DagInit>(arg)) {
|
||||
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
|
||||
os.indent(indent) << "{\n";
|
||||
os.indent(indent + 2) << formatv(
|
||||
"auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n",
|
||||
depth + 1, depth, i);
|
||||
matchOp(argTree, depth + 1);
|
||||
emitOpMatch(argTree, depth + 1);
|
||||
os.indent(indent) << "}\n";
|
||||
continue;
|
||||
}
|
||||
|
||||
// Verify arguments.
|
||||
if (auto defInit = dyn_cast<DefInit>(arg)) {
|
||||
if (auto defInit = tree.getArgAsDefInit(i)) {
|
||||
// Verify operands.
|
||||
if (auto *operand = opArg.dyn_cast<Operand *>()) {
|
||||
// Skip verification where not needed due to definition of op.
|
||||
|
@ -195,8 +160,7 @@ void Pattern::matchOp(DagInit *tree, int depth) {
|
|||
goto StateCapture;
|
||||
|
||||
if (!defInit->getDef()->isSubClassOf("Type"))
|
||||
PrintFatalError(pattern->getLoc(),
|
||||
"type argument required for operand");
|
||||
PrintFatalError(loc, "type argument required for operand");
|
||||
|
||||
auto constraint = tblgen::TypeConstraint(*defInit);
|
||||
os.indent(indent)
|
||||
|
@ -219,7 +183,7 @@ void Pattern::matchOp(DagInit *tree, int depth) {
|
|||
}
|
||||
|
||||
StateCapture:
|
||||
auto name = tree->getArgNameStr(i);
|
||||
auto name = tree.getArgName(i);
|
||||
if (name.empty())
|
||||
continue;
|
||||
if (opArg.is<Operand *>())
|
||||
|
@ -234,7 +198,7 @@ void Pattern::matchOp(DagInit *tree, int depth) {
|
|||
}
|
||||
}
|
||||
|
||||
void Pattern::emitMatcher(DagInit *tree) {
|
||||
void PatternEmitter::emitMatchMethod(DagNode tree) {
|
||||
// Emit the heading.
|
||||
os << R"(
|
||||
PatternMatchResult match(OperationInst *op0) const override {
|
||||
|
@ -242,28 +206,30 @@ void Pattern::emitMatcher(DagInit *tree) {
|
|||
if (op0->getNumResults() != 1) return matchFailure();
|
||||
auto state = std::make_unique<MatchedState>();)"
|
||||
<< "\n";
|
||||
matchOp(tree, 0);
|
||||
emitOpMatch(tree, 0);
|
||||
os.indent(4) << "return matchSuccess(std::move(state));\n }\n";
|
||||
}
|
||||
|
||||
void Pattern::emit(StringRef rewriteName) {
|
||||
DagInit *tree = pattern->getValueAsDag("PatternToMatch");
|
||||
// Collect bound arguments and compute number of ops matched.
|
||||
void PatternEmitter::emit(StringRef rewriteName) {
|
||||
// Get the DAG tree for the source pattern
|
||||
DagNode tree = pattern.getSourcePattern();
|
||||
|
||||
// TODO(jpienaar): the benefit metric is simply number of ops matched at the
|
||||
// moment, revise.
|
||||
collectBoundArguments(tree);
|
||||
unsigned benefit = tree.getNumOps();
|
||||
|
||||
const Operator &rootOp = pattern.getSourceRootOp();
|
||||
auto rootName = rootOp.getOperationName();
|
||||
|
||||
// Emit RewritePattern for Pattern.
|
||||
DefInit *root = cast<DefInit>(tree->getOperator());
|
||||
auto *rootName = cast<StringInit>(root->getDef()->getValueInit("opName"));
|
||||
os << formatv(R"(struct {0} : public RewritePattern {
|
||||
{0}(MLIRContext *context) : RewritePattern({1}, {2}, context) {{})",
|
||||
rewriteName, rootName->getAsString(), numberOfOpsMatched)
|
||||
{0}(MLIRContext *context) : RewritePattern("{1}", {2}, context) {{})",
|
||||
rewriteName, rootName, benefit)
|
||||
<< "\n";
|
||||
|
||||
// Emit matched state.
|
||||
os << " struct MatchedState : public PatternState {\n";
|
||||
for (auto &arg : boundArguments) {
|
||||
for (const auto &arg : pattern.getSourcePatternBoundArgs()) {
|
||||
if (auto namedAttr = arg.second.arg.dyn_cast<NamedAttribute *>()) {
|
||||
os.indent(4) << namedAttr->attr.getStorageType() << " " << arg.first()
|
||||
<< ";\n";
|
||||
|
@ -273,23 +239,22 @@ void Pattern::emit(StringRef rewriteName) {
|
|||
}
|
||||
os << " };\n";
|
||||
|
||||
emitMatcher(tree);
|
||||
emitMatchMethod(tree);
|
||||
emitRewriteMethod();
|
||||
|
||||
os << "};\n";
|
||||
}
|
||||
|
||||
void Pattern::emitRewriteMethod() {
|
||||
ListInit *resultOps = pattern->getValueAsListInit("ResultOps");
|
||||
if (resultOps->size() != 1)
|
||||
void PatternEmitter::emitRewriteMethod() {
|
||||
if (pattern.getNumResults() != 1)
|
||||
PrintFatalError("only single result rules supported");
|
||||
DagInit *resultTree = cast<DagInit>(resultOps->getElement(0));
|
||||
|
||||
DagNode resultTree = pattern.getResultPattern(0);
|
||||
|
||||
// TODO(jpienaar): Expand to multiple results.
|
||||
for (auto result : resultTree->getArgs()) {
|
||||
if (isa<DagInit>(result))
|
||||
PrintFatalError(pattern->getLoc(), "only single op result supported");
|
||||
}
|
||||
for (unsigned i = 0, e = resultTree.getNumArgs(); i != e; ++i)
|
||||
if (resultTree.getArgAsNestedDag(i))
|
||||
PrintFatalError(loc, "only single op result supported");
|
||||
|
||||
os << R"(
|
||||
void rewrite(OperationInst *op, std::unique_ptr<PatternState> state,
|
||||
|
@ -297,8 +262,7 @@ void Pattern::emitRewriteMethod() {
|
|||
auto& s = *static_cast<MatchedState *>(state.get());
|
||||
)";
|
||||
|
||||
auto *dagOpDef = cast<DefInit>(resultTree->getOperator())->getDef();
|
||||
if (dagOpDef->getName() == "replaceWithValue")
|
||||
if (resultTree.isReplaceWithValue())
|
||||
emitReplaceWithExistingValue(resultTree);
|
||||
else
|
||||
emitReplaceOpWithNewOp(resultTree);
|
||||
|
@ -306,31 +270,29 @@ void Pattern::emitRewriteMethod() {
|
|||
os << " }\n";
|
||||
}
|
||||
|
||||
void Pattern::emitReplaceWithExistingValue(DagInit *resultTree) {
|
||||
if (resultTree->getNumArgs() != 1) {
|
||||
PrintFatalError(pattern->getLoc(),
|
||||
"exactly one argument needed in the result pattern");
|
||||
void PatternEmitter::emitReplaceWithExistingValue(DagNode resultTree) {
|
||||
if (resultTree.getNumArgs() != 1) {
|
||||
PrintFatalError(loc, "exactly one argument needed in the result pattern");
|
||||
}
|
||||
|
||||
auto name = resultTree->getArgNameStr(0);
|
||||
checkArgumentBound(name);
|
||||
auto name = resultTree.getArgName(0);
|
||||
pattern.ensureArgBoundInSourcePattern(name);
|
||||
os.indent(4) << "rewriter.replaceOp(op, {s." << name << "});\n";
|
||||
}
|
||||
|
||||
void Pattern::emitReplaceOpWithNewOp(DagInit *resultTree) {
|
||||
DefInit *dagOperator = cast<DefInit>(resultTree->getOperator());
|
||||
Operator &resultOp = getOperator(dagOperator->getDef());
|
||||
auto resultOperands = dagOperator->getDef()->getValueAsDag("arguments");
|
||||
void PatternEmitter::emitReplaceOpWithNewOp(DagNode resultTree) {
|
||||
Operator &resultOp = resultTree.getDialectOp(opMap);
|
||||
auto numOpArgs =
|
||||
resultOp.getNumOperands() + resultOp.getNumNativeAttributes();
|
||||
|
||||
os << formatv(R"(
|
||||
rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
|
||||
resultOp.cppClassName());
|
||||
if (resultOperands->getNumArgs() != resultTree->getNumArgs()) {
|
||||
PrintFatalError(pattern->getLoc(),
|
||||
Twine("mismatch between arguments of resultant op (") +
|
||||
Twine(resultOperands->getNumArgs()) +
|
||||
") and arguments provided for rewrite (" +
|
||||
Twine(resultTree->getNumArgs()) + Twine(')'));
|
||||
if (numOpArgs != resultTree.getNumArgs()) {
|
||||
PrintFatalError(loc, Twine("mismatch between arguments of resultant op (") +
|
||||
Twine(numOpArgs) +
|
||||
") and arguments provided for rewrite (" +
|
||||
Twine(resultTree.getNumArgs()) + Twine(')'));
|
||||
}
|
||||
|
||||
// Create the builder call for the result.
|
||||
|
@ -340,8 +302,8 @@ void Pattern::emitReplaceOpWithNewOp(DagInit *resultTree) {
|
|||
// Start each operand on its own line.
|
||||
(os << ",\n").indent(6);
|
||||
|
||||
auto name = resultTree->getArgNameStr(i);
|
||||
checkArgumentBound(name);
|
||||
auto name = resultTree.getArgName(i);
|
||||
pattern.ensureArgBoundInSourcePattern(name);
|
||||
if (operand.name)
|
||||
os << "/*" << operand.name->getAsUnquotedString() << "=*/";
|
||||
os << "s." << name;
|
||||
|
@ -350,18 +312,18 @@ void Pattern::emitReplaceOpWithNewOp(DagInit *resultTree) {
|
|||
}
|
||||
|
||||
// Add attributes.
|
||||
for (int e = resultTree->getNumArgs(); i != e; ++i) {
|
||||
for (int e = resultTree.getNumArgs(); i != e; ++i) {
|
||||
// Start each attribute on its own line.
|
||||
(os << ",\n").indent(6);
|
||||
|
||||
// The argument in the result DAG pattern.
|
||||
auto name = resultTree->getArgNameStr(i);
|
||||
auto argName = resultTree.getArgName(i);
|
||||
auto opName = resultOp.getArgName(i);
|
||||
auto defInit = dyn_cast<DefInit>(resultTree->getArg(i));
|
||||
auto *defInit = resultTree.getArgAsDefInit(i);
|
||||
auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
|
||||
if (!value) {
|
||||
checkArgumentBound(name);
|
||||
auto result = "s." + name;
|
||||
pattern.ensureArgBoundInSourcePattern(argName);
|
||||
auto result = "s." + argName;
|
||||
os << "/*" << opName << "=*/";
|
||||
if (defInit) {
|
||||
auto transform = defInit->getDef();
|
||||
|
@ -380,31 +342,34 @@ void Pattern::emitReplaceOpWithNewOp(DagInit *resultTree) {
|
|||
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
|
||||
auto argument = resultOp.getArg(i);
|
||||
if (!argument.is<NamedAttribute *>())
|
||||
PrintFatalError(pattern->getLoc(),
|
||||
Twine("expected attribute ") + Twine(i));
|
||||
PrintFatalError(loc, Twine("expected attribute ") + Twine(i));
|
||||
|
||||
if (!name.empty())
|
||||
os << "/*" << name << "=*/";
|
||||
if (!argName.empty())
|
||||
os << "/*" << argName << "=*/";
|
||||
emitAttributeValue(defInit->getDef());
|
||||
// TODO(jpienaar): verify types
|
||||
}
|
||||
os << "\n );\n";
|
||||
}
|
||||
|
||||
void Pattern::emit(StringRef rewriteName, Record *p, raw_ostream &os) {
|
||||
Pattern pattern(p, os);
|
||||
pattern.emit(rewriteName);
|
||||
void PatternEmitter::emit(StringRef rewriteName, Record *p,
|
||||
RecordOperatorMap *mapper, raw_ostream &os) {
|
||||
PatternEmitter(p, mapper, os).emit(rewriteName);
|
||||
}
|
||||
|
||||
static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||
emitSourceFileHeader("Rewriters", os);
|
||||
const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
|
||||
|
||||
// We put the map here because it can be shared among multiple patterns.
|
||||
RecordOperatorMap recordOpMap;
|
||||
|
||||
// Ensure unique patterns simply by appending unique suffix.
|
||||
std::string baseRewriteName = "GeneratedConvert";
|
||||
int rewritePatternCount = 0;
|
||||
for (Record *p : patterns) {
|
||||
Pattern::emit(baseRewriteName + llvm::utostr(rewritePatternCount++), p, os);
|
||||
PatternEmitter::emit(baseRewriteName + llvm::utostr(rewritePatternCount++),
|
||||
p, &recordOpMap, os);
|
||||
}
|
||||
|
||||
// Emit function to add the generated matchers to the pattern list.
|
||||
|
|
Loading…
Reference in New Issue