Add pattern file location to generated code to trace origin of pattern.

--

PiperOrigin-RevId: 249734666
This commit is contained in:
Jacques Pienaar 2019-05-23 16:08:52 -07:00 committed by Mehdi Amini
parent 3ccbc0bcec
commit 4165885a90
4 changed files with 35 additions and 0 deletions

View File

@ -220,6 +220,12 @@ public:
// Returns the benefit score of the pattern.
int getBenefit() const;
using IdentifierLine = std::pair<StringRef, unsigned>;
// Returns the file location of the pattern (buffer identifier + line number
// pair).
std::vector<IdentifierLine> getLocation() const;
private:
// Recursively collects all bound arguments inside the DAG tree rooted
// at `tree`.

View File

@ -229,6 +229,20 @@ int tblgen::Pattern::getBenefit() const {
return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
}
std::vector<tblgen::Pattern::IdentifierLine>
tblgen::Pattern::getLocation() const {
std::vector<std::pair<StringRef, unsigned>> result;
result.reserve(def.getLoc().size());
for (auto loc : def.getLoc()) {
unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
assert(buf && "invalid source location");
result.emplace_back(
llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
llvm::SrcMgr.getLineAndColumn(loc, buf).first);
}
return result;
}
void tblgen::Pattern::collectBoundArguments(DagNode tree) {
auto &op = getDialectOp(tree);
auto numOpArgs = op.getNumArgs();

View File

@ -23,6 +23,8 @@ def MyRule : Pat<(OpA $input, $attr), (OpB $input, $attr)>;
// Test rewrite rule naming
// ---
// CHECK: Generated from:
// CHECK-NEXT: {{.*pattern.td.*}}
// CHECK: struct MyRule : public RewritePattern
def : Pat<(OpA $input, $attr), (OpB $input, $attr)>;

View File

@ -30,6 +30,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatAdapters.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
@ -41,6 +42,15 @@ using namespace llvm;
using namespace mlir;
using namespace mlir::tblgen;
namespace llvm {
template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
raw_ostream &os, StringRef style) {
os << v.first << ":" << v.second;
}
};
} // end namespace llvm
// Returns the bound symbol for the given op argument or op named `symbol`.
//
// Arguments and ops bound in the source pattern are grouped into a
@ -445,6 +455,9 @@ void PatternEmitter::emit(StringRef rewriteName) {
loc, "replacing op with variadic results not supported right now");
// Emit RewritePattern for Pattern.
auto locs = pattern.getLocation();
os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
make_range(locs.rbegin(), locs.rend()));
os << formatv(R"(struct {0} : public RewritePattern {
{0}(MLIRContext *context) : RewritePattern("{1}", {2}, context) {{})",
rewriteName, rootName, pattern.getBenefit())