forked from OSchip/llvm-project
[mlir-tblgen] Add DagNode StaticMatcher.
Some patterns may share the common DAG structures. Generate a static function to do the match logic to reduce the binary size. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D105797
This commit is contained in:
parent
4e7c0a37c9
commit
bb2506061b
|
@ -18,6 +18,7 @@
|
|||
#include "mlir/TableGen/Argument.h"
|
||||
#include "mlir/TableGen/Operator.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/Hashing.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
|
||||
|
@ -198,6 +199,7 @@ public:
|
|||
|
||||
private:
|
||||
friend class SymbolInfoMap;
|
||||
friend llvm::DenseMapInfo<DagNode>;
|
||||
const void *getAsOpaquePointer() const { return node; }
|
||||
|
||||
const llvm::DagInit *node; // nullptr means null DagNode
|
||||
|
@ -242,10 +244,17 @@ public:
|
|||
// Class for information regarding a symbol.
|
||||
class SymbolInfo {
|
||||
public:
|
||||
// Returns a type string of a variable.
|
||||
std::string getVarTypeStr(StringRef name) const;
|
||||
|
||||
// Returns a string for defining a variable named as `name` to store the
|
||||
// value bound by this symbol.
|
||||
std::string getVarDecl(StringRef name) const;
|
||||
|
||||
// Returns a string for defining an argument which passes the reference of
|
||||
// the variable.
|
||||
std::string getArgDecl(StringRef name) const;
|
||||
|
||||
// Returns a variable name for the symbol named as `name`.
|
||||
std::string getVarName(StringRef name) const;
|
||||
|
||||
|
@ -383,6 +392,7 @@ public:
|
|||
// with index `argIndex` for operator `op`.
|
||||
const_iterator findBoundSymbol(StringRef key, DagNode node,
|
||||
const Operator &op, int argIndex) const;
|
||||
const_iterator findBoundSymbol(StringRef key, SymbolInfo symbolInfo) const;
|
||||
|
||||
// Returns the bounds of a range that includes all the elements which
|
||||
// bind to the `key`.
|
||||
|
@ -474,15 +484,15 @@ public:
|
|||
// pair).
|
||||
std::vector<IdentifierLine> getLocation() const;
|
||||
|
||||
private:
|
||||
// Helper function to verify variabld binding.
|
||||
void verifyBind(bool result, StringRef symbolName);
|
||||
|
||||
// Recursively collects all bound symbols inside the DAG tree rooted
|
||||
// at `tree` and updates the given `infoMap`.
|
||||
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
|
||||
bool isSrcPattern);
|
||||
|
||||
private:
|
||||
// Helper function to verify variable binding.
|
||||
void verifyBind(bool result, StringRef symbolName);
|
||||
|
||||
// The TableGen definition of this pattern.
|
||||
const llvm::Record &def;
|
||||
|
||||
|
@ -495,4 +505,24 @@ private:
|
|||
} // end namespace tblgen
|
||||
} // end namespace mlir
|
||||
|
||||
namespace llvm {
|
||||
template <>
|
||||
struct DenseMapInfo<mlir::tblgen::DagNode> {
|
||||
static mlir::tblgen::DagNode getEmptyKey() {
|
||||
return mlir::tblgen::DagNode(
|
||||
llvm::DenseMapInfo<llvm::DagInit *>::getEmptyKey());
|
||||
}
|
||||
static mlir::tblgen::DagNode getTombstoneKey() {
|
||||
return mlir::tblgen::DagNode(
|
||||
llvm::DenseMapInfo<llvm::DagInit *>::getTombstoneKey());
|
||||
}
|
||||
static unsigned getHashValue(mlir::tblgen::DagNode node) {
|
||||
return llvm::hash_value(node.getAsOpaquePointer());
|
||||
}
|
||||
static bool isEqual(mlir::tblgen::DagNode lhs, mlir::tblgen::DagNode rhs) {
|
||||
return lhs.node == rhs.node;
|
||||
}
|
||||
};
|
||||
} // end namespace llvm
|
||||
|
||||
#endif // MLIR_TABLEGEN_PATTERN_H_
|
||||
|
|
|
@ -230,45 +230,50 @@ std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
|
|||
return alternativeName.hasValue() ? alternativeName.getValue() : name.str();
|
||||
}
|
||||
|
||||
std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
|
||||
LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
|
||||
std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const {
|
||||
LLVM_DEBUG(llvm::dbgs() << "getVarTypeStr for '" << name << "': ");
|
||||
switch (kind) {
|
||||
case Kind::Attr: {
|
||||
if (op) {
|
||||
auto type = op->getArg(getArgIndex())
|
||||
if (op)
|
||||
return op->getArg(getArgIndex())
|
||||
.get<NamedAttribute *>()
|
||||
->attr.getStorageType();
|
||||
return std::string(formatv("{0} {1};\n", type, name));
|
||||
}
|
||||
->attr.getStorageType()
|
||||
.str();
|
||||
// TODO(suderman): Use a more exact type when available.
|
||||
return std::string(formatv("Attribute {0};\n", name));
|
||||
return "Attribute";
|
||||
}
|
||||
case Kind::Operand: {
|
||||
// Use operand range for captured operands (to support potential variadic
|
||||
// operands).
|
||||
return std::string(
|
||||
formatv("::mlir::Operation::operand_range {0}(op0->getOperands());\n",
|
||||
getVarName(name)));
|
||||
return "::mlir::Operation::operand_range";
|
||||
}
|
||||
case Kind::Value: {
|
||||
return std::string(formatv("::mlir::Value {0};\n", name));
|
||||
return "::mlir::Value";
|
||||
}
|
||||
case Kind::MultipleValues: {
|
||||
// This is for the variable used in the source pattern. Each named value in
|
||||
// source pattern will only be bound to a Value. The others in the result
|
||||
// pattern may be associated with multiple Values as we will use `auto` to
|
||||
// do the type inference.
|
||||
return std::string(formatv(
|
||||
"::mlir::Value {0}_raw; ::mlir::ValueRange {0}({0}_raw);\n", name));
|
||||
return "::mlir::ValueRange";
|
||||
}
|
||||
case Kind::Result: {
|
||||
// Use the op itself for captured results.
|
||||
return std::string(formatv("{0} {1};\n", op->getQualCppClassName(), name));
|
||||
return op->getQualCppClassName();
|
||||
}
|
||||
}
|
||||
llvm_unreachable("unknown kind");
|
||||
}
|
||||
|
||||
std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
|
||||
LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
|
||||
std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : "";
|
||||
return std::string(
|
||||
formatv("{0} {1}{2};\n", getVarTypeStr(name), getVarName(name), varInit));
|
||||
}
|
||||
|
||||
std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const {
|
||||
LLVM_DEBUG(llvm::dbgs() << "getArgDecl for '" << name << "': ");
|
||||
return std::string(
|
||||
formatv("{0} &{1}", getVarTypeStr(name), getVarName(name)));
|
||||
}
|
||||
|
||||
std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
|
||||
StringRef name, int index, const char *fmt, const char *separator) const {
|
||||
LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': ");
|
||||
|
@ -486,11 +491,14 @@ SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
|
|||
SymbolInfoMap::const_iterator
|
||||
SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
|
||||
int argIndex) const {
|
||||
return findBoundSymbol(key, SymbolInfo::getOperand(node, &op, argIndex));
|
||||
}
|
||||
|
||||
SymbolInfoMap::const_iterator
|
||||
SymbolInfoMap::findBoundSymbol(StringRef key, SymbolInfo symbolInfo) const {
|
||||
std::string name = getValuePackName(key).str();
|
||||
auto range = symbolInfoMap.equal_range(name);
|
||||
|
||||
const auto symbolInfo = SymbolInfo::getOperand(node, &op, argIndex);
|
||||
|
||||
for (auto it = range.first; it != range.second; ++it)
|
||||
if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
|
||||
return it;
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def Test_Dialect : Dialect {
|
||||
let name = "test";
|
||||
}
|
||||
class NS_Op<string mnemonic, list<OpTrait> traits> :
|
||||
Op<Test_Dialect, mnemonic, traits>;
|
||||
|
||||
def AOp : NS_Op<"a_op", []> {
|
||||
let arguments = (ins
|
||||
AnyInteger:$any_integer
|
||||
);
|
||||
|
||||
let results = (outs AnyInteger);
|
||||
}
|
||||
|
||||
def BOp : NS_Op<"b_op", []> {
|
||||
let arguments = (ins
|
||||
AnyAttr: $any_attr,
|
||||
AnyInteger
|
||||
);
|
||||
|
||||
let results = (outs AnyInteger);
|
||||
}
|
||||
|
||||
def COp : NS_Op<"c_op", []> {
|
||||
let arguments = (ins
|
||||
AnyAttr: $any_attr,
|
||||
AnyInteger
|
||||
);
|
||||
|
||||
let results = (outs AnyInteger);
|
||||
}
|
||||
|
||||
// Test static matcher for duplicate DagNode
|
||||
// ---
|
||||
|
||||
// CHECK: static ::mlir::LogicalResult static_dag_matcher_0
|
||||
|
||||
// CHECK: if(failed(static_dag_matcher_0(rewriter, op1, tblgen_ops
|
||||
def : Pat<(AOp (BOp I32Attr:$attr, I32:$int)),
|
||||
(AOp $int)>;
|
||||
|
||||
// CHECK: if(failed(static_dag_matcher_0(rewriter, op1, tblgen_ops
|
||||
def : Pat<(COp $_, (BOp I32Attr:$attr, I32:$int)),
|
||||
(COp $attr, $int)>;
|
|
@ -18,6 +18,7 @@
|
|||
#include "mlir/TableGen/Pattern.h"
|
||||
#include "mlir/TableGen/Predicate.h"
|
||||
#include "mlir/TableGen/Type.h"
|
||||
#include "llvm/ADT/FunctionExtras.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
|
@ -54,13 +55,20 @@ struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
|
||||
class StaticMatcherHelper;
|
||||
|
||||
class PatternEmitter {
|
||||
public:
|
||||
PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
|
||||
PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os,
|
||||
StaticMatcherHelper &helper);
|
||||
|
||||
// Emits the mlir::RewritePattern struct named `rewriteName`.
|
||||
void emit(StringRef rewriteName);
|
||||
|
||||
// Emits the static function of DAG matcher.
|
||||
void emitStaticMatcher(DagNode tree, std::string funcName);
|
||||
|
||||
private:
|
||||
// Emits the code for matching ops.
|
||||
void emitMatchLogic(DagNode tree, StringRef opName);
|
||||
|
@ -75,6 +83,9 @@ private:
|
|||
// Emits C++ statements for matching the DAG structure.
|
||||
void emitMatch(DagNode tree, StringRef name, int depth);
|
||||
|
||||
// Emit C++ function call to static DAG matcher.
|
||||
void emitStaticMatchCall(DagNode tree, StringRef name);
|
||||
|
||||
// Emits C++ statements for matching using a native code call.
|
||||
void emitNativeCodeMatch(DagNode tree, StringRef name, int depth);
|
||||
|
||||
|
@ -216,6 +227,8 @@ private:
|
|||
// Map for all bound symbols' info.
|
||||
SymbolInfoMap symbolInfoMap;
|
||||
|
||||
StaticMatcherHelper &staticMatcherHelper;
|
||||
|
||||
// The next unused ID for newly created values.
|
||||
unsigned nextValueId;
|
||||
|
||||
|
@ -223,16 +236,79 @@ private:
|
|||
|
||||
// Format contexts containing placeholder substitutions.
|
||||
FmtContext fmtCtx;
|
||||
|
||||
// Number of op processed.
|
||||
int opCounter = 0;
|
||||
};
|
||||
|
||||
// Tracks DagNode's reference multiple times across patterns. Enables generating
|
||||
// static matcher functions for DagNode's referenced multiple times rather than
|
||||
// inlining them.
|
||||
class StaticMatcherHelper {
|
||||
public:
|
||||
StaticMatcherHelper(RecordOperatorMap &mapper);
|
||||
|
||||
// Determine if we should inline the match logic or delegate to a static
|
||||
// function.
|
||||
bool useStaticMatcher(DagNode node) {
|
||||
return refStats[node] > kStaticMatcherThreshold;
|
||||
}
|
||||
|
||||
// Get the name of the static DAG matcher function corresponding to the node.
|
||||
std::string getMatcherName(DagNode node) {
|
||||
assert(useStaticMatcher(node));
|
||||
return matcherNames[node];
|
||||
}
|
||||
|
||||
// Collect the `Record`s, i.e., the DRR, so that we can get the information of
|
||||
// the duplicated DAGs.
|
||||
void addPattern(Record *record);
|
||||
|
||||
// Emit all static functions of DAG Matcher.
|
||||
void populateStaticMatchers(raw_ostream &os);
|
||||
|
||||
private:
|
||||
static constexpr unsigned kStaticMatcherThreshold = 1;
|
||||
|
||||
// Consider two patterns as down below,
|
||||
// DagNode_Root_A DagNode_Root_B
|
||||
// \ \
|
||||
// DagNode_C DagNode_C
|
||||
// \ \
|
||||
// DagNode_D DagNode_D
|
||||
//
|
||||
// DagNode_Root_A and DagNode_Root_B share the same subtree which consists of
|
||||
// DagNode_C and DagNode_D. Both DagNode_C and DagNode_D are referenced
|
||||
// multiple times so we'll have static matchers for both of them. When we're
|
||||
// emitting the match logic for DagNode_C, we will check if DagNode_D has the
|
||||
// static matcher generated. If so, then we'll generate a call to the
|
||||
// function, inline otherwise. In this case, inlining is not what we want. As
|
||||
// a result, generate the static matcher in topological order to ensure all
|
||||
// the dependent static matchers are generated and we can avoid accidentally
|
||||
// inlining.
|
||||
//
|
||||
// The topological order of all the DagNodes among all patterns.
|
||||
SmallVector<std::pair<DagNode, Record *>> topologicalOrder;
|
||||
|
||||
RecordOperatorMap &opMap;
|
||||
|
||||
// Records of the static function name of each DagNode
|
||||
DenseMap<DagNode, std::string> matcherNames;
|
||||
|
||||
// After collecting all the DagNode in each pattern, `refStats` records the
|
||||
// number of users for each DagNode. We will generate the static matcher for a
|
||||
// DagNode while the number of users exceeds a certain threshold.
|
||||
DenseMap<DagNode, unsigned> refStats;
|
||||
|
||||
// Number of static matcher generated. This is used to generate a unique name
|
||||
// for each DagNode.
|
||||
int staticMatcherCounter = 0;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
|
||||
raw_ostream &os)
|
||||
raw_ostream &os, StaticMatcherHelper &helper)
|
||||
: loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
|
||||
symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) {
|
||||
symbolInfoMap(pat->getLoc()), staticMatcherHelper(helper), nextValueId(0),
|
||||
os(os) {
|
||||
fmtCtx.withBuilder("rewriter");
|
||||
}
|
||||
|
||||
|
@ -246,6 +322,33 @@ std::string PatternEmitter::handleConstantAttr(Attribute attr,
|
|||
return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value));
|
||||
}
|
||||
|
||||
void PatternEmitter::emitStaticMatcher(DagNode tree, std::string funcName) {
|
||||
os << formatv(
|
||||
"static ::mlir::LogicalResult {0}(::mlir::PatternRewriter &rewriter, "
|
||||
"::mlir::Operation *op0, ::llvm::SmallVector<::mlir::Operation "
|
||||
"*, 4> &tblgen_ops",
|
||||
funcName);
|
||||
|
||||
// We pass the reference of the variables that need to be captured. Hence we
|
||||
// need to collect all the symbols in the tree first.
|
||||
pattern.collectBoundSymbols(tree, symbolInfoMap, /*isSrcPattern=*/true);
|
||||
symbolInfoMap.assignUniqueAlternativeNames();
|
||||
for (const auto &info : symbolInfoMap)
|
||||
os << formatv(", {0}", info.second.getArgDecl(info.first));
|
||||
|
||||
os << ") {\n";
|
||||
os.indent();
|
||||
os << "(void)tblgen_ops;\n";
|
||||
|
||||
// Note that a static matcher is considered at least one step from the match
|
||||
// entry.
|
||||
emitMatch(tree, "op0", /*depth=*/1);
|
||||
|
||||
os << "return ::mlir::success();\n";
|
||||
os.unindent();
|
||||
os << "}\n\n";
|
||||
}
|
||||
|
||||
// Helper function to match patterns.
|
||||
void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
|
||||
if (tree.isNativeCodeCall()) {
|
||||
|
@ -261,6 +364,36 @@ void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
|
|||
PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match.");
|
||||
}
|
||||
|
||||
void PatternEmitter::emitStaticMatchCall(DagNode tree, StringRef opName) {
|
||||
std::string funcName = staticMatcherHelper.getMatcherName(tree);
|
||||
os << formatv("if(failed({0}(rewriter, {1}, tblgen_ops", funcName, opName);
|
||||
|
||||
// TODO(chiahungduan): Add a lookupBoundSymbols() to do the subtree lookup in
|
||||
// one pass.
|
||||
|
||||
// In general, bound symbol should have the unique name in the pattern but
|
||||
// for the operand, binding same symbol to multiple operands imply a
|
||||
// constraint at the same time. In this case, we will rename those operands
|
||||
// with different names. As a result, we need to collect all the symbolInfos
|
||||
// from the DagNode then get the updated name of the local variables from the
|
||||
// global symbolInfoMap.
|
||||
|
||||
// Collect all the bound symbols in the Dag
|
||||
SymbolInfoMap localSymbolMap(loc);
|
||||
pattern.collectBoundSymbols(tree, localSymbolMap, /*isSrcPattern=*/true);
|
||||
|
||||
for (const auto &info : localSymbolMap) {
|
||||
auto name = info.first;
|
||||
auto symboInfo = info.second;
|
||||
auto ret = symbolInfoMap.findBoundSymbol(name, symboInfo);
|
||||
os << formatv(", {0}", ret->second.getVarName(name));
|
||||
}
|
||||
|
||||
os << "))) {\n";
|
||||
os.scope().os << "return ::mlir::failure();\n";
|
||||
os << "}\n";
|
||||
}
|
||||
|
||||
// Helper function to match patterns.
|
||||
void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
|
||||
int depth) {
|
||||
|
@ -268,6 +401,21 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
|
|||
LLVM_DEBUG(tree.print(llvm::dbgs()));
|
||||
LLVM_DEBUG(llvm::dbgs() << '\n');
|
||||
|
||||
// The order of generating static matcher follows the topological order so
|
||||
// that for every dependent DagNode already have their static matcher
|
||||
// generated if needed. The reason we check if `getMatcherName(tree).empty()`
|
||||
// is when we are generating the static matcher for a DagNode itself. In this
|
||||
// case, we need to emit the function body rather than a function call.
|
||||
if (staticMatcherHelper.useStaticMatcher(tree) &&
|
||||
!staticMatcherHelper.getMatcherName(tree).empty()) {
|
||||
emitStaticMatchCall(tree, opName);
|
||||
|
||||
// NativeCodeCall will never be at depth 0 so that we don't need to catch
|
||||
// the root operation as emitOpMatch();
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO(suderman): iterate through arguments, determine their types, output
|
||||
// names.
|
||||
SmallVector<std::string, 8> capture;
|
||||
|
@ -356,7 +504,28 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
|
|||
<< op.getOperationName() << "' at depth " << depth
|
||||
<< '\n');
|
||||
|
||||
std::string castedName = formatv("castedOp{0}", depth);
|
||||
auto getCastedName = [depth]() -> std::string {
|
||||
return formatv("castedOp{0}", depth);
|
||||
};
|
||||
|
||||
// The order of generating static matcher follows the topological order so
|
||||
// that for every dependent DagNode already have their static matcher
|
||||
// generated if needed. The reason we check if `getMatcherName(tree).empty()`
|
||||
// is when we are generating the static matcher for a DagNode itself. In this
|
||||
// case, we need to emit the function body rather than a function call.
|
||||
if (staticMatcherHelper.useStaticMatcher(tree) &&
|
||||
!staticMatcherHelper.getMatcherName(tree).empty()) {
|
||||
emitStaticMatchCall(tree, opName);
|
||||
// In the codegen of rewriter, we suppose that castedOp0 will capture the
|
||||
// root operation. Manually add it if the root DagNode is a static matcher.
|
||||
if (depth == 0)
|
||||
os << formatv("auto {2} = ::llvm::dyn_cast_or_null<{1}>({0}); "
|
||||
"(void){2};\n",
|
||||
opName, op.getQualCppClassName(), getCastedName());
|
||||
return;
|
||||
}
|
||||
|
||||
std::string castedName = getCastedName();
|
||||
os << formatv("auto {0} = ::llvm::dyn_cast<{2}>({1}); "
|
||||
"(void){0};\n",
|
||||
castedName, opName, op.getQualCppClassName());
|
||||
|
@ -405,7 +574,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
|
|||
formatv("\"Operand {0} of {1} has null definingOp\"",
|
||||
nextOperand++, castedName));
|
||||
emitMatch(argTree, argName, depth + 1);
|
||||
os << formatv("tblgen_ops[{0}] = {1};\n", ++opCounter, argName);
|
||||
os << formatv("tblgen_ops.push_back({0});\n", argName);
|
||||
os.unindent() << "}\n";
|
||||
continue;
|
||||
}
|
||||
|
@ -704,13 +873,12 @@ void PatternEmitter::emit(StringRef rewriteName) {
|
|||
}
|
||||
// TODO: capture ops with consistent numbering so that it can be
|
||||
// reused for fused loc.
|
||||
os << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n",
|
||||
pattern.getSourcePattern().getNumOps());
|
||||
os << "::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;\n\n";
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "done creating local variables for capturing matches\n");
|
||||
|
||||
os << "// Match\n";
|
||||
os << "tblgen_ops[0] = op0;\n";
|
||||
os << "tblgen_ops.push_back(op0);\n";
|
||||
emitMatchLogic(sourceTree, "op0");
|
||||
|
||||
os << "\n// Rewrite\n";
|
||||
|
@ -1399,17 +1567,67 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
|
|||
}
|
||||
}
|
||||
|
||||
StaticMatcherHelper::StaticMatcherHelper(RecordOperatorMap &mapper)
|
||||
: opMap(mapper) {}
|
||||
|
||||
void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) {
|
||||
// PatternEmitter will use the static matcher if there's one generated. To
|
||||
// ensure that all the dependent static matchers are generated before emitting
|
||||
// the matching logic of the DagNode, we use topological order to achieve it.
|
||||
for (auto &dagInfo : topologicalOrder) {
|
||||
DagNode node = dagInfo.first;
|
||||
if (!useStaticMatcher(node))
|
||||
continue;
|
||||
|
||||
std::string funcName =
|
||||
formatv("static_dag_matcher_{0}", staticMatcherCounter++);
|
||||
assert(matcherNames.find(node) == matcherNames.end());
|
||||
PatternEmitter(dagInfo.second, &opMap, os, *this)
|
||||
.emitStaticMatcher(node, funcName);
|
||||
matcherNames[node] = funcName;
|
||||
}
|
||||
}
|
||||
|
||||
void StaticMatcherHelper::addPattern(Record *record) {
|
||||
Pattern pat(record, &opMap);
|
||||
|
||||
// While generating the function body of the DAG matcher, it may depends on
|
||||
// other DAG matchers. To ensure the dependent matchers are ready, we compute
|
||||
// the topological order for all the DAGs and emit the DAG matchers in this
|
||||
// order.
|
||||
llvm::unique_function<void(DagNode)> dfs = [&](DagNode node) {
|
||||
++refStats[node];
|
||||
|
||||
if (refStats[node] != 1)
|
||||
return;
|
||||
|
||||
for (unsigned i = 0, e = node.getNumArgs(); i < e; ++i)
|
||||
if (DagNode sibling = node.getArgAsNestedDag(i))
|
||||
dfs(sibling);
|
||||
|
||||
topologicalOrder.push_back(std::make_pair(node, record));
|
||||
};
|
||||
|
||||
dfs(pat.getSourcePattern());
|
||||
}
|
||||
|
||||
static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||
emitSourceFileHeader("Rewriters", os);
|
||||
|
||||
const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
|
||||
auto numPatterns = patterns.size();
|
||||
|
||||
// We put the map here because it can be shared among multiple patterns.
|
||||
RecordOperatorMap recordOpMap;
|
||||
|
||||
// Exam all the patterns and generate static matcher for the duplicated
|
||||
// DagNode.
|
||||
StaticMatcherHelper staticMatcher(recordOpMap);
|
||||
for (Record *p : patterns)
|
||||
staticMatcher.addPattern(p);
|
||||
staticMatcher.populateStaticMatchers(os);
|
||||
|
||||
std::vector<std::string> rewriterNames;
|
||||
rewriterNames.reserve(numPatterns);
|
||||
rewriterNames.reserve(patterns.size());
|
||||
|
||||
std::string baseRewriterName = "GeneratedConvert";
|
||||
int rewriterIndex = 0;
|
||||
|
@ -1425,7 +1643,7 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
|||
}
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "=== start generating pattern '" << name << "' ===\n");
|
||||
PatternEmitter(p, &recordOpMap, os).emit(name);
|
||||
PatternEmitter(p, &recordOpMap, os, staticMatcher).emit(name);
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "=== done generating pattern '" << name << "' ===\n");
|
||||
rewriterNames.push_back(std::move(name));
|
||||
|
|
Loading…
Reference in New Issue