Add LLVM_DEBUG in RewritersGen.cpp and Pattern.cpp

It's usually hard to understand what went wrong if mlir-tblgen
crashes on some input. This CL adds a few useful LLVM_DEBUG
statements so that we can use mlir-tblegn -debug to figure
out the culprit for a crash.

PiperOrigin-RevId: 275253532
This commit is contained in:
Lei Zhang 2019-10-17 07:25:50 -07:00 committed by A. Unique TensorFlower
parent bdc250c5a7
commit 1358df19ca
3 changed files with 105 additions and 14 deletions

View File

@ -104,6 +104,8 @@ public:
// Precondition: isNativeCodeCall()
StringRef getNativeCodeTemplate() const;
void print(raw_ostream &os) const;
private:
// Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and
// also a subclass of the given `superclass`.
@ -176,6 +178,8 @@ public:
// Precondition: isNativeCodeCall()
StringRef getNativeCodeTemplate() const;
void print(raw_ostream &os) const;
private:
const llvm::DagInit *node; // nullptr means null DagNode
};

View File

@ -22,10 +22,13 @@
#include "mlir/TableGen/Pattern.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#define DEBUG_TYPE "mlir-tblgen-pattern"
using namespace mlir;
using llvm::formatv;
@ -92,6 +95,11 @@ bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const {
return false;
}
void tblgen::DagLeaf::print(raw_ostream &os) const {
if (def)
def->print(os);
}
//===----------------------------------------------------------------------===//
// DagNode
//===----------------------------------------------------------------------===//
@ -159,6 +167,11 @@ bool tblgen::DagNode::isReplaceWithValue() const {
return dagOpDef->getName() == "replaceWithValue";
}
void tblgen::DagNode::print(raw_ostream &os) const {
if (node)
node->print(os);
}
//===----------------------------------------------------------------------===//
// SymbolInfoMap
//===----------------------------------------------------------------------===//
@ -222,10 +235,13 @@ tblgen::SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
std::string tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
StringRef name, int index, const char *fmt, const char *separator) const {
LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': ");
switch (kind) {
case Kind::Attr: {
assert(index < 0);
return formatv(fmt, name);
auto repl = formatv(fmt, name);
LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n");
return repl;
}
case Kind::Operand: {
assert(index < 0);
@ -233,9 +249,13 @@ std::string tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
// If this operand is variadic, then return a range. Otherwise, return the
// value itself.
if (operand->isVariadic()) {
return formatv(fmt, name);
auto repl = formatv(fmt, name);
LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
return repl;
}
return formatv(fmt, formatv("(*{0}.begin())", name));
auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n");
return repl;
}
case Kind::Result: {
// If `index` is greater than zero, then we are referencing a specific
@ -244,7 +264,9 @@ std::string tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
std::string v = formatv("{0}.getODSResults({1})", name, index);
if (!op->getResult(index).isVariadic())
v = formatv("(*{0}.begin())", v);
return formatv(fmt, v);
auto repl = formatv(fmt, v);
LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
return repl;
}
// We are referencing all results of the multi-result op. A specific result
@ -259,12 +281,16 @@ std::string tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
}
values.push_back(formatv(fmt, v));
}
return llvm::join(values, separator);
auto repl = llvm::join(values, separator);
LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
return repl;
}
case Kind::Value: {
assert(index < 0);
assert(op == nullptr);
return formatv(fmt, name);
auto repl = formatv(fmt, name);
LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
return repl;
}
}
llvm_unreachable("unknown kind");
@ -272,15 +298,20 @@ std::string tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
std::string tblgen::SymbolInfoMap::SymbolInfo::getAllRangeUse(
StringRef name, int index, const char *fmt, const char *separator) const {
LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': ");
switch (kind) {
case Kind::Attr:
case Kind::Operand: {
assert(index < 0 && "only allowed for symbol bound to result");
return formatv(fmt, name);
auto repl = formatv(fmt, name);
LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n");
return repl;
}
case Kind::Result: {
if (index >= 0) {
return formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
return repl;
}
// We are referencing all results of the multi-result op. Each result should
@ -292,12 +323,16 @@ std::string tblgen::SymbolInfoMap::SymbolInfo::getAllRangeUse(
values.push_back(
formatv(fmt, formatv("{0}.getODSResults({1})", name, i)));
}
return llvm::join(values, separator);
auto repl = llvm::join(values, separator);
LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
return repl;
}
case Kind::Value: {
assert(index < 0 && "only allowed for symbol bound to result");
assert(op == nullptr);
return formatv(fmt, formatv("{{{0}}", name));
auto repl = formatv(fmt, formatv("{{{0}}", name));
LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
return repl;
}
}
llvm_unreachable("unknown kind");
@ -402,15 +437,19 @@ tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const {
void tblgen::Pattern::collectSourcePatternBoundSymbols(
tblgen::SymbolInfoMap &infoMap) {
LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
}
void tblgen::Pattern::collectResultPatternBoundSymbols(
tblgen::SymbolInfoMap &infoMap) {
LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n");
for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
auto pattern = getResultPattern(i);
collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
}
LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n");
}
const tblgen::Operator &tblgen::Pattern::getSourceRootOp() {
@ -496,6 +535,8 @@ void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
// The name attached to the DAG node's operator is for representing the
// results generated from this op. It should be remembered as bound results.
if (!treeName.empty()) {
LLVM_DEBUG(llvm::dbgs()
<< "found symbol bound to op result: " << treeName << '\n');
if (!infoMap.bindOpResult(treeName, op))
PrintFatalError(def.getLoc(),
formatv("symbol '{0}' bound more than once", treeName));
@ -510,6 +551,8 @@ void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
// symbols are referenced in result patterns.
auto treeArgName = tree.getArgName(i);
if (!treeArgName.empty()) {
LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
<< treeArgName << '\n');
if (!infoMap.bindOpArgument(treeArgName, op, i)) {
auto err = formatv("symbol '{0}' bound more than once", treeArgName);
PrintFatalError(def.getLoc(), err);

View File

@ -30,6 +30,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatAdapters.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/Signals.h"
@ -42,6 +43,8 @@ using namespace llvm;
using namespace mlir;
using namespace mlir::tblgen;
#define DEBUG_TYPE "mlir-tblgen-rewritergen"
namespace llvm {
template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
@ -191,6 +194,9 @@ std::string PatternEmitter::handleConstantAttr(Attribute attr,
// Helper function to match patterns.
void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
Operator &op = tree.getDialectOp(opMap);
LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
<< op.getOperationName() << "' at depth " << depth
<< '\n');
int indent = 4 + 2 * depth;
os.indent(indent) << formatv(
@ -249,6 +255,9 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
PrintFatalError(loc, "unhandled case when matching op");
}
}
LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
<< op.getOperationName() << "' at depth " << depth
<< '\n');
}
void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
@ -346,6 +355,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
}
void PatternEmitter::emitMatchLogic(DagNode tree) {
LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
emitOpMatch(tree, 0);
for (auto &appliedConstraint : pattern.getConstraints()) {
@ -383,13 +393,18 @@ void PatternEmitter::emitMatchLogic(DagNode tree) {
names[1], names[2], names[3]));
}
}
LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
}
void PatternEmitter::collectOps(DagNode tree,
llvm::SmallPtrSetImpl<const Operator *> &ops) {
// Check if this tree is an operation.
if (tree.isOperation())
ops.insert(&tree.getDialectOp(opMap));
if (tree.isOperation()) {
const Operator &op = tree.getDialectOp(opMap);
LLVM_DEBUG(llvm::dbgs()
<< "found operation " << op.getOperationName() << '\n');
ops.insert(&op);
}
// Recurse the arguments of the tree.
for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
@ -406,8 +421,11 @@ void PatternEmitter::emit(StringRef rewriteName) {
// Collect the set of result operations.
llvm::SmallPtrSet<const Operator *, 4> resultOps;
for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i)
LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
collectOps(pattern.getResultPattern(i), resultOps);
}
LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
// Emit RewritePattern for Pattern.
auto locs = pattern.getLocation();
@ -437,6 +455,8 @@ void PatternEmitter::emit(StringRef rewriteName) {
// Register all symbols bound in the source pattern.
pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
LLVM_DEBUG(
llvm::dbgs() << "start creating local variables for capturing matches\n");
os.indent(4) << "// Variables for capturing values and attributes used for "
"creating ops\n";
// Create local variables for storing the arguments and results bound
@ -450,6 +470,8 @@ void PatternEmitter::emit(StringRef rewriteName) {
// reused for fused loc.
os.indent(4) << formatv("Operation *tblgen_ops[{0}];\n\n",
pattern.getSourcePattern().getNumOps());
LLVM_DEBUG(
llvm::dbgs() << "done creating local variables for capturing matches\n");
os.indent(4) << "// Match\n";
os.indent(4) << "tblgen_ops[0] = op0;\n";
@ -465,6 +487,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
}
void PatternEmitter::emitRewriteLogic() {
LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
const Operator &rootOp = pattern.getSourceRootOp();
int numExpectedResults = rootOp.getNumResults();
int numResultPatterns = pattern.getNumResultPatterns();
@ -525,6 +548,7 @@ void PatternEmitter::emitRewriteLogic() {
}
os.indent(4) << "\n";
os.indent(4) << "rewriter.replaceOp(op0, tblgen_values);\n";
LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
}
std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
@ -533,6 +557,10 @@ std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
std::string PatternEmitter::handleResultPattern(DagNode resultTree,
int resultIndex, int depth) {
LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
LLVM_DEBUG(resultTree.print(llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << '\n');
if (resultTree.isNativeCodeCall()) {
auto symbol = handleReplaceWithNativeCodeCall(resultTree);
symbolInfoMap.bindValue(symbol);
@ -585,17 +613,27 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
return handleConstantAttr(enumCase, val);
}
LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
<< "' (via symbol ref)\n");
return argName;
}
if (leaf.isNativeCodeCall()) {
return tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
<< "' (via NativeCodeCall)\n");
return repl;
}
PrintFatalError(loc, "unhandled case when rewriting op");
}
std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
LLVM_DEBUG(tree.print(llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << '\n');
auto fmt = tree.getNativeCodeTemplate();
// TODO(b/138794486): replace formatv arguments with the exact specified args.
SmallVector<std::string, 8> attrs(8);
@ -605,6 +643,8 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
}
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argment #" << i
<< " replacement: " << attrs[i] << "\n");
}
return tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], attrs[4],
attrs[5], attrs[6], attrs[7]);
@ -830,7 +870,11 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
} else {
name = p->getName();
}
LLVM_DEBUG(llvm::dbgs()
<< "=== start generating pattern '" << name << "' ===\n");
PatternEmitter(p, &recordOpMap, os).emit(name);
LLVM_DEBUG(llvm::dbgs()
<< "=== done generating pattern '" << name << "' ===\n");
rewriterNames.push_back(std::move(name));
}