2019-02-18 23:21:12 +08:00
|
|
|
//===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
|
2018-12-12 19:09:11 +08:00
|
|
|
//
|
|
|
|
// Copyright 2019 The MLIR Authors.
|
|
|
|
//
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
//
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
//
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
// =============================================================================
|
|
|
|
//
|
2018-12-27 20:56:03 +08:00
|
|
|
// RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
|
2018-12-12 19:09:11 +08:00
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-02-14 06:30:40 +08:00
|
|
|
#include "mlir/Support/STLExtras.h"
|
2019-01-17 01:23:14 +08:00
|
|
|
#include "mlir/TableGen/Attribute.h"
|
2019-04-12 21:05:49 +08:00
|
|
|
#include "mlir/TableGen/Format.h"
|
2018-12-27 20:56:03 +08:00
|
|
|
#include "mlir/TableGen/GenInfo.h"
|
2018-12-29 04:02:08 +08:00
|
|
|
#include "mlir/TableGen/Operator.h"
|
2019-01-29 06:04:40 +08:00
|
|
|
#include "mlir/TableGen/Pattern.h"
|
2019-01-06 00:11:29 +08:00
|
|
|
#include "mlir/TableGen/Predicate.h"
|
2019-01-09 09:19:22 +08:00
|
|
|
#include "mlir/TableGen/Type.h"
|
2018-12-12 19:09:11 +08:00
|
|
|
#include "llvm/ADT/StringExtras.h"
|
2018-12-29 04:02:08 +08:00
|
|
|
#include "llvm/ADT/StringSet.h"
|
2018-12-12 19:09:11 +08:00
|
|
|
#include "llvm/Support/CommandLine.h"
|
|
|
|
#include "llvm/Support/PrettyStackTrace.h"
|
|
|
|
#include "llvm/Support/Signals.h"
|
|
|
|
#include "llvm/TableGen/Error.h"
|
|
|
|
#include "llvm/TableGen/Main.h"
|
|
|
|
#include "llvm/TableGen/Record.h"
|
|
|
|
#include "llvm/TableGen/TableGenBackend.h"
|
|
|
|
|
|
|
|
using namespace llvm;
|
2018-12-29 04:02:08 +08:00
|
|
|
using namespace mlir;
|
[TableGen] Consolidate constraint related concepts
Previously we have multiple mechanisms to specify op definition and match constraints:
TypeConstraint, AttributeConstraint, Type, Attr, mAttr, mAttrAnyOf, mPat. These variants
are not added because there are so many distinct cases we need to model; essentially,
they are all carrying a predicate. It's just an artifact of implementation.
It's quite confusing for users to grasp these variants and choose among them. Instead,
as the OpBase TableGen file, we need to strike to provide an unified mechanism. Each
dialect has the flexibility to define its own aliases if wanted.
This CL removes mAttr, mAttrAnyOf, mPat. A new base class, Constraint, is added. Now
TypeConstraint and AttrConstraint derive from Constraint. Type and Attr further derive
from TypeConstraint and AttrConstraint, respectively.
Comments are revised and examples are added to make it clear how to use constraints.
PiperOrigin-RevId: 240125076
2019-03-25 21:09:26 +08:00
|
|
|
using namespace mlir::tblgen;
|
2018-12-29 23:55:08 +08:00
|
|
|
|
2019-04-23 04:40:30 +08:00
|
|
|
// Returns the bound symbol for the given op argument or op named `symbol`.
|
2019-04-04 20:44:58 +08:00
|
|
|
//
|
2019-04-23 04:40:30 +08:00
|
|
|
// Arguments and ops bound in the source pattern are grouped into a
|
|
|
|
// transient `PatternState` struct. This struct can be accessed in both
|
|
|
|
// `match()` and `rewrite()` via the local variable named as `s`.
|
|
|
|
static Twine getBoundSymbol(const StringRef &symbol) {
|
2019-04-04 20:44:58 +08:00
|
|
|
return Twine("s.") + symbol;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PatternSymbolResolver
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
// A class for resolving symbols bound in patterns.
|
|
|
|
//
|
2019-04-23 05:56:54 +08:00
|
|
|
// Symbols can be bound to op arguments and ops in the source pattern and ops
|
|
|
|
// in result patterns. For example, in
|
2019-04-04 20:44:58 +08:00
|
|
|
//
|
|
|
|
// ```
|
|
|
|
// def : Pattern<(SrcOp:$op1 $arg0, %arg1),
|
|
|
|
// [(ResOp1:$op2), (ResOp2 $op2 (ResOp3))]>;
|
|
|
|
// ```
|
|
|
|
//
|
|
|
|
// `$argN` is bound to the `SrcOp`'s N-th argument. `$op1` is bound to `SrcOp`.
|
|
|
|
// `$op2` is bound to `ResOp1`.
|
|
|
|
//
|
|
|
|
// This class keeps track of such symbols and translates them into their bound
|
|
|
|
// values.
|
|
|
|
//
|
|
|
|
// Note that we also generate local variables for unnamed DAG nodes, like
|
2019-04-23 05:56:54 +08:00
|
|
|
// `(ResOp3)` in the above. Since we don't bind a symbol to the op, the
|
2019-04-04 20:44:58 +08:00
|
|
|
// generated local variable will be implicitly named. Those implicit names are
|
|
|
|
// not tracked in this class.
|
|
|
|
class PatternSymbolResolver {
|
|
|
|
public:
|
|
|
|
PatternSymbolResolver(const StringMap<Argument> &srcArgs,
|
2019-04-23 04:40:30 +08:00
|
|
|
const StringSet<> &srcOperations);
|
2019-04-04 20:44:58 +08:00
|
|
|
|
|
|
|
// Marks the given `symbol` as bound. Returns false if the `symbol` is
|
|
|
|
// already bound.
|
|
|
|
bool add(StringRef symbol);
|
|
|
|
|
|
|
|
// Queries the substitution for the given `symbol`.
|
|
|
|
std::string query(StringRef symbol) const;
|
|
|
|
|
|
|
|
private:
|
|
|
|
// Symbols bound to arguments in source pattern.
|
|
|
|
const StringMap<Argument> &sourceArguments;
|
|
|
|
// Symbols bound to ops (for their results) in source pattern.
|
|
|
|
const StringSet<> &sourceOps;
|
|
|
|
// Symbols bound to ops (for their results) in result patterns.
|
|
|
|
StringSet<> resultOps;
|
|
|
|
};
|
|
|
|
} // end anonymous namespace
|
|
|
|
|
|
|
|
PatternSymbolResolver::PatternSymbolResolver(const StringMap<Argument> &srcArgs,
|
2019-04-23 04:40:30 +08:00
|
|
|
const StringSet<> &srcOperations)
|
|
|
|
: sourceArguments(srcArgs), sourceOps(srcOperations) {}
|
2019-04-04 20:44:58 +08:00
|
|
|
|
|
|
|
bool PatternSymbolResolver::add(StringRef symbol) {
|
|
|
|
return resultOps.insert(symbol).second;
|
|
|
|
}
|
|
|
|
|
|
|
|
std::string PatternSymbolResolver::query(StringRef symbol) const {
|
|
|
|
{
|
|
|
|
auto it = resultOps.find(symbol);
|
|
|
|
if (it != resultOps.end())
|
|
|
|
return it->getKey();
|
|
|
|
}
|
|
|
|
{
|
|
|
|
auto it = sourceArguments.find(symbol);
|
|
|
|
if (it != sourceArguments.end())
|
2019-04-23 04:40:30 +08:00
|
|
|
return getBoundSymbol(symbol).str();
|
2019-04-04 20:44:58 +08:00
|
|
|
}
|
|
|
|
{
|
|
|
|
auto it = sourceOps.find(symbol);
|
|
|
|
if (it != sourceOps.end())
|
2019-04-23 04:40:30 +08:00
|
|
|
return getBoundSymbol(symbol).str();
|
2019-04-04 20:44:58 +08:00
|
|
|
}
|
|
|
|
return {};
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PatternEmitter
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-12-29 04:02:08 +08:00
|
|
|
namespace {
|
2019-01-29 06:04:40 +08:00
|
|
|
class PatternEmitter {
|
2018-12-29 04:02:08 +08:00
|
|
|
public:
|
2019-01-29 06:04:40 +08:00
|
|
|
static void emit(StringRef rewriteName, Record *p, RecordOperatorMap *mapper,
|
|
|
|
raw_ostream &os);
|
2018-12-29 23:55:08 +08:00
|
|
|
|
|
|
|
private:
|
2019-02-09 22:36:23 +08:00
|
|
|
PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
|
2018-12-29 23:55:08 +08:00
|
|
|
|
2019-01-29 06:04:40 +08:00
|
|
|
// Emits the mlir::RewritePattern struct named `rewriteName`.
|
2018-12-29 23:55:08 +08:00
|
|
|
void emit(StringRef rewriteName);
|
2018-12-29 04:02:08 +08:00
|
|
|
|
2019-01-29 06:04:40 +08:00
|
|
|
// Emits the match() method.
|
|
|
|
void emitMatchMethod(DagNode tree);
|
2018-12-29 04:02:08 +08:00
|
|
|
|
2019-01-26 02:09:15 +08:00
|
|
|
// Emits the rewrite() method.
|
|
|
|
void emitRewriteMethod();
|
|
|
|
|
2019-01-29 06:04:40 +08:00
|
|
|
// Emits C++ statements for matching the op constrained by the given DAG
|
|
|
|
// `tree`.
|
|
|
|
void emitOpMatch(DagNode tree, int depth);
|
2018-12-29 23:55:08 +08:00
|
|
|
|
2019-02-02 07:40:22 +08:00
|
|
|
// Emits C++ statements for matching the `index`-th argument of the given DAG
|
|
|
|
// `tree` as an operand.
|
|
|
|
void emitOperandMatch(DagNode tree, int index, int depth, int indent);
|
|
|
|
// Emits C++ statements for matching the `index`-th argument of the given DAG
|
|
|
|
// `tree` as an attribute.
|
|
|
|
void emitAttributeMatch(DagNode tree, int index, int depth, int indent);
|
|
|
|
|
2019-02-09 22:36:23 +08:00
|
|
|
// Returns a unique name for an value of the given `op`.
|
|
|
|
std::string getUniqueValueName(const Operator *op);
|
|
|
|
|
|
|
|
// Entry point for handling a rewrite pattern rooted at `resultTree` and
|
|
|
|
// dispatches to concrete handlers. The given tree is the `resultIndex`-th
|
|
|
|
// argument of the enclosing DAG.
|
|
|
|
std::string handleRewritePattern(DagNode resultTree, int resultIndex,
|
2019-03-09 05:57:09 +08:00
|
|
|
int depth);
|
2019-02-09 22:36:23 +08:00
|
|
|
|
2019-04-23 05:13:45 +08:00
|
|
|
// Emits the C++ statement to replace the matched DAG with a value built via
|
|
|
|
// calling native C++ code.
|
|
|
|
std::string emitReplaceWithNativeCodeCall(DagNode resultTree);
|
2019-02-09 22:36:23 +08:00
|
|
|
|
|
|
|
// Returns the C++ expression referencing the old value serving as the
|
|
|
|
// replacement.
|
|
|
|
std::string handleReplaceWithValue(DagNode tree);
|
|
|
|
|
2019-03-09 05:56:53 +08:00
|
|
|
// Handles the `verifyUnusedValue` directive: emitting C++ statements to check
|
|
|
|
// the `index`-th result of the source op is not used.
|
|
|
|
void handleVerifyUnusedValue(DagNode tree, int index);
|
|
|
|
|
2019-02-09 22:36:23 +08:00
|
|
|
// Emits the C++ statement to build a new op out of the given DAG `tree` and
|
2019-03-09 05:57:09 +08:00
|
|
|
// returns the variable name that this op is assigned to. If the root op in
|
|
|
|
// DAG `tree` has a specified name, the created op will be assigned to a
|
|
|
|
// variable of the given name. Otherwise, a unique name will be used as the
|
|
|
|
// result value name.
|
|
|
|
std::string emitOpCreate(DagNode tree, int resultIndex, int depth);
|
2019-02-09 22:36:23 +08:00
|
|
|
|
2019-04-01 23:58:53 +08:00
|
|
|
// Returns the C++ expression to construct a constant attribute of the given
|
|
|
|
// `value` for the given attribute kind `attr`.
|
|
|
|
std::string handleConstantAttr(Attribute attr, StringRef value);
|
2019-03-13 04:55:50 +08:00
|
|
|
|
|
|
|
// Returns the C++ expression to build an argument from the given DAG `leaf`.
|
|
|
|
// `patArgName` is used to bound the argument to the source pattern.
|
|
|
|
std::string handleOpArgument(DagLeaf leaf, llvm::StringRef patArgName);
|
|
|
|
|
2019-04-04 20:44:58 +08:00
|
|
|
// Marks the symbol attached to DagNode `node` as bound. Aborts if the symbol
|
|
|
|
// is already bound.
|
|
|
|
void addSymbol(DagNode node);
|
|
|
|
|
|
|
|
// Gets the substitution for `symbol`. Aborts if `symbol` is not bound.
|
|
|
|
std::string resolveSymbol(StringRef symbol);
|
|
|
|
|
2019-01-29 06:04:40 +08:00
|
|
|
private:
|
|
|
|
// Pattern instantiation location followed by the location of multiclass
|
|
|
|
// prototypes used. This is intended to be used as a whole to
|
|
|
|
// PrintFatalError() on errors.
|
|
|
|
ArrayRef<llvm::SMLoc> loc;
|
|
|
|
// Op's TableGen Record to wrapper object
|
|
|
|
RecordOperatorMap *opMap;
|
|
|
|
// Handy wrapper for pattern being emitted
|
[TableGen] Consolidate constraint related concepts
Previously we have multiple mechanisms to specify op definition and match constraints:
TypeConstraint, AttributeConstraint, Type, Attr, mAttr, mAttrAnyOf, mPat. These variants
are not added because there are so many distinct cases we need to model; essentially,
they are all carrying a predicate. It's just an artifact of implementation.
It's quite confusing for users to grasp these variants and choose among them. Instead,
as the OpBase TableGen file, we need to strike to provide an unified mechanism. Each
dialect has the flexibility to define its own aliases if wanted.
This CL removes mAttr, mAttrAnyOf, mPat. A new base class, Constraint, is added. Now
TypeConstraint and AttrConstraint derive from Constraint. Type and Attr further derive
from TypeConstraint and AttrConstraint, respectively.
Comments are revised and examples are added to make it clear how to use constraints.
PiperOrigin-RevId: 240125076
2019-03-25 21:09:26 +08:00
|
|
|
Pattern pattern;
|
2019-04-04 20:44:58 +08:00
|
|
|
PatternSymbolResolver symbolResolver;
|
2019-02-09 22:36:23 +08:00
|
|
|
// The next unused ID for newly created values
|
|
|
|
unsigned nextValueId;
|
2018-12-29 23:55:08 +08:00
|
|
|
raw_ostream &os;
|
2019-04-12 21:05:49 +08:00
|
|
|
|
|
|
|
// Format contexts containing placeholder substitutations for match().
|
|
|
|
FmtContext matchCtx;
|
|
|
|
// Format contexts containing placeholder substitutations for rewrite().
|
|
|
|
FmtContext rewriteCtx;
|
2018-12-29 04:02:08 +08:00
|
|
|
};
|
2019-04-04 20:44:58 +08:00
|
|
|
} // end anonymous namespace
|
2018-12-29 04:02:08 +08:00
|
|
|
|
2019-02-09 22:36:23 +08:00
|
|
|
PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
|
|
|
|
raw_ostream &os)
|
2019-04-04 20:44:58 +08:00
|
|
|
: loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
|
|
|
|
symbolResolver(pattern.getSourcePatternBoundArgs(),
|
2019-04-23 04:40:30 +08:00
|
|
|
pattern.getSourcePatternBoundOps()),
|
2019-04-12 21:05:49 +08:00
|
|
|
nextValueId(0), os(os) {
|
|
|
|
matchCtx.withBuilder("mlir::Builder(ctx)");
|
|
|
|
rewriteCtx.withBuilder("rewriter");
|
|
|
|
}
|
2019-02-09 22:36:23 +08:00
|
|
|
|
2019-04-01 23:58:53 +08:00
|
|
|
std::string PatternEmitter::handleConstantAttr(Attribute attr,
|
|
|
|
StringRef value) {
|
2019-01-17 01:23:14 +08:00
|
|
|
if (!attr.isConstBuildable())
|
2019-01-29 06:04:40 +08:00
|
|
|
PrintFatalError(loc, "Attribute " + attr.getTableGenDefName() +
|
|
|
|
" does not have the 'constBuilderCall' field");
|
2019-01-03 04:43:52 +08:00
|
|
|
|
2019-01-09 09:19:22 +08:00
|
|
|
// TODO(jpienaar): Verify the constants here
|
2019-04-12 21:05:49 +08:00
|
|
|
return tgfmt(attr.getConstBuilderTemplate(),
|
|
|
|
&rewriteCtx.withBuilder("rewriter"), value);
|
2018-12-29 04:02:08 +08:00
|
|
|
}
|
|
|
|
|
2018-12-29 23:55:08 +08:00
|
|
|
// Helper function to match patterns.
|
2019-01-29 06:04:40 +08:00
|
|
|
void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
|
|
|
|
Operator &op = tree.getDialectOp(opMap);
|
2018-12-29 23:55:08 +08:00
|
|
|
int indent = 4 + 2 * depth;
|
|
|
|
// Skip the operand matching at depth 0 as the pattern rewriter already does.
|
|
|
|
if (depth != 0) {
|
2019-03-28 23:24:38 +08:00
|
|
|
// Skip if there is no defining operation (e.g., arguments to function).
|
2018-12-29 23:55:08 +08:00
|
|
|
os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth);
|
|
|
|
os.indent(indent) << formatv(
|
2019-01-03 08:11:42 +08:00
|
|
|
"if (!op{0}->isa<{1}>()) return matchFailure();\n", depth,
|
2019-01-30 01:27:04 +08:00
|
|
|
op.getQualCppClassName());
|
2018-12-29 23:55:08 +08:00
|
|
|
}
|
2019-02-02 07:40:22 +08:00
|
|
|
if (tree.getNumArgs() != op.getNumArgs()) {
|
|
|
|
PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
|
|
|
|
"pattern vs. {2} in definition",
|
|
|
|
op.getOperationName(), tree.getNumArgs(),
|
|
|
|
op.getNumArgs()));
|
|
|
|
}
|
|
|
|
|
2019-02-14 06:30:40 +08:00
|
|
|
// If the operand's name is set, set to that variable.
|
|
|
|
auto name = tree.getOpName();
|
|
|
|
if (!name.empty())
|
2019-04-23 04:40:30 +08:00
|
|
|
os.indent(indent) << formatv("{0} = op{1};\n", getBoundSymbol(name), depth);
|
2019-02-14 06:30:40 +08:00
|
|
|
|
2019-01-29 06:04:40 +08:00
|
|
|
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
|
2019-01-08 01:52:26 +08:00
|
|
|
auto opArg = op.getArg(i);
|
|
|
|
|
2019-02-02 07:40:22 +08:00
|
|
|
// Handle nested DAG construct first
|
2019-01-29 06:04:40 +08:00
|
|
|
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
|
2018-12-29 23:55:08 +08:00
|
|
|
os.indent(indent) << "{\n";
|
2019-03-27 08:05:09 +08:00
|
|
|
os.indent(indent + 2)
|
|
|
|
<< formatv("auto op{0} = op{1}->getOperand({2})->getDefiningOp();\n",
|
|
|
|
depth + 1, depth, i);
|
2019-01-29 06:04:40 +08:00
|
|
|
emitOpMatch(argTree, depth + 1);
|
2018-12-29 23:55:08 +08:00
|
|
|
os.indent(indent) << "}\n";
|
|
|
|
continue;
|
|
|
|
}
|
2019-01-06 00:11:29 +08:00
|
|
|
|
2019-02-02 07:40:22 +08:00
|
|
|
// Next handle DAG leaf: operand or attribute
|
2019-05-11 04:49:22 +08:00
|
|
|
if (opArg.is<NamedTypeConstraint *>()) {
|
2019-02-02 07:40:22 +08:00
|
|
|
emitOperandMatch(tree, i, depth, indent);
|
2019-05-11 04:49:22 +08:00
|
|
|
} else if (opArg.is<NamedAttribute *>()) {
|
2019-02-02 07:40:22 +08:00
|
|
|
emitAttributeMatch(tree, i, depth, indent);
|
|
|
|
} else {
|
|
|
|
PrintFatalError(loc, "unhandled case when matching op");
|
2019-01-06 00:11:29 +08:00
|
|
|
}
|
2019-02-02 07:40:22 +08:00
|
|
|
}
|
|
|
|
}
|
2019-01-06 00:11:29 +08:00
|
|
|
|
2019-02-02 07:40:22 +08:00
|
|
|
void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
|
|
|
|
int indent) {
|
|
|
|
Operator &op = tree.getDialectOp(opMap);
|
2019-03-18 22:54:20 +08:00
|
|
|
auto *operand = op.getArg(index).get<NamedTypeConstraint *>();
|
2019-02-02 07:40:22 +08:00
|
|
|
auto matcher = tree.getArgAsLeaf(index);
|
|
|
|
|
|
|
|
// If a constraint is specified, we need to generate C++ statements to
|
|
|
|
// check the constraint.
|
|
|
|
if (!matcher.isUnspecified()) {
|
|
|
|
if (!matcher.isOperandMatcher()) {
|
|
|
|
PrintFatalError(
|
|
|
|
loc, formatv("the {1}-th argument of op '{0}' should be an operand",
|
|
|
|
op.getOperationName(), index + 1));
|
2019-01-08 01:52:26 +08:00
|
|
|
}
|
2019-02-02 07:40:22 +08:00
|
|
|
|
|
|
|
// Only need to verify if the matcher's type is different from the one
|
|
|
|
// of op definition.
|
[TableGen] Consolidate constraint related concepts
Previously we have multiple mechanisms to specify op definition and match constraints:
TypeConstraint, AttributeConstraint, Type, Attr, mAttr, mAttrAnyOf, mPat. These variants
are not added because there are so many distinct cases we need to model; essentially,
they are all carrying a predicate. It's just an artifact of implementation.
It's quite confusing for users to grasp these variants and choose among them. Instead,
as the OpBase TableGen file, we need to strike to provide an unified mechanism. Each
dialect has the flexibility to define its own aliases if wanted.
This CL removes mAttr, mAttrAnyOf, mPat. A new base class, Constraint, is added. Now
TypeConstraint and AttrConstraint derive from Constraint. Type and Attr further derive
from TypeConstraint and AttrConstraint, respectively.
Comments are revised and examples are added to make it clear how to use constraints.
PiperOrigin-RevId: 240125076
2019-03-25 21:09:26 +08:00
|
|
|
if (operand->constraint != matcher.getAsConstraint()) {
|
2019-04-12 21:05:49 +08:00
|
|
|
auto self = formatv("op{0}->getOperand({1})->getType()", depth, index);
|
2019-02-02 07:40:22 +08:00
|
|
|
os.indent(indent) << "if (!("
|
2019-04-12 21:05:49 +08:00
|
|
|
<< tgfmt(matcher.getConditionTemplate(),
|
|
|
|
&matchCtx.withSelf(self))
|
2019-02-02 07:40:22 +08:00
|
|
|
<< ")) return matchFailure();\n";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Capture the value
|
|
|
|
auto name = tree.getArgName(index);
|
|
|
|
if (!name.empty()) {
|
2019-04-23 04:40:30 +08:00
|
|
|
os.indent(indent) << getBoundSymbol(name) << " = op" << depth
|
2019-02-02 07:40:22 +08:00
|
|
|
<< "->getOperand(" << index << ");\n";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
|
|
|
|
int indent) {
|
|
|
|
Operator &op = tree.getDialectOp(opMap);
|
|
|
|
auto *namedAttr = op.getArg(index).get<NamedAttribute *>();
|
2019-04-05 00:25:38 +08:00
|
|
|
const auto &attr = namedAttr->attr;
|
|
|
|
|
|
|
|
os.indent(indent) << "{\n";
|
|
|
|
indent += 2;
|
|
|
|
os.indent(indent) << formatv(
|
|
|
|
"auto attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth,
|
2019-04-13 01:24:59 +08:00
|
|
|
attr.getStorageType(), namedAttr->name);
|
2019-04-05 00:25:38 +08:00
|
|
|
|
|
|
|
// TODO(antiagainst): This should use getter method to avoid duplication.
|
2019-04-12 21:05:49 +08:00
|
|
|
if (attr.hasDefaultValueInitializer()) {
|
2019-04-05 00:25:38 +08:00
|
|
|
os.indent(indent) << "if (!attr) attr = "
|
2019-04-12 21:05:49 +08:00
|
|
|
<< tgfmt(attr.getConstBuilderTemplate(), &matchCtx,
|
|
|
|
attr.getDefaultValueInitializer())
|
2019-04-05 00:25:38 +08:00
|
|
|
<< ";\n";
|
|
|
|
} else if (attr.isOptional()) {
|
|
|
|
// For a missing attribut that is optional according to definition, we
|
|
|
|
// should just capature a mlir::Attribute() to signal the missing state.
|
|
|
|
// That is precisely what getAttr() returns on missing attributes.
|
|
|
|
} else {
|
|
|
|
os.indent(indent) << "if (!attr) return matchFailure();\n";
|
|
|
|
}
|
2019-02-02 07:40:22 +08:00
|
|
|
|
2019-04-05 00:25:38 +08:00
|
|
|
auto matcher = tree.getArgAsLeaf(index);
|
2019-02-10 09:38:24 +08:00
|
|
|
if (!matcher.isUnspecified()) {
|
|
|
|
if (!matcher.isAttrMatcher()) {
|
|
|
|
PrintFatalError(
|
|
|
|
loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
|
|
|
|
op.getOperationName(), index + 1));
|
|
|
|
}
|
2019-02-02 07:40:22 +08:00
|
|
|
|
2019-02-10 09:38:24 +08:00
|
|
|
// If a constraint is specified, we need to generate C++ statements to
|
|
|
|
// check the constraint.
|
2019-04-05 00:25:38 +08:00
|
|
|
os.indent(indent) << "if (!("
|
2019-04-12 21:05:49 +08:00
|
|
|
<< tgfmt(matcher.getConditionTemplate(),
|
|
|
|
&matchCtx.withSelf("attr"))
|
2019-04-05 00:25:38 +08:00
|
|
|
<< ")) return matchFailure();\n";
|
2019-02-10 09:38:24 +08:00
|
|
|
}
|
2019-02-02 07:40:22 +08:00
|
|
|
|
|
|
|
// Capture the value
|
|
|
|
auto name = tree.getArgName(index);
|
|
|
|
if (!name.empty()) {
|
2019-04-23 04:40:30 +08:00
|
|
|
os.indent(indent) << getBoundSymbol(name) << " = attr;\n";
|
2018-12-29 04:02:08 +08:00
|
|
|
}
|
2019-04-05 00:25:38 +08:00
|
|
|
|
|
|
|
indent -= 2;
|
|
|
|
os.indent(indent) << "}\n";
|
2018-12-29 23:55:08 +08:00
|
|
|
}
|
2018-12-29 04:02:08 +08:00
|
|
|
|
2019-01-29 06:04:40 +08:00
|
|
|
void PatternEmitter::emitMatchMethod(DagNode tree) {
|
2018-12-29 23:55:08 +08:00
|
|
|
// Emit the heading.
|
2018-12-29 04:02:08 +08:00
|
|
|
os << R"(
|
2019-03-28 23:24:38 +08:00
|
|
|
PatternMatchResult match(Operation *op0) const override {
|
2019-02-01 00:33:47 +08:00
|
|
|
auto ctx = op0->getContext(); (void)ctx;
|
2019-04-04 20:44:58 +08:00
|
|
|
auto state = llvm::make_unique<MatchedState>();
|
|
|
|
auto &s = *state;
|
|
|
|
)";
|
2019-02-14 06:30:40 +08:00
|
|
|
|
2019-03-09 05:56:53 +08:00
|
|
|
// The rewrite pattern may specify that certain outputs should be unused in
|
|
|
|
// the source IR. Check it here.
|
|
|
|
for (int i = 0, e = pattern.getNumResults(); i < e; ++i) {
|
|
|
|
DagNode resultTree = pattern.getResultPattern(i);
|
|
|
|
if (resultTree.isVerifyUnusedValue()) {
|
|
|
|
handleVerifyUnusedValue(resultTree, i);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-01-29 06:04:40 +08:00
|
|
|
emitOpMatch(tree, 0);
|
2019-02-14 06:30:40 +08:00
|
|
|
|
[TableGen] Consolidate constraint related concepts
Previously we have multiple mechanisms to specify op definition and match constraints:
TypeConstraint, AttributeConstraint, Type, Attr, mAttr, mAttrAnyOf, mPat. These variants
are not added because there are so many distinct cases we need to model; essentially,
they are all carrying a predicate. It's just an artifact of implementation.
It's quite confusing for users to grasp these variants and choose among them. Instead,
as the OpBase TableGen file, we need to strike to provide an unified mechanism. Each
dialect has the flexibility to define its own aliases if wanted.
This CL removes mAttr, mAttrAnyOf, mPat. A new base class, Constraint, is added. Now
TypeConstraint and AttrConstraint derive from Constraint. Type and Attr further derive
from TypeConstraint and AttrConstraint, respectively.
Comments are revised and examples are added to make it clear how to use constraints.
PiperOrigin-RevId: 240125076
2019-03-25 21:09:26 +08:00
|
|
|
for (auto &appliedConstraint : pattern.getConstraints()) {
|
|
|
|
auto &constraint = appliedConstraint.constraint;
|
|
|
|
auto &entities = appliedConstraint.entities;
|
|
|
|
|
|
|
|
auto condition = constraint.getConditionTemplate();
|
|
|
|
auto cmd = "if (!{0}) return matchFailure();\n";
|
|
|
|
|
|
|
|
if (isa<TypeConstraint>(constraint)) {
|
2019-04-12 21:05:49 +08:00
|
|
|
auto self = formatv("(*{0}->result_type_begin())",
|
|
|
|
resolveSymbol(entities.front()));
|
2019-02-14 06:30:40 +08:00
|
|
|
// TODO(jpienaar): Verify op only has one result.
|
2019-04-12 21:05:49 +08:00
|
|
|
os.indent(4) << formatv(cmd,
|
|
|
|
tgfmt(condition, &matchCtx.withSelf(self.str())));
|
[TableGen] Consolidate constraint related concepts
Previously we have multiple mechanisms to specify op definition and match constraints:
TypeConstraint, AttributeConstraint, Type, Attr, mAttr, mAttrAnyOf, mPat. These variants
are not added because there are so many distinct cases we need to model; essentially,
they are all carrying a predicate. It's just an artifact of implementation.
It's quite confusing for users to grasp these variants and choose among them. Instead,
as the OpBase TableGen file, we need to strike to provide an unified mechanism. Each
dialect has the flexibility to define its own aliases if wanted.
This CL removes mAttr, mAttrAnyOf, mPat. A new base class, Constraint, is added. Now
TypeConstraint and AttrConstraint derive from Constraint. Type and Attr further derive
from TypeConstraint and AttrConstraint, respectively.
Comments are revised and examples are added to make it clear how to use constraints.
PiperOrigin-RevId: 240125076
2019-03-25 21:09:26 +08:00
|
|
|
} else if (isa<AttrConstraint>(constraint)) {
|
|
|
|
PrintFatalError(
|
|
|
|
loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
|
2019-02-14 06:30:40 +08:00
|
|
|
} else {
|
[TableGen] Consolidate constraint related concepts
Previously we have multiple mechanisms to specify op definition and match constraints:
TypeConstraint, AttributeConstraint, Type, Attr, mAttr, mAttrAnyOf, mPat. These variants
are not added because there are so many distinct cases we need to model; essentially,
they are all carrying a predicate. It's just an artifact of implementation.
It's quite confusing for users to grasp these variants and choose among them. Instead,
as the OpBase TableGen file, we need to strike to provide an unified mechanism. Each
dialect has the flexibility to define its own aliases if wanted.
This CL removes mAttr, mAttrAnyOf, mPat. A new base class, Constraint, is added. Now
TypeConstraint and AttrConstraint derive from Constraint. Type and Attr further derive
from TypeConstraint and AttrConstraint, respectively.
Comments are revised and examples are added to make it clear how to use constraints.
PiperOrigin-RevId: 240125076
2019-03-25 21:09:26 +08:00
|
|
|
// TODO(fengliuai): replace formatv arguments with the exact specified
|
|
|
|
// args.
|
|
|
|
if (entities.size() > 4) {
|
|
|
|
PrintFatalError(loc, "only support up to 4-entity constraints now");
|
|
|
|
}
|
|
|
|
SmallVector<std::string, 4> names;
|
2019-05-04 10:48:57 +08:00
|
|
|
int i = 0;
|
|
|
|
for (int e = entities.size(); i < e; ++i)
|
2019-04-04 20:44:58 +08:00
|
|
|
names.push_back(resolveSymbol(entities[i]));
|
[TableGen] Consolidate constraint related concepts
Previously we have multiple mechanisms to specify op definition and match constraints:
TypeConstraint, AttributeConstraint, Type, Attr, mAttr, mAttrAnyOf, mPat. These variants
are not added because there are so many distinct cases we need to model; essentially,
they are all carrying a predicate. It's just an artifact of implementation.
It's quite confusing for users to grasp these variants and choose among them. Instead,
as the OpBase TableGen file, we need to strike to provide an unified mechanism. Each
dialect has the flexibility to define its own aliases if wanted.
This CL removes mAttr, mAttrAnyOf, mPat. A new base class, Constraint, is added. Now
TypeConstraint and AttrConstraint derive from Constraint. Type and Attr further derive
from TypeConstraint and AttrConstraint, respectively.
Comments are revised and examples are added to make it clear how to use constraints.
PiperOrigin-RevId: 240125076
2019-03-25 21:09:26 +08:00
|
|
|
for (; i < 4; ++i)
|
|
|
|
names.push_back("<unused>");
|
2019-04-12 21:05:49 +08:00
|
|
|
os.indent(4) << formatv(cmd, tgfmt(condition, &matchCtx, names[0],
|
|
|
|
names[1], names[2], names[3]));
|
2019-02-14 06:30:40 +08:00
|
|
|
}
|
|
|
|
}
|
2019-03-09 05:56:53 +08:00
|
|
|
|
2018-12-29 04:02:08 +08:00
|
|
|
os.indent(4) << "return matchSuccess(std::move(state));\n }\n";
|
2018-12-29 23:55:08 +08:00
|
|
|
}
|
|
|
|
|
2019-01-29 06:04:40 +08:00
|
|
|
void PatternEmitter::emit(StringRef rewriteName) {
|
|
|
|
// Get the DAG tree for the source pattern
|
|
|
|
DagNode tree = pattern.getSourcePattern();
|
|
|
|
|
|
|
|
const Operator &rootOp = pattern.getSourceRootOp();
|
|
|
|
auto rootName = rootOp.getOperationName();
|
2018-12-29 04:02:08 +08:00
|
|
|
|
2019-04-26 05:45:37 +08:00
|
|
|
if (rootOp.getNumVariadicResults() != 0)
|
2019-04-04 03:29:14 +08:00
|
|
|
PrintFatalError(
|
|
|
|
loc, "replacing op with variadic results not supported right now");
|
|
|
|
|
2018-12-29 23:55:08 +08:00
|
|
|
// Emit RewritePattern for Pattern.
|
|
|
|
os << formatv(R"(struct {0} : public RewritePattern {
|
2019-01-29 06:04:40 +08:00
|
|
|
{0}(MLIRContext *context) : RewritePattern("{1}", {2}, context) {{})",
|
2019-03-30 00:36:09 +08:00
|
|
|
rewriteName, rootName, pattern.getBenefit())
|
2018-12-29 23:55:08 +08:00
|
|
|
<< "\n";
|
|
|
|
|
|
|
|
// Emit matched state.
|
|
|
|
os << " struct MatchedState : public PatternState {\n";
|
2019-01-29 06:04:40 +08:00
|
|
|
for (const auto &arg : pattern.getSourcePatternBoundArgs()) {
|
2019-02-02 07:40:22 +08:00
|
|
|
auto fieldName = arg.first();
|
|
|
|
if (auto namedAttr = arg.second.dyn_cast<NamedAttribute *>()) {
|
|
|
|
os.indent(4) << namedAttr->attr.getStorageType() << " " << fieldName
|
2019-01-10 05:50:20 +08:00
|
|
|
<< ";\n";
|
2018-12-29 23:55:08 +08:00
|
|
|
} else {
|
2019-04-23 04:40:30 +08:00
|
|
|
os.indent(4) << "Value *" << fieldName << ";\n";
|
2018-12-29 23:55:08 +08:00
|
|
|
}
|
|
|
|
}
|
2019-04-23 04:40:30 +08:00
|
|
|
for (const auto &result : pattern.getSourcePatternBoundOps()) {
|
|
|
|
os.indent(4) << "Operation *" << result.getKey() << ";\n";
|
|
|
|
}
|
2018-12-29 23:55:08 +08:00
|
|
|
os << " };\n";
|
|
|
|
|
2019-01-29 06:04:40 +08:00
|
|
|
emitMatchMethod(tree);
|
2019-01-26 02:09:15 +08:00
|
|
|
emitRewriteMethod();
|
|
|
|
|
|
|
|
os << "};\n";
|
|
|
|
}
|
|
|
|
|
2019-01-29 06:04:40 +08:00
|
|
|
void PatternEmitter::emitRewriteMethod() {
|
2019-04-04 03:29:14 +08:00
|
|
|
const Operator &rootOp = pattern.getSourceRootOp();
|
|
|
|
int numExpectedResults = rootOp.getNumResults();
|
2019-05-04 10:48:57 +08:00
|
|
|
int numProvidedResults = pattern.getNumResults();
|
2019-04-04 03:29:14 +08:00
|
|
|
|
|
|
|
if (numProvidedResults < numExpectedResults)
|
|
|
|
PrintFatalError(
|
|
|
|
loc, "no enough result patterns to replace root op in source pattern");
|
2018-12-29 04:02:08 +08:00
|
|
|
|
2019-01-26 02:09:15 +08:00
|
|
|
os << R"(
|
2019-03-28 23:24:38 +08:00
|
|
|
void rewrite(Operation *op, std::unique_ptr<PatternState> state,
|
2018-12-29 04:02:08 +08:00
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
auto& s = *static_cast<MatchedState *>(state.get());
|
2019-02-09 22:36:23 +08:00
|
|
|
auto loc = op->getLoc(); (void)loc;
|
2019-01-26 02:09:15 +08:00
|
|
|
)";
|
|
|
|
|
2019-03-09 05:56:53 +08:00
|
|
|
// Collect the replacement value for each result
|
|
|
|
llvm::SmallVector<std::string, 2> resultValues;
|
2019-05-04 10:48:57 +08:00
|
|
|
for (int i = 0; i < numProvidedResults; ++i) {
|
2019-03-09 05:56:53 +08:00
|
|
|
DagNode resultTree = pattern.getResultPattern(i);
|
|
|
|
resultValues.push_back(handleRewritePattern(resultTree, i, 0));
|
2019-04-04 20:44:58 +08:00
|
|
|
// Keep track of bound symbols at the top-level DAG nodes
|
|
|
|
addSymbol(resultTree);
|
2019-03-09 05:56:53 +08:00
|
|
|
}
|
2019-02-09 22:36:23 +08:00
|
|
|
|
2019-03-09 05:56:53 +08:00
|
|
|
// Emit the final replaceOp() statement
|
|
|
|
os.indent(4) << "rewriter.replaceOp(op, {";
|
|
|
|
interleave(
|
2019-04-04 03:29:14 +08:00
|
|
|
// We only use the last numExpectedResults ones to replace the root op.
|
|
|
|
ArrayRef<std::string>(resultValues).take_back(numExpectedResults),
|
|
|
|
[&](const std::string &name) { os << name; }, [&]() { os << ", "; });
|
2019-02-09 22:36:23 +08:00
|
|
|
os << "});\n }\n";
|
|
|
|
}
|
|
|
|
|
|
|
|
std::string PatternEmitter::getUniqueValueName(const Operator *op) {
|
|
|
|
return formatv("v{0}{1}", op->getCppClassName(), nextValueId++);
|
|
|
|
}
|
|
|
|
|
|
|
|
std::string PatternEmitter::handleRewritePattern(DagNode resultTree,
|
2019-03-09 05:57:09 +08:00
|
|
|
int resultIndex, int depth) {
|
2019-04-23 05:13:45 +08:00
|
|
|
if (resultTree.isNativeCodeCall())
|
|
|
|
return emitReplaceWithNativeCodeCall(resultTree);
|
2019-02-09 22:36:23 +08:00
|
|
|
|
2019-03-09 05:56:53 +08:00
|
|
|
if (resultTree.isVerifyUnusedValue()) {
|
|
|
|
if (depth > 0) {
|
|
|
|
// TODO: Revisit this when we have use cases of matching an intermediate
|
|
|
|
// multi-result op with no uses of its certain results.
|
|
|
|
PrintFatalError(loc, "verifyUnusedValue directive can only be used to "
|
|
|
|
"verify top-level result");
|
|
|
|
}
|
2019-04-04 20:44:58 +08:00
|
|
|
|
|
|
|
if (!resultTree.getOpName().empty()) {
|
|
|
|
PrintFatalError(loc, "cannot bind symbol to verifyUnusedValue");
|
|
|
|
}
|
|
|
|
|
2019-03-09 05:56:53 +08:00
|
|
|
// The C++ statements to check that this result value is unused are already
|
|
|
|
// emitted in the match() method. So returning a nullptr here directly
|
|
|
|
// should be safe because the C++ RewritePattern harness will use it to
|
|
|
|
// replace nothing.
|
|
|
|
return "nullptr";
|
|
|
|
}
|
|
|
|
|
2019-02-09 22:36:23 +08:00
|
|
|
if (resultTree.isReplaceWithValue())
|
|
|
|
return handleReplaceWithValue(resultTree);
|
2019-01-26 02:09:15 +08:00
|
|
|
|
2019-03-09 05:57:09 +08:00
|
|
|
return emitOpCreate(resultTree, resultIndex, depth);
|
2019-01-26 02:09:15 +08:00
|
|
|
}
|
|
|
|
|
2019-02-09 22:36:23 +08:00
|
|
|
std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
|
|
|
|
assert(tree.isReplaceWithValue());
|
|
|
|
|
|
|
|
if (tree.getNumArgs() != 1) {
|
|
|
|
PrintFatalError(
|
|
|
|
loc, "replaceWithValue directive must take exactly one argument");
|
|
|
|
}
|
|
|
|
|
2019-04-04 20:44:58 +08:00
|
|
|
if (!tree.getOpName().empty()) {
|
|
|
|
PrintFatalError(loc, "cannot bind symbol to verifyUnusedValue");
|
|
|
|
}
|
|
|
|
|
2019-02-09 22:36:23 +08:00
|
|
|
auto name = tree.getArgName(0);
|
2019-04-23 04:40:30 +08:00
|
|
|
pattern.ensureBoundInSourcePattern(name);
|
2019-02-09 22:36:23 +08:00
|
|
|
|
2019-04-23 04:40:30 +08:00
|
|
|
return getBoundSymbol(name).str();
|
2019-02-09 22:36:23 +08:00
|
|
|
}
|
|
|
|
|
2019-03-09 05:56:53 +08:00
|
|
|
void PatternEmitter::handleVerifyUnusedValue(DagNode tree, int index) {
|
|
|
|
assert(tree.isVerifyUnusedValue());
|
|
|
|
|
|
|
|
os.indent(4) << "if (!op0->getResult(" << index
|
|
|
|
<< ")->use_empty()) return matchFailure();\n";
|
|
|
|
}
|
|
|
|
|
2019-03-13 04:55:50 +08:00
|
|
|
std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
|
|
|
|
llvm::StringRef argName) {
|
|
|
|
if (leaf.isConstantAttr()) {
|
2019-04-01 23:58:53 +08:00
|
|
|
auto constAttr = leaf.getAsConstantAttr();
|
|
|
|
return handleConstantAttr(constAttr.getAttribute(),
|
|
|
|
constAttr.getConstantValue());
|
|
|
|
}
|
|
|
|
if (leaf.isEnumAttrCase()) {
|
|
|
|
auto enumCase = leaf.getAsEnumAttrCase();
|
|
|
|
return handleConstantAttr(enumCase, enumCase.getSymbol());
|
2019-03-13 04:55:50 +08:00
|
|
|
}
|
2019-04-23 04:40:30 +08:00
|
|
|
pattern.ensureBoundInSourcePattern(argName);
|
|
|
|
std::string result = getBoundSymbol(argName).str();
|
2019-03-13 04:55:50 +08:00
|
|
|
if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
|
|
|
|
return result;
|
|
|
|
}
|
2019-04-23 05:13:45 +08:00
|
|
|
if (leaf.isNativeCodeCall()) {
|
|
|
|
return tgfmt(leaf.getNativeCodeTemplate(), &rewriteCtx.withSelf(result));
|
2019-03-13 04:55:50 +08:00
|
|
|
}
|
|
|
|
PrintFatalError(loc, "unhandled case when rewriting op");
|
|
|
|
}
|
|
|
|
|
2019-04-23 05:13:45 +08:00
|
|
|
std::string PatternEmitter::emitReplaceWithNativeCodeCall(DagNode tree) {
|
|
|
|
auto fmt = tree.getNativeCodeTemplate();
|
2019-03-13 04:55:50 +08:00
|
|
|
// TODO(fengliuai): replace formatv arguments with the exact specified args.
|
|
|
|
SmallVector<std::string, 8> attrs(8);
|
|
|
|
if (tree.getNumArgs() > 8) {
|
2019-04-23 05:13:45 +08:00
|
|
|
PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
|
2019-03-13 04:55:50 +08:00
|
|
|
Twine(tree.getNumArgs()));
|
|
|
|
}
|
2019-05-04 10:48:57 +08:00
|
|
|
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
|
2019-03-13 04:55:50 +08:00
|
|
|
attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
|
|
|
|
}
|
2019-04-12 21:05:49 +08:00
|
|
|
return tgfmt(fmt, &rewriteCtx, attrs[0], attrs[1], attrs[2], attrs[3],
|
|
|
|
attrs[4], attrs[5], attrs[6], attrs[7]);
|
2019-03-13 04:55:50 +08:00
|
|
|
}
|
|
|
|
|
2019-04-04 20:44:58 +08:00
|
|
|
void PatternEmitter::addSymbol(DagNode node) {
|
|
|
|
StringRef symbol = node.getOpName();
|
|
|
|
// Skip empty-named symbols, which happen for unbound ops in result patterns.
|
|
|
|
if (symbol.empty())
|
|
|
|
return;
|
|
|
|
if (!symbolResolver.add(symbol))
|
|
|
|
PrintFatalError(loc, formatv("symbol '{0}' bound more than once", symbol));
|
|
|
|
}
|
|
|
|
|
|
|
|
std::string PatternEmitter::resolveSymbol(StringRef symbol) {
|
|
|
|
auto subst = symbolResolver.query(symbol);
|
|
|
|
if (subst.empty())
|
|
|
|
PrintFatalError(loc, formatv("referencing unbound symbol '{0}'", symbol));
|
|
|
|
return subst;
|
|
|
|
}
|
|
|
|
|
2019-02-09 22:36:23 +08:00
|
|
|
std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
|
2019-03-09 05:57:09 +08:00
|
|
|
int depth) {
|
2019-02-09 22:36:23 +08:00
|
|
|
Operator &resultOp = tree.getDialectOp(opMap);
|
2019-03-09 05:56:53 +08:00
|
|
|
auto numOpArgs = resultOp.getNumArgs();
|
2019-01-26 02:09:15 +08:00
|
|
|
|
2019-02-09 22:36:23 +08:00
|
|
|
if (numOpArgs != tree.getNumArgs()) {
|
2019-02-02 07:40:22 +08:00
|
|
|
PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: "
|
|
|
|
"{1} in pattern vs. {2} in definition",
|
2019-02-09 22:36:23 +08:00
|
|
|
resultOp.getOperationName(), tree.getNumArgs(),
|
|
|
|
numOpArgs));
|
|
|
|
}
|
|
|
|
|
2019-03-09 05:56:53 +08:00
|
|
|
if (resultOp.getNumResults() > 1) {
|
|
|
|
PrintFatalError(
|
|
|
|
loc, formatv("generating multiple-result op '{0}' is unsupported now",
|
|
|
|
resultOp.getOperationName()));
|
|
|
|
}
|
|
|
|
|
2019-02-09 22:45:55 +08:00
|
|
|
// A map to collect all nested DAG child nodes' names, with operand index as
|
2019-04-04 20:44:58 +08:00
|
|
|
// the key. This includes both bound and unbound child nodes. Bound child
|
|
|
|
// nodes will additionally be tracked in `symbolResolver` so they can be
|
|
|
|
// referenced by other patterns. Unbound child nodes will only be used once
|
|
|
|
// to build this op.
|
2019-02-09 22:45:55 +08:00
|
|
|
llvm::DenseMap<unsigned, std::string> childNodeNames;
|
|
|
|
|
|
|
|
// First go through all the child nodes who are nested DAG constructs to
|
|
|
|
// create ops for them, so that we can use the results in the current node.
|
|
|
|
// This happens in a recursive manner.
|
2019-05-04 10:48:57 +08:00
|
|
|
for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
|
2019-02-09 22:45:55 +08:00
|
|
|
if (auto child = tree.getArgAsNestedDag(i)) {
|
2019-03-09 05:57:09 +08:00
|
|
|
childNodeNames[i] = handleRewritePattern(child, i, depth + 1);
|
2019-04-04 20:44:58 +08:00
|
|
|
// Keep track of bound symbols at the middle-level DAG nodes
|
|
|
|
addSymbol(child);
|
2019-02-09 22:45:55 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-03-09 05:57:09 +08:00
|
|
|
// Use the specified name for this op if available. Generate one otherwise.
|
|
|
|
std::string resultValue = tree.getOpName();
|
|
|
|
if (resultValue.empty())
|
|
|
|
resultValue = getUniqueValueName(&resultOp);
|
2019-02-09 22:36:23 +08:00
|
|
|
|
2019-02-09 22:45:55 +08:00
|
|
|
// Then we build the new op corresponding to this DAG node.
|
|
|
|
|
2019-02-09 22:36:23 +08:00
|
|
|
// TODO: this is a hack to support various constant ops. We are assuming
|
|
|
|
// all of them have no operands and one attribute here. Figure out a better
|
|
|
|
// way to do this.
|
2019-02-15 02:54:50 +08:00
|
|
|
bool isConstOp =
|
|
|
|
resultOp.getNumOperands() == 0 && resultOp.getNumNativeAttributes() == 1;
|
|
|
|
|
|
|
|
bool isSameValueType = resultOp.hasTrait("SameOperandsAndResultType");
|
|
|
|
bool isBroadcastable = resultOp.hasTrait("BroadcastableTwoOperandsOneResult");
|
2019-03-27 06:31:15 +08:00
|
|
|
bool useFirstAttr = resultOp.hasTrait("FirstAttrDerivedResultType");
|
2019-02-15 02:54:50 +08:00
|
|
|
|
2019-03-27 06:31:15 +08:00
|
|
|
if (isConstOp || isSameValueType || isBroadcastable || useFirstAttr) {
|
2019-02-09 22:36:23 +08:00
|
|
|
os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc", resultValue,
|
|
|
|
resultOp.getQualCppClassName());
|
|
|
|
} else {
|
|
|
|
std::string resultType = formatv("op->getResult({0})", resultIndex).str();
|
|
|
|
|
|
|
|
os.indent(4) << formatv(
|
|
|
|
"auto {0} = rewriter.create<{1}>(loc, {2}->getType()", resultValue,
|
|
|
|
resultOp.getQualCppClassName(), resultType);
|
2018-12-29 04:02:08 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// Create the builder call for the result.
|
|
|
|
// Add operands.
|
|
|
|
int i = 0;
|
2019-04-23 05:13:45 +08:00
|
|
|
for (int e = resultOp.getNumOperands(); i < e; ++i) {
|
|
|
|
const auto &operand = resultOp.getOperand(i);
|
|
|
|
|
2018-12-29 04:02:08 +08:00
|
|
|
// Start each operand on its own line.
|
|
|
|
(os << ",\n").indent(6);
|
|
|
|
|
2019-01-30 22:05:27 +08:00
|
|
|
if (!operand.name.empty())
|
|
|
|
os << "/*" << operand.name << "=*/";
|
2019-02-09 22:45:55 +08:00
|
|
|
|
2019-04-04 20:44:58 +08:00
|
|
|
if (tree.isNestedDagArg(i)) {
|
|
|
|
os << childNodeNames[i];
|
|
|
|
} else {
|
2019-04-23 05:13:45 +08:00
|
|
|
DagLeaf leaf = tree.getArgAsLeaf(i);
|
|
|
|
auto symbol = resolveSymbol(tree.getArgName(i));
|
|
|
|
if (leaf.isNativeCodeCall()) {
|
|
|
|
os << tgfmt(leaf.getNativeCodeTemplate(), &rewriteCtx.withSelf(symbol));
|
|
|
|
} else {
|
|
|
|
os << symbol;
|
|
|
|
}
|
2019-04-04 20:44:58 +08:00
|
|
|
}
|
2018-12-29 04:02:08 +08:00
|
|
|
// TODO(jpienaar): verify types
|
|
|
|
}
|
|
|
|
|
|
|
|
// Add attributes.
|
2019-02-09 22:36:23 +08:00
|
|
|
for (int e = tree.getNumArgs(); i != e; ++i) {
|
2018-12-29 04:02:08 +08:00
|
|
|
// Start each attribute on its own line.
|
|
|
|
(os << ",\n").indent(6);
|
2019-02-02 07:40:22 +08:00
|
|
|
// The argument in the op definition.
|
|
|
|
auto opArgName = resultOp.getArgName(i);
|
2019-03-13 04:55:50 +08:00
|
|
|
if (auto subTree = tree.getArgAsNestedDag(i)) {
|
2019-04-23 05:13:45 +08:00
|
|
|
if (!subTree.isNativeCodeCall())
|
|
|
|
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
|
|
|
|
"for creating attribute");
|
|
|
|
os << formatv("/*{0}=*/{1}", opArgName,
|
|
|
|
emitReplaceWithNativeCodeCall(subTree));
|
2019-02-02 07:40:22 +08:00
|
|
|
} else {
|
2019-03-13 04:55:50 +08:00
|
|
|
auto leaf = tree.getArgAsLeaf(i);
|
|
|
|
// The argument in the result DAG pattern.
|
|
|
|
auto patArgName = tree.getArgName(i);
|
2019-04-01 23:58:53 +08:00
|
|
|
if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
|
2019-03-13 04:55:50 +08:00
|
|
|
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
|
|
|
|
auto argument = resultOp.getArg(i);
|
|
|
|
if (!argument.is<NamedAttribute *>())
|
|
|
|
PrintFatalError(loc, Twine("expected attribute ") + Twine(i));
|
|
|
|
if (!patArgName.empty())
|
|
|
|
os << "/*" << patArgName << "=*/";
|
|
|
|
} else {
|
|
|
|
os << "/*" << opArgName << "=*/";
|
|
|
|
}
|
|
|
|
os << handleOpArgument(leaf, patArgName);
|
2019-01-08 01:52:26 +08:00
|
|
|
}
|
2018-12-29 04:02:08 +08:00
|
|
|
}
|
2019-01-26 02:09:15 +08:00
|
|
|
os << "\n );\n";
|
2019-02-01 09:57:06 +08:00
|
|
|
|
2019-02-09 22:36:23 +08:00
|
|
|
return resultValue;
|
2019-02-01 09:57:06 +08:00
|
|
|
}
|
|
|
|
|
2019-01-29 06:04:40 +08:00
|
|
|
void PatternEmitter::emit(StringRef rewriteName, Record *p,
|
|
|
|
RecordOperatorMap *mapper, raw_ostream &os) {
|
|
|
|
PatternEmitter(p, mapper, os).emit(rewriteName);
|
2018-12-29 23:55:08 +08:00
|
|
|
}
|
|
|
|
|
2018-12-12 19:09:11 +08:00
|
|
|
static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
|
|
|
emitSourceFileHeader("Rewriters", os);
|
2019-04-11 02:37:53 +08:00
|
|
|
|
2018-12-12 19:09:11 +08:00
|
|
|
const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
|
2019-04-11 02:37:53 +08:00
|
|
|
auto numPatterns = patterns.size();
|
2018-12-12 19:09:11 +08:00
|
|
|
|
2019-01-29 06:04:40 +08:00
|
|
|
// We put the map here because it can be shared among multiple patterns.
|
|
|
|
RecordOperatorMap recordOpMap;
|
|
|
|
|
2019-04-11 02:37:53 +08:00
|
|
|
std::vector<std::string> rewriterNames;
|
|
|
|
rewriterNames.reserve(numPatterns);
|
|
|
|
|
|
|
|
std::string baseRewriterName = "GeneratedConvert";
|
|
|
|
int rewriterIndex = 0;
|
|
|
|
|
2018-12-29 04:02:08 +08:00
|
|
|
for (Record *p : patterns) {
|
2019-04-11 02:37:53 +08:00
|
|
|
std::string name;
|
|
|
|
if (p->isAnonymous()) {
|
|
|
|
// If no name is provided, ensure unique rewriter names simply by
|
|
|
|
// appending unique suffix.
|
|
|
|
name = baseRewriterName + llvm::utostr(rewriterIndex++);
|
|
|
|
} else {
|
|
|
|
name = p->getName();
|
|
|
|
}
|
|
|
|
PatternEmitter::emit(name, p, &recordOpMap, os);
|
|
|
|
rewriterNames.push_back(std::move(name));
|
2018-12-12 19:09:11 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// Emit function to add the generated matchers to the pattern list.
|
|
|
|
os << "void populateWithGenerated(MLIRContext *context, "
|
|
|
|
<< "OwningRewritePatternList *patterns) {\n";
|
2019-04-11 02:37:53 +08:00
|
|
|
for (const auto &name : rewriterNames) {
|
|
|
|
os << " patterns->push_back(llvm::make_unique<" << name
|
|
|
|
<< ">(context));\n";
|
2018-12-12 19:09:11 +08:00
|
|
|
}
|
|
|
|
os << "}\n";
|
|
|
|
}
|
|
|
|
|
Start doc generation pass.
Start doc generation pass that generates simple markdown output. The output is formatted simply[1] in markdown, but this allows seeing what info we have, where we can refine the op description (e.g., the inputs is probably redundant), what info is missing (e.g., the attributes could probably have a description).
The formatting of the description is still left up to whatever was in the op definition (which luckily, due to the uniformity in the .td file, turned out well but relying on the indentation there is fragile). The mechanism to autogenerate these post changes has not been added yet either. The output file could be run through a markdown formatter too to remove extra spaces.
[1]. This is not proposal for final style :) There could also be a discussion around single doc vs multiple (per dialect, per op), whether we want a TOC, whether operands/attributes should be headings or just formatted differently ...
PiperOrigin-RevId: 230354538
2019-01-23 01:31:04 +08:00
|
|
|
static mlir::GenRegistration
|
2018-12-27 20:56:03 +08:00
|
|
|
genRewriters("gen-rewriters", "Generate pattern rewriters",
|
|
|
|
[](const RecordKeeper &records, raw_ostream &os) {
|
|
|
|
emitRewriters(records, os);
|
|
|
|
return false;
|
|
|
|
});
|