Add tblgen::Pattern to model Patterns defined in TableGen

Similar to other tblgen:: abstractions, tblgen::Pattern hides the native TableGen
API and provides a nicer API that is more coherent with the TableGen definitions.

PiperOrigin-RevId: 231285143
This commit is contained in:
Lei Zhang 2019-01-28 14:04:40 -08:00 committed by jpienaar
parent 0fbf4ff232
commit eb753f4aec
5 changed files with 401 additions and 133 deletions

View File

@ -68,10 +68,10 @@ public:
// Op attribute accessors.
int getNumAttributes() const { return attributes.size(); }
// Returns the total number of native attributes.
int getNumNativeAttributes() const;
NamedAttribute &getAttribute(int index) { return attributes[index]; }
const NamedAttribute &getAttribute(int index) const {
return attributes[index];
}
const NamedAttribute &getAttribute(int index) const;
// Op operand iterators.
using operand_iterator = Operand *;
@ -87,6 +87,7 @@ public:
// Op argument (attribute or operand) accessors.
Argument getArg(int index);
StringRef getArgName(int index) const;
// Returns the total number of arguments.
int getNumArgs() const { return operands.size() + attributes.size(); }
// Query functions for the documentation of the operator.

View File

@ -0,0 +1,161 @@
//===- Pattern.h - Pattern wrapper class ------------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Pattern wrapper class to simplify using TableGen Record defining a MLIR
// Pattern.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_PATTERN_H_
#define MLIR_TABLEGEN_PATTERN_H_
#include "mlir/TableGen/Argument.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/TableGen/Error.h"
namespace llvm {
class Record;
class Init;
class DagInit;
class StringRef;
} // end namespace llvm
namespace mlir {
namespace tblgen {
// Mapping from TableGen Record to Operator wrapper object
using RecordOperatorMap = llvm::DenseMap<const llvm::Record *, Operator>;
// Wrapper around DAG argument.
struct DagArg {
DagArg(Argument arg, llvm::Init *constraint)
: arg(arg), constraint(constraint) {}
// Returns true if this DAG argument concerns an operation attribute.
bool isAttr() const;
Argument arg;
llvm::Init *constraint;
};
class Pattern;
// Wrapper class providing helper methods for accessing TableGen DAG constructs
// used inside Patterns. This class is lightweight and designed to be used like
// values.
//
// A TableGen DAG construct is of the syntax
// `(operator, arg0, arg1, ...)`.
//
// When used inside Patterns, `operator` corresponds to some dialect op, or
// a known list of verbs that defines special transformation actions. This
// `arg*` can be a nested DAG construct. This class provides getters to
// retrieve `operator` and `arg*` as tblgen:: wrapper objects for handy helper
// methods.
//
// A null DagNode contains a nullptr and converts to false implicitly.
class DagNode {
public:
explicit DagNode(const llvm::DagInit *node) : node(node) {}
// Implicit bool converter that returns true if this DagNode is not a null
// DagNode.
operator bool() const { return node != nullptr; }
// Returns the operator wrapper object corresponding to the dialect op matched
// by this DAG. The operator wrapper will be queried from the given `mapper`
// and created in it if not existing.
Operator &getDialectOp(RecordOperatorMap *mapper) const;
// Returns the number of operations recursively involved in the DAG tree
// rooted from this node.
unsigned getNumOps() const;
// Returns the number of immediate arguments to this DAG node.
unsigned getNumArgs() const;
// Returns true if the `index`-th argument is a nested DAG construct.
bool isNestedDagArg(unsigned index) const;
// Gets the `index`-th argument as a nested DAG construct if possible. Returns
// null DagNode otherwise.
DagNode getArgAsNestedDag(unsigned index) const;
// Gets the `index`-th argument as a TableGen DefInit* if possible. Returns
// nullptr otherwise.
// TODO: This method is exposing raw TableGen object and should be changed.
llvm::DefInit *getArgAsDefInit(unsigned index) const;
// Returns the specified name of the `index`-th argument.
llvm::StringRef getArgName(unsigned index) const;
// Collects all recursively bound arguments involved in the DAG tree rooted
// from this node.
void collectBoundArguments(Pattern *pattern) const;
// Returns true if this DAG construct means to replace with an existing SSA
// value.
bool isReplaceWithValue() const;
private:
const llvm::DagInit *node; // nullptr means null DagNode
};
// Wrapper class providing helper methods for accessing MLIR Pattern defined
// in TableGen. This class should closely reflect what is defined as class
// `Pattern` in TableGen. This class contains maps so it is not intended to be
// used as values.
class Pattern {
public:
explicit Pattern(const llvm::Record *def, RecordOperatorMap *mapper);
// Returns the source pattern to match.
DagNode getSourcePattern() const;
// Returns the number of results generated by applying this rewrite pattern.
unsigned getNumResults() const;
// Returns the DAG tree root node of the `index`-th result pattern.
DagNode getResultPattern(unsigned index) const;
// Checks whether an argument with the given `name` is bound in source
// pattern. Prints fatal error if not; does nothing otherwise.
void ensureArgBoundInSourcePattern(llvm::StringRef name) const;
// Returns a reference to all the bound arguments in the source pattern.
llvm::StringMap<DagArg> &getSourcePatternBoundArgs();
// Returns the op that the root node of the source pattern matches.
const Operator &getSourceRootOp();
// Returns the operator wrapper object corresponding to the given `node`'s DAG
// operator.
Operator &getDialectOp(DagNode node);
private:
// The TableGen definition of this pattern.
const llvm::Record &def;
RecordOperatorMap *recordOpMap; // All operators
llvm::StringMap<DagArg> boundArguments; // All bound arguments
};
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TABLEGEN_PATTERN_H_

View File

@ -53,6 +53,14 @@ std::string tblgen::Operator::qualifiedCppClassName() const {
return llvm::join(getSplitDefName(), "::");
}
int tblgen::Operator::getNumNativeAttributes() const {
return derivedAttrStart - nativeAttrStart;
}
const tblgen::NamedAttribute &tblgen::Operator::getAttribute(int index) const {
return attributes[index];
}
StringRef tblgen::Operator::getArgName(int index) const {
DagInit *argumentValues = def.getValueAsDag("arguments");
return argumentValues->getArgName(index)->getValue();

View File

@ -0,0 +1,133 @@
//===- Pattern.cpp - Pattern wrapper class ----------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Pattern wrapper class to simplify using TableGen Record defining a MLIR
// Pattern.
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Pattern.h"
#include "llvm/ADT/Twine.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using mlir::tblgen::Operator;
bool tblgen::DagArg::isAttr() const {
return arg.is<tblgen::NamedAttribute *>();
}
Operator &tblgen::DagNode::getDialectOp(RecordOperatorMap *mapper) const {
llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef();
return mapper->try_emplace(opDef, opDef).first->second;
}
unsigned tblgen::DagNode::getNumOps() const {
unsigned count = isReplaceWithValue() ? 0 : 1;
for (unsigned i = 0, e = getNumArgs(); i != e; ++i) {
if (auto child = getArgAsNestedDag(i))
count += child.getNumOps();
}
return count;
}
unsigned tblgen::DagNode::getNumArgs() const { return node->getNumArgs(); }
bool tblgen::DagNode::isNestedDagArg(unsigned index) const {
return isa<llvm::DagInit>(node->getArg(index));
}
tblgen::DagNode tblgen::DagNode::getArgAsNestedDag(unsigned index) const {
return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
}
llvm::DefInit *tblgen::DagNode::getArgAsDefInit(unsigned index) const {
return dyn_cast<llvm::DefInit>(node->getArg(index));
}
StringRef tblgen::DagNode::getArgName(unsigned index) const {
return node->getArgNameStr(index);
}
static void collectBoundArguments(const llvm::DagInit *tree,
tblgen::Pattern *pattern) {
auto &op = pattern->getDialectOp(tblgen::DagNode(tree));
// TODO(jpienaar): Expand to multiple matches.
for (unsigned i = 0, e = tree->getNumArgs(); i != e; ++i) {
auto *arg = tree->getArg(i);
if (auto *argTree = dyn_cast<llvm::DagInit>(arg)) {
collectBoundArguments(argTree, pattern);
continue;
}
StringRef name = tree->getArgNameStr(i);
if (name.empty())
continue;
pattern->getSourcePatternBoundArgs().try_emplace(name, op.getArg(i), arg);
}
}
void tblgen::DagNode::collectBoundArguments(tblgen::Pattern *pattern) const {
::collectBoundArguments(node, pattern);
}
bool tblgen::DagNode::isReplaceWithValue() const {
auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
return dagOpDef->getName() == "replaceWithValue";
}
tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
: def(*def), recordOpMap(mapper) {
getSourcePattern().collectBoundArguments(this);
}
tblgen::DagNode tblgen::Pattern::getSourcePattern() const {
return tblgen::DagNode(def.getValueAsDag("PatternToMatch"));
}
unsigned tblgen::Pattern::getNumResults() const {
auto *results = def.getValueAsListInit("ResultOps");
return results->size();
}
tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const {
auto *results = def.getValueAsListInit("ResultOps");
return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index)));
}
void tblgen::Pattern::ensureArgBoundInSourcePattern(
llvm::StringRef name) const {
if (boundArguments.find(name) == boundArguments.end())
PrintFatalError(def.getLoc(),
Twine("referencing unbound variable '") + name + "'");
}
llvm::StringMap<tblgen::DagArg> &tblgen::Pattern::getSourcePatternBoundArgs() {
return boundArguments;
}
const tblgen::Operator &tblgen::Pattern::getSourceRootOp() {
return getSourcePattern().getDialectOp(recordOpMap);
}
tblgen::Operator &tblgen::Pattern::getDialectOp(DagNode node) {
return node.getDialectOp(recordOpMap);
}

View File

@ -22,6 +22,7 @@
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/Pattern.h"
#include "mlir/TableGen/Predicate.h"
#include "mlir/TableGen/Type.h"
#include "llvm/ADT/StringExtras.h"
@ -40,9 +41,12 @@ using namespace mlir;
using mlir::tblgen::Argument;
using mlir::tblgen::Attribute;
using mlir::tblgen::DagNode;
using mlir::tblgen::NamedAttribute;
using mlir::tblgen::Operand;
using mlir::tblgen::Operator;
using mlir::tblgen::Pattern;
using mlir::tblgen::RecordOperatorMap;
using mlir::tblgen::Type;
namespace {
@ -62,102 +66,65 @@ struct DagArg {
bool DagArg::isAttr() { return arg.is<NamedAttribute *>(); }
namespace {
class Pattern {
class PatternEmitter {
public:
static void emit(StringRef rewriteName, Record *p, raw_ostream &os);
static void emit(StringRef rewriteName, Record *p, RecordOperatorMap *mapper,
raw_ostream &os);
private:
Pattern(Record *pattern, raw_ostream &os) : pattern(pattern), os(os) {}
PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os)
: loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), os(os) {}
// Emits the rewrite pattern named `rewriteName`.
// Emits the mlir::RewritePattern struct named `rewriteName`.
void emit(StringRef rewriteName);
// Emits the matcher.
void emitMatcher(DagInit *tree);
// Emits the match() method.
void emitMatchMethod(DagNode tree);
// Emits the rewrite() method.
void emitRewriteMethod();
// Emits the C++ statement to replace the matched DAG with an existing value.
void emitReplaceWithExistingValue(DagInit *resultTree);
void emitReplaceWithExistingValue(DagNode resultTree);
// Emits the C++ statement to replace the matched DAG with a new op.
void emitReplaceOpWithNewOp(DagInit *resultTree);
void emitReplaceOpWithNewOp(DagNode resultTree);
// Emits the value of constant attribute to `os`.
void emitAttributeValue(Record *constAttr);
// Collects bound arguments.
void collectBoundArguments(DagInit *tree);
// Emits C++ statements for matching the op constrained by the given DAG
// `tree`.
void emitOpMatch(DagNode tree, int depth);
// Checks whether an argument with the given `name` is bound in source
// pattern. Prints fatal error if not; does nothing otherwise.
void checkArgumentBound(StringRef name) const;
// Helper function to match patterns.
void matchOp(DagInit *tree, int depth);
// Returns the Operator stored for the given record.
Operator &getOperator(const llvm::Record *record);
// Map from bound argument name to DagArg.
StringMap<DagArg> boundArguments;
// Map from Record* to Operator.
DenseMap<const llvm::Record *, Operator> opMap;
// Number of the operations in the input pattern.
int numberOfOpsMatched = 0;
Record *pattern;
private:
// Pattern instantiation location followed by the location of multiclass
// prototypes used. This is intended to be used as a whole to
// PrintFatalError() on errors.
ArrayRef<llvm::SMLoc> loc;
// Op's TableGen Record to wrapper object
RecordOperatorMap *opMap;
// Handy wrapper for pattern being emitted
Pattern pattern;
raw_ostream &os;
};
} // end namespace
// Returns the Operator stored for the given record.
auto Pattern::getOperator(const llvm::Record *record) -> Operator & {
return opMap.try_emplace(record, record).first->second;
}
void Pattern::emitAttributeValue(Record *constAttr) {
void PatternEmitter::emitAttributeValue(Record *constAttr) {
Attribute attr(constAttr->getValueAsDef("attr"));
auto value = constAttr->getValue("value");
if (!attr.isConstBuildable())
PrintFatalError(pattern->getLoc(),
"Attribute " + attr.getTableGenDefName() +
" does not have the 'constBuilderCall' field");
PrintFatalError(loc, "Attribute " + attr.getTableGenDefName() +
" does not have the 'constBuilderCall' field");
// TODO(jpienaar): Verify the constants here
os << formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter",
value->getValue()->getAsUnquotedString());
}
void Pattern::collectBoundArguments(DagInit *tree) {
++numberOfOpsMatched;
Operator &op = getOperator(cast<DefInit>(tree->getOperator())->getDef());
// TODO(jpienaar): Expand to multiple matches.
for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
auto arg = tree->getArg(i);
if (auto argTree = dyn_cast<DagInit>(arg)) {
collectBoundArguments(argTree);
continue;
}
auto name = tree->getArgNameStr(i);
if (name.empty())
continue;
boundArguments.try_emplace(name, op.getArg(i), arg);
}
}
void Pattern::checkArgumentBound(StringRef name) const {
if (boundArguments.find(name) == boundArguments.end())
PrintFatalError(pattern->getLoc(),
Twine("referencing unbound variable '") + name + "'");
}
// Helper function to match patterns.
void Pattern::matchOp(DagInit *tree, int depth) {
Operator &op = getOperator(cast<DefInit>(tree->getOperator())->getDef());
void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
Operator &op = tree.getDialectOp(opMap);
int indent = 4 + 2 * depth;
// Skip the operand matching at depth 0 as the pattern rewriter already does.
if (depth != 0) {
@ -167,27 +134,25 @@ void Pattern::matchOp(DagInit *tree, int depth) {
"if (!op{0}->isa<{1}>()) return matchFailure();\n", depth,
op.qualifiedCppClassName());
}
if (tree->getNumArgs() != op.getNumArgs())
PrintFatalError(pattern->getLoc(),
Twine("mismatch in number of arguments to op '") +
op.getOperationName() +
"' in pattern and op's definition");
for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
auto arg = tree->getArg(i);
if (tree.getNumArgs() != op.getNumArgs())
PrintFatalError(loc, Twine("mismatch in number of arguments to op '") +
op.getOperationName() +
"' in pattern and op's definition");
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
auto opArg = op.getArg(i);
if (auto argTree = dyn_cast<DagInit>(arg)) {
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
os.indent(indent) << "{\n";
os.indent(indent + 2) << formatv(
"auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n",
depth + 1, depth, i);
matchOp(argTree, depth + 1);
emitOpMatch(argTree, depth + 1);
os.indent(indent) << "}\n";
continue;
}
// Verify arguments.
if (auto defInit = dyn_cast<DefInit>(arg)) {
if (auto defInit = tree.getArgAsDefInit(i)) {
// Verify operands.
if (auto *operand = opArg.dyn_cast<Operand *>()) {
// Skip verification where not needed due to definition of op.
@ -195,8 +160,7 @@ void Pattern::matchOp(DagInit *tree, int depth) {
goto StateCapture;
if (!defInit->getDef()->isSubClassOf("Type"))
PrintFatalError(pattern->getLoc(),
"type argument required for operand");
PrintFatalError(loc, "type argument required for operand");
auto constraint = tblgen::TypeConstraint(*defInit);
os.indent(indent)
@ -219,7 +183,7 @@ void Pattern::matchOp(DagInit *tree, int depth) {
}
StateCapture:
auto name = tree->getArgNameStr(i);
auto name = tree.getArgName(i);
if (name.empty())
continue;
if (opArg.is<Operand *>())
@ -234,7 +198,7 @@ void Pattern::matchOp(DagInit *tree, int depth) {
}
}
void Pattern::emitMatcher(DagInit *tree) {
void PatternEmitter::emitMatchMethod(DagNode tree) {
// Emit the heading.
os << R"(
PatternMatchResult match(OperationInst *op0) const override {
@ -242,28 +206,30 @@ void Pattern::emitMatcher(DagInit *tree) {
if (op0->getNumResults() != 1) return matchFailure();
auto state = std::make_unique<MatchedState>();)"
<< "\n";
matchOp(tree, 0);
emitOpMatch(tree, 0);
os.indent(4) << "return matchSuccess(std::move(state));\n }\n";
}
void Pattern::emit(StringRef rewriteName) {
DagInit *tree = pattern->getValueAsDag("PatternToMatch");
// Collect bound arguments and compute number of ops matched.
void PatternEmitter::emit(StringRef rewriteName) {
// Get the DAG tree for the source pattern
DagNode tree = pattern.getSourcePattern();
// TODO(jpienaar): the benefit metric is simply number of ops matched at the
// moment, revise.
collectBoundArguments(tree);
unsigned benefit = tree.getNumOps();
const Operator &rootOp = pattern.getSourceRootOp();
auto rootName = rootOp.getOperationName();
// Emit RewritePattern for Pattern.
DefInit *root = cast<DefInit>(tree->getOperator());
auto *rootName = cast<StringInit>(root->getDef()->getValueInit("opName"));
os << formatv(R"(struct {0} : public RewritePattern {
{0}(MLIRContext *context) : RewritePattern({1}, {2}, context) {{})",
rewriteName, rootName->getAsString(), numberOfOpsMatched)
{0}(MLIRContext *context) : RewritePattern("{1}", {2}, context) {{})",
rewriteName, rootName, benefit)
<< "\n";
// Emit matched state.
os << " struct MatchedState : public PatternState {\n";
for (auto &arg : boundArguments) {
for (const auto &arg : pattern.getSourcePatternBoundArgs()) {
if (auto namedAttr = arg.second.arg.dyn_cast<NamedAttribute *>()) {
os.indent(4) << namedAttr->attr.getStorageType() << " " << arg.first()
<< ";\n";
@ -273,23 +239,22 @@ void Pattern::emit(StringRef rewriteName) {
}
os << " };\n";
emitMatcher(tree);
emitMatchMethod(tree);
emitRewriteMethod();
os << "};\n";
}
void Pattern::emitRewriteMethod() {
ListInit *resultOps = pattern->getValueAsListInit("ResultOps");
if (resultOps->size() != 1)
void PatternEmitter::emitRewriteMethod() {
if (pattern.getNumResults() != 1)
PrintFatalError("only single result rules supported");
DagInit *resultTree = cast<DagInit>(resultOps->getElement(0));
DagNode resultTree = pattern.getResultPattern(0);
// TODO(jpienaar): Expand to multiple results.
for (auto result : resultTree->getArgs()) {
if (isa<DagInit>(result))
PrintFatalError(pattern->getLoc(), "only single op result supported");
}
for (unsigned i = 0, e = resultTree.getNumArgs(); i != e; ++i)
if (resultTree.getArgAsNestedDag(i))
PrintFatalError(loc, "only single op result supported");
os << R"(
void rewrite(OperationInst *op, std::unique_ptr<PatternState> state,
@ -297,8 +262,7 @@ void Pattern::emitRewriteMethod() {
auto& s = *static_cast<MatchedState *>(state.get());
)";
auto *dagOpDef = cast<DefInit>(resultTree->getOperator())->getDef();
if (dagOpDef->getName() == "replaceWithValue")
if (resultTree.isReplaceWithValue())
emitReplaceWithExistingValue(resultTree);
else
emitReplaceOpWithNewOp(resultTree);
@ -306,31 +270,29 @@ void Pattern::emitRewriteMethod() {
os << " }\n";
}
void Pattern::emitReplaceWithExistingValue(DagInit *resultTree) {
if (resultTree->getNumArgs() != 1) {
PrintFatalError(pattern->getLoc(),
"exactly one argument needed in the result pattern");
void PatternEmitter::emitReplaceWithExistingValue(DagNode resultTree) {
if (resultTree.getNumArgs() != 1) {
PrintFatalError(loc, "exactly one argument needed in the result pattern");
}
auto name = resultTree->getArgNameStr(0);
checkArgumentBound(name);
auto name = resultTree.getArgName(0);
pattern.ensureArgBoundInSourcePattern(name);
os.indent(4) << "rewriter.replaceOp(op, {s." << name << "});\n";
}
void Pattern::emitReplaceOpWithNewOp(DagInit *resultTree) {
DefInit *dagOperator = cast<DefInit>(resultTree->getOperator());
Operator &resultOp = getOperator(dagOperator->getDef());
auto resultOperands = dagOperator->getDef()->getValueAsDag("arguments");
void PatternEmitter::emitReplaceOpWithNewOp(DagNode resultTree) {
Operator &resultOp = resultTree.getDialectOp(opMap);
auto numOpArgs =
resultOp.getNumOperands() + resultOp.getNumNativeAttributes();
os << formatv(R"(
rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
resultOp.cppClassName());
if (resultOperands->getNumArgs() != resultTree->getNumArgs()) {
PrintFatalError(pattern->getLoc(),
Twine("mismatch between arguments of resultant op (") +
Twine(resultOperands->getNumArgs()) +
") and arguments provided for rewrite (" +
Twine(resultTree->getNumArgs()) + Twine(')'));
if (numOpArgs != resultTree.getNumArgs()) {
PrintFatalError(loc, Twine("mismatch between arguments of resultant op (") +
Twine(numOpArgs) +
") and arguments provided for rewrite (" +
Twine(resultTree.getNumArgs()) + Twine(')'));
}
// Create the builder call for the result.
@ -340,8 +302,8 @@ void Pattern::emitReplaceOpWithNewOp(DagInit *resultTree) {
// Start each operand on its own line.
(os << ",\n").indent(6);
auto name = resultTree->getArgNameStr(i);
checkArgumentBound(name);
auto name = resultTree.getArgName(i);
pattern.ensureArgBoundInSourcePattern(name);
if (operand.name)
os << "/*" << operand.name->getAsUnquotedString() << "=*/";
os << "s." << name;
@ -350,18 +312,18 @@ void Pattern::emitReplaceOpWithNewOp(DagInit *resultTree) {
}
// Add attributes.
for (int e = resultTree->getNumArgs(); i != e; ++i) {
for (int e = resultTree.getNumArgs(); i != e; ++i) {
// Start each attribute on its own line.
(os << ",\n").indent(6);
// The argument in the result DAG pattern.
auto name = resultTree->getArgNameStr(i);
auto argName = resultTree.getArgName(i);
auto opName = resultOp.getArgName(i);
auto defInit = dyn_cast<DefInit>(resultTree->getArg(i));
auto *defInit = resultTree.getArgAsDefInit(i);
auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
if (!value) {
checkArgumentBound(name);
auto result = "s." + name;
pattern.ensureArgBoundInSourcePattern(argName);
auto result = "s." + argName;
os << "/*" << opName << "=*/";
if (defInit) {
auto transform = defInit->getDef();
@ -380,31 +342,34 @@ void Pattern::emitReplaceOpWithNewOp(DagInit *resultTree) {
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
auto argument = resultOp.getArg(i);
if (!argument.is<NamedAttribute *>())
PrintFatalError(pattern->getLoc(),
Twine("expected attribute ") + Twine(i));
PrintFatalError(loc, Twine("expected attribute ") + Twine(i));
if (!name.empty())
os << "/*" << name << "=*/";
if (!argName.empty())
os << "/*" << argName << "=*/";
emitAttributeValue(defInit->getDef());
// TODO(jpienaar): verify types
}
os << "\n );\n";
}
void Pattern::emit(StringRef rewriteName, Record *p, raw_ostream &os) {
Pattern pattern(p, os);
pattern.emit(rewriteName);
void PatternEmitter::emit(StringRef rewriteName, Record *p,
RecordOperatorMap *mapper, raw_ostream &os) {
PatternEmitter(p, mapper, os).emit(rewriteName);
}
static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
emitSourceFileHeader("Rewriters", os);
const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
// We put the map here because it can be shared among multiple patterns.
RecordOperatorMap recordOpMap;
// Ensure unique patterns simply by appending unique suffix.
std::string baseRewriteName = "GeneratedConvert";
int rewritePatternCount = 0;
for (Record *p : patterns) {
Pattern::emit(baseRewriteName + llvm::utostr(rewritePatternCount++), p, os);
PatternEmitter::emit(baseRewriteName + llvm::utostr(rewritePatternCount++),
p, &recordOpMap, os);
}
// Emit function to add the generated matchers to the pattern list.