forked from OSchip/llvm-project
Add LLVM_DEBUG in RewritersGen.cpp and Pattern.cpp
It's usually hard to understand what went wrong if mlir-tblgen crashes on some input. This CL adds a few useful LLVM_DEBUG statements so that we can use mlir-tblegn -debug to figure out the culprit for a crash. PiperOrigin-RevId: 275253532
This commit is contained in:
parent
bdc250c5a7
commit
1358df19ca
|
@ -104,6 +104,8 @@ public:
|
|||
// Precondition: isNativeCodeCall()
|
||||
StringRef getNativeCodeTemplate() const;
|
||||
|
||||
void print(raw_ostream &os) const;
|
||||
|
||||
private:
|
||||
// Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and
|
||||
// also a subclass of the given `superclass`.
|
||||
|
@ -176,6 +178,8 @@ public:
|
|||
// Precondition: isNativeCodeCall()
|
||||
StringRef getNativeCodeTemplate() const;
|
||||
|
||||
void print(raw_ostream &os) const;
|
||||
|
||||
private:
|
||||
const llvm::DagInit *node; // nullptr means null DagNode
|
||||
};
|
||||
|
|
|
@ -22,10 +22,13 @@
|
|||
|
||||
#include "mlir/TableGen/Pattern.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/TableGen/Error.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
|
||||
#define DEBUG_TYPE "mlir-tblgen-pattern"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
using llvm::formatv;
|
||||
|
@ -92,6 +95,11 @@ bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const {
|
|||
return false;
|
||||
}
|
||||
|
||||
void tblgen::DagLeaf::print(raw_ostream &os) const {
|
||||
if (def)
|
||||
def->print(os);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DagNode
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -159,6 +167,11 @@ bool tblgen::DagNode::isReplaceWithValue() const {
|
|||
return dagOpDef->getName() == "replaceWithValue";
|
||||
}
|
||||
|
||||
void tblgen::DagNode::print(raw_ostream &os) const {
|
||||
if (node)
|
||||
node->print(os);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SymbolInfoMap
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -222,10 +235,13 @@ tblgen::SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
|
|||
|
||||
std::string tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
|
||||
StringRef name, int index, const char *fmt, const char *separator) const {
|
||||
LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': ");
|
||||
switch (kind) {
|
||||
case Kind::Attr: {
|
||||
assert(index < 0);
|
||||
return formatv(fmt, name);
|
||||
auto repl = formatv(fmt, name);
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n");
|
||||
return repl;
|
||||
}
|
||||
case Kind::Operand: {
|
||||
assert(index < 0);
|
||||
|
@ -233,9 +249,13 @@ std::string tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
|
|||
// If this operand is variadic, then return a range. Otherwise, return the
|
||||
// value itself.
|
||||
if (operand->isVariadic()) {
|
||||
return formatv(fmt, name);
|
||||
auto repl = formatv(fmt, name);
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
|
||||
return repl;
|
||||
}
|
||||
return formatv(fmt, formatv("(*{0}.begin())", name));
|
||||
auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n");
|
||||
return repl;
|
||||
}
|
||||
case Kind::Result: {
|
||||
// If `index` is greater than zero, then we are referencing a specific
|
||||
|
@ -244,7 +264,9 @@ std::string tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
|
|||
std::string v = formatv("{0}.getODSResults({1})", name, index);
|
||||
if (!op->getResult(index).isVariadic())
|
||||
v = formatv("(*{0}.begin())", v);
|
||||
return formatv(fmt, v);
|
||||
auto repl = formatv(fmt, v);
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
|
||||
return repl;
|
||||
}
|
||||
|
||||
// We are referencing all results of the multi-result op. A specific result
|
||||
|
@ -259,12 +281,16 @@ std::string tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
|
|||
}
|
||||
values.push_back(formatv(fmt, v));
|
||||
}
|
||||
return llvm::join(values, separator);
|
||||
auto repl = llvm::join(values, separator);
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
|
||||
return repl;
|
||||
}
|
||||
case Kind::Value: {
|
||||
assert(index < 0);
|
||||
assert(op == nullptr);
|
||||
return formatv(fmt, name);
|
||||
auto repl = formatv(fmt, name);
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
|
||||
return repl;
|
||||
}
|
||||
}
|
||||
llvm_unreachable("unknown kind");
|
||||
|
@ -272,15 +298,20 @@ std::string tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
|
|||
|
||||
std::string tblgen::SymbolInfoMap::SymbolInfo::getAllRangeUse(
|
||||
StringRef name, int index, const char *fmt, const char *separator) const {
|
||||
LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': ");
|
||||
switch (kind) {
|
||||
case Kind::Attr:
|
||||
case Kind::Operand: {
|
||||
assert(index < 0 && "only allowed for symbol bound to result");
|
||||
return formatv(fmt, name);
|
||||
auto repl = formatv(fmt, name);
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n");
|
||||
return repl;
|
||||
}
|
||||
case Kind::Result: {
|
||||
if (index >= 0) {
|
||||
return formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
|
||||
auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
|
||||
return repl;
|
||||
}
|
||||
|
||||
// We are referencing all results of the multi-result op. Each result should
|
||||
|
@ -292,12 +323,16 @@ std::string tblgen::SymbolInfoMap::SymbolInfo::getAllRangeUse(
|
|||
values.push_back(
|
||||
formatv(fmt, formatv("{0}.getODSResults({1})", name, i)));
|
||||
}
|
||||
return llvm::join(values, separator);
|
||||
auto repl = llvm::join(values, separator);
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
|
||||
return repl;
|
||||
}
|
||||
case Kind::Value: {
|
||||
assert(index < 0 && "only allowed for symbol bound to result");
|
||||
assert(op == nullptr);
|
||||
return formatv(fmt, formatv("{{{0}}", name));
|
||||
auto repl = formatv(fmt, formatv("{{{0}}", name));
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
|
||||
return repl;
|
||||
}
|
||||
}
|
||||
llvm_unreachable("unknown kind");
|
||||
|
@ -402,15 +437,19 @@ tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const {
|
|||
|
||||
void tblgen::Pattern::collectSourcePatternBoundSymbols(
|
||||
tblgen::SymbolInfoMap &infoMap) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
|
||||
collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
|
||||
LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
|
||||
}
|
||||
|
||||
void tblgen::Pattern::collectResultPatternBoundSymbols(
|
||||
tblgen::SymbolInfoMap &infoMap) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n");
|
||||
for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
|
||||
auto pattern = getResultPattern(i);
|
||||
collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
|
||||
}
|
||||
LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n");
|
||||
}
|
||||
|
||||
const tblgen::Operator &tblgen::Pattern::getSourceRootOp() {
|
||||
|
@ -496,6 +535,8 @@ void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
|
|||
// The name attached to the DAG node's operator is for representing the
|
||||
// results generated from this op. It should be remembered as bound results.
|
||||
if (!treeName.empty()) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "found symbol bound to op result: " << treeName << '\n');
|
||||
if (!infoMap.bindOpResult(treeName, op))
|
||||
PrintFatalError(def.getLoc(),
|
||||
formatv("symbol '{0}' bound more than once", treeName));
|
||||
|
@ -510,6 +551,8 @@ void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
|
|||
// symbols are referenced in result patterns.
|
||||
auto treeArgName = tree.getArgName(i);
|
||||
if (!treeArgName.empty()) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
|
||||
<< treeArgName << '\n');
|
||||
if (!infoMap.bindOpArgument(treeArgName, op, i)) {
|
||||
auto err = formatv("symbol '{0}' bound more than once", treeArgName);
|
||||
PrintFatalError(def.getLoc(), err);
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/FormatAdapters.h"
|
||||
#include "llvm/Support/PrettyStackTrace.h"
|
||||
#include "llvm/Support/Signals.h"
|
||||
|
@ -42,6 +43,8 @@ using namespace llvm;
|
|||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
|
||||
#define DEBUG_TYPE "mlir-tblgen-rewritergen"
|
||||
|
||||
namespace llvm {
|
||||
template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
|
||||
static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
|
||||
|
@ -191,6 +194,9 @@ std::string PatternEmitter::handleConstantAttr(Attribute attr,
|
|||
// Helper function to match patterns.
|
||||
void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
|
||||
Operator &op = tree.getDialectOp(opMap);
|
||||
LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
|
||||
<< op.getOperationName() << "' at depth " << depth
|
||||
<< '\n');
|
||||
|
||||
int indent = 4 + 2 * depth;
|
||||
os.indent(indent) << formatv(
|
||||
|
@ -249,6 +255,9 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
|
|||
PrintFatalError(loc, "unhandled case when matching op");
|
||||
}
|
||||
}
|
||||
LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
|
||||
<< op.getOperationName() << "' at depth " << depth
|
||||
<< '\n');
|
||||
}
|
||||
|
||||
void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
|
||||
|
@ -346,6 +355,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
|
|||
}
|
||||
|
||||
void PatternEmitter::emitMatchLogic(DagNode tree) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
|
||||
emitOpMatch(tree, 0);
|
||||
|
||||
for (auto &appliedConstraint : pattern.getConstraints()) {
|
||||
|
@ -383,13 +393,18 @@ void PatternEmitter::emitMatchLogic(DagNode tree) {
|
|||
names[1], names[2], names[3]));
|
||||
}
|
||||
}
|
||||
LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
|
||||
}
|
||||
|
||||
void PatternEmitter::collectOps(DagNode tree,
|
||||
llvm::SmallPtrSetImpl<const Operator *> &ops) {
|
||||
// Check if this tree is an operation.
|
||||
if (tree.isOperation())
|
||||
ops.insert(&tree.getDialectOp(opMap));
|
||||
if (tree.isOperation()) {
|
||||
const Operator &op = tree.getDialectOp(opMap);
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "found operation " << op.getOperationName() << '\n');
|
||||
ops.insert(&op);
|
||||
}
|
||||
|
||||
// Recurse the arguments of the tree.
|
||||
for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
|
||||
|
@ -406,8 +421,11 @@ void PatternEmitter::emit(StringRef rewriteName) {
|
|||
|
||||
// Collect the set of result operations.
|
||||
llvm::SmallPtrSet<const Operator *, 4> resultOps;
|
||||
for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i)
|
||||
LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
|
||||
for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
|
||||
collectOps(pattern.getResultPattern(i), resultOps);
|
||||
}
|
||||
LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
|
||||
|
||||
// Emit RewritePattern for Pattern.
|
||||
auto locs = pattern.getLocation();
|
||||
|
@ -437,6 +455,8 @@ void PatternEmitter::emit(StringRef rewriteName) {
|
|||
// Register all symbols bound in the source pattern.
|
||||
pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
|
||||
|
||||
LLVM_DEBUG(
|
||||
llvm::dbgs() << "start creating local variables for capturing matches\n");
|
||||
os.indent(4) << "// Variables for capturing values and attributes used for "
|
||||
"creating ops\n";
|
||||
// Create local variables for storing the arguments and results bound
|
||||
|
@ -450,6 +470,8 @@ void PatternEmitter::emit(StringRef rewriteName) {
|
|||
// reused for fused loc.
|
||||
os.indent(4) << formatv("Operation *tblgen_ops[{0}];\n\n",
|
||||
pattern.getSourcePattern().getNumOps());
|
||||
LLVM_DEBUG(
|
||||
llvm::dbgs() << "done creating local variables for capturing matches\n");
|
||||
|
||||
os.indent(4) << "// Match\n";
|
||||
os.indent(4) << "tblgen_ops[0] = op0;\n";
|
||||
|
@ -465,6 +487,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
|
|||
}
|
||||
|
||||
void PatternEmitter::emitRewriteLogic() {
|
||||
LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
|
||||
const Operator &rootOp = pattern.getSourceRootOp();
|
||||
int numExpectedResults = rootOp.getNumResults();
|
||||
int numResultPatterns = pattern.getNumResultPatterns();
|
||||
|
@ -525,6 +548,7 @@ void PatternEmitter::emitRewriteLogic() {
|
|||
}
|
||||
os.indent(4) << "\n";
|
||||
os.indent(4) << "rewriter.replaceOp(op0, tblgen_values);\n";
|
||||
LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
|
||||
}
|
||||
|
||||
std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
|
||||
|
@ -533,6 +557,10 @@ std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
|
|||
|
||||
std::string PatternEmitter::handleResultPattern(DagNode resultTree,
|
||||
int resultIndex, int depth) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
|
||||
LLVM_DEBUG(resultTree.print(llvm::dbgs()));
|
||||
LLVM_DEBUG(llvm::dbgs() << '\n');
|
||||
|
||||
if (resultTree.isNativeCodeCall()) {
|
||||
auto symbol = handleReplaceWithNativeCodeCall(resultTree);
|
||||
symbolInfoMap.bindValue(symbol);
|
||||
|
@ -585,17 +613,27 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
|
|||
return handleConstantAttr(enumCase, val);
|
||||
}
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
|
||||
auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
|
||||
if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
|
||||
<< "' (via symbol ref)\n");
|
||||
return argName;
|
||||
}
|
||||
if (leaf.isNativeCodeCall()) {
|
||||
return tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
|
||||
auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
|
||||
LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
|
||||
<< "' (via NativeCodeCall)\n");
|
||||
return repl;
|
||||
}
|
||||
PrintFatalError(loc, "unhandled case when rewriting op");
|
||||
}
|
||||
|
||||
std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
|
||||
LLVM_DEBUG(tree.print(llvm::dbgs()));
|
||||
LLVM_DEBUG(llvm::dbgs() << '\n');
|
||||
|
||||
auto fmt = tree.getNativeCodeTemplate();
|
||||
// TODO(b/138794486): replace formatv arguments with the exact specified args.
|
||||
SmallVector<std::string, 8> attrs(8);
|
||||
|
@ -605,6 +643,8 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
|
|||
}
|
||||
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
|
||||
attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
|
||||
LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argment #" << i
|
||||
<< " replacement: " << attrs[i] << "\n");
|
||||
}
|
||||
return tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], attrs[4],
|
||||
attrs[5], attrs[6], attrs[7]);
|
||||
|
@ -830,7 +870,11 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
|||
} else {
|
||||
name = p->getName();
|
||||
}
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "=== start generating pattern '" << name << "' ===\n");
|
||||
PatternEmitter(p, &recordOpMap, os).emit(name);
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "=== done generating pattern '" << name << "' ===\n");
|
||||
rewriterNames.push_back(std::move(name));
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue