diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h index 66db2d403d35..fac9194ce6e4 100644 --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -49,6 +49,9 @@ public: // Returns the class name of the op. StringRef cppClassName(); + // Returns the class name of the op with namespace added. + std::string qualifiedCppClassName(); + // Operations attribute accessors. struct Attribute { llvm::StringInit *name; diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index 8e742a6721a0..c4021d3cae0d 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -44,6 +44,9 @@ StringRef Operator::getOperationName() const { } StringRef Operator::cppClassName() { return getSplitDefName().back(); } +std::string Operator::qualifiedCppClassName() { + return llvm::join(getSplitDefName(), "::"); +} StringRef Operator::getArgName(int index) const { DagInit *argumentValues = def.getValueAsDag("arguments"); diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index e4ba02b41ae5..6a5b353d0f49 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -137,14 +137,9 @@ static void matchOp(DagInit *tree, int depth, raw_ostream &os) { if (depth != 0) { // Skip if there is no defining instruction (e.g., arguments to function). os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth); - // TODO(jpienaar): This is bad, we should not be checking strings here, we - // should be matching using mOp (and helpers). Currently doing this to allow - // for TF ops that aren't registed. Fix it. os.indent(indent) << formatv( - "if (op{0}->getName().getStringRef() != \"{1}\")", - depth, op.getOperationName()) - << "\n"; - os.indent(indent + 2) << "return matchFailure();\n"; + "if (!op{0}->isa<{1}>()) return matchFailure();\n", depth, + op.qualifiedCppClassName()); } for (int i = 0, e = tree->getNumArgs(); i != e; ++i) { auto arg = tree->getArg(i);