forked from OSchip/llvm-project
Add pattern file location to generated code to trace origin of pattern.
-- PiperOrigin-RevId: 249734666
This commit is contained in:
parent
3ccbc0bcec
commit
4165885a90
|
@ -220,6 +220,12 @@ public:
|
|||
// Returns the benefit score of the pattern.
|
||||
int getBenefit() const;
|
||||
|
||||
using IdentifierLine = std::pair<StringRef, unsigned>;
|
||||
|
||||
// Returns the file location of the pattern (buffer identifier + line number
|
||||
// pair).
|
||||
std::vector<IdentifierLine> getLocation() const;
|
||||
|
||||
private:
|
||||
// Recursively collects all bound arguments inside the DAG tree rooted
|
||||
// at `tree`.
|
||||
|
|
|
@ -229,6 +229,20 @@ int tblgen::Pattern::getBenefit() const {
|
|||
return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
|
||||
}
|
||||
|
||||
std::vector<tblgen::Pattern::IdentifierLine>
|
||||
tblgen::Pattern::getLocation() const {
|
||||
std::vector<std::pair<StringRef, unsigned>> result;
|
||||
result.reserve(def.getLoc().size());
|
||||
for (auto loc : def.getLoc()) {
|
||||
unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
|
||||
assert(buf && "invalid source location");
|
||||
result.emplace_back(
|
||||
llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
|
||||
llvm::SrcMgr.getLineAndColumn(loc, buf).first);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void tblgen::Pattern::collectBoundArguments(DagNode tree) {
|
||||
auto &op = getDialectOp(tree);
|
||||
auto numOpArgs = op.getNumArgs();
|
||||
|
|
|
@ -23,6 +23,8 @@ def MyRule : Pat<(OpA $input, $attr), (OpB $input, $attr)>;
|
|||
// Test rewrite rule naming
|
||||
// ---
|
||||
|
||||
// CHECK: Generated from:
|
||||
// CHECK-NEXT: {{.*pattern.td.*}}
|
||||
// CHECK: struct MyRule : public RewritePattern
|
||||
|
||||
def : Pat<(OpA $input, $attr), (OpB $input, $attr)>;
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/FormatAdapters.h"
|
||||
#include "llvm/Support/PrettyStackTrace.h"
|
||||
#include "llvm/Support/Signals.h"
|
||||
#include "llvm/TableGen/Error.h"
|
||||
|
@ -41,6 +42,15 @@ using namespace llvm;
|
|||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
|
||||
namespace llvm {
|
||||
template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
|
||||
static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
|
||||
raw_ostream &os, StringRef style) {
|
||||
os << v.first << ":" << v.second;
|
||||
}
|
||||
};
|
||||
} // end namespace llvm
|
||||
|
||||
// Returns the bound symbol for the given op argument or op named `symbol`.
|
||||
//
|
||||
// Arguments and ops bound in the source pattern are grouped into a
|
||||
|
@ -445,6 +455,9 @@ void PatternEmitter::emit(StringRef rewriteName) {
|
|||
loc, "replacing op with variadic results not supported right now");
|
||||
|
||||
// Emit RewritePattern for Pattern.
|
||||
auto locs = pattern.getLocation();
|
||||
os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
|
||||
make_range(locs.rbegin(), locs.rend()));
|
||||
os << formatv(R"(struct {0} : public RewritePattern {
|
||||
{0}(MLIRContext *context) : RewritePattern("{1}", {2}, context) {{})",
|
||||
rewriteName, rootName, pattern.getBenefit())
|
||||
|
|
Loading…
Reference in New Issue