diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h index 1b75712e0831..f5eb9a37ef04 100644 --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -220,6 +220,12 @@ public: // Returns the benefit score of the pattern. int getBenefit() const; + using IdentifierLine = std::pair; + + // Returns the file location of the pattern (buffer identifier + line number + // pair). + std::vector getLocation() const; + private: // Recursively collects all bound arguments inside the DAG tree rooted // at `tree`. diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index 285f1c9bd535..31bab8172e3b 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -229,6 +229,20 @@ int tblgen::Pattern::getBenefit() const { return initBenefit + dyn_cast(delta->getArg(0))->getValue(); } +std::vector +tblgen::Pattern::getLocation() const { + std::vector> result; + result.reserve(def.getLoc().size()); + for (auto loc : def.getLoc()) { + unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc); + assert(buf && "invalid source location"); + result.emplace_back( + llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(), + llvm::SrcMgr.getLineAndColumn(loc, buf).first); + } + return result; +} + void tblgen::Pattern::collectBoundArguments(DagNode tree) { auto &op = getDialectOp(tree); auto numOpArgs = op.getNumArgs(); diff --git a/mlir/test/mlir-tblgen/pattern.td b/mlir/test/mlir-tblgen/pattern.td index bb5055af56f2..b5a6c60731c9 100644 --- a/mlir/test/mlir-tblgen/pattern.td +++ b/mlir/test/mlir-tblgen/pattern.td @@ -23,6 +23,8 @@ def MyRule : Pat<(OpA $input, $attr), (OpB $input, $attr)>; // Test rewrite rule naming // --- +// CHECK: Generated from: +// CHECK-NEXT: {{.*pattern.td.*}} // CHECK: struct MyRule : public RewritePattern def : Pat<(OpA $input, $attr), (OpB $input, $attr)>; diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 57068e2bff25..b00be1c8c95c 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -30,6 +30,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/FormatAdapters.h" #include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" @@ -41,6 +42,15 @@ using namespace llvm; using namespace mlir; using namespace mlir::tblgen; +namespace llvm { +template <> struct format_provider { + static void format(const mlir::tblgen::Pattern::IdentifierLine &v, + raw_ostream &os, StringRef style) { + os << v.first << ":" << v.second; + } +}; +} // end namespace llvm + // Returns the bound symbol for the given op argument or op named `symbol`. // // Arguments and ops bound in the source pattern are grouped into a @@ -445,6 +455,9 @@ void PatternEmitter::emit(StringRef rewriteName) { loc, "replacing op with variadic results not supported right now"); // Emit RewritePattern for Pattern. + auto locs = pattern.getLocation(); + os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n", + make_range(locs.rbegin(), locs.rend())); os << formatv(R"(struct {0} : public RewritePattern { {0}(MLIRContext *context) : RewritePattern("{1}", {2}, context) {{})", rewriteName, rootName, pattern.getBenefit())