diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index c701f0787507..eaaf5b75230e 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -791,16 +791,12 @@ class Attr : // instantiation. // TOOD(b/132458159): deduplicate the fields in attribute wrapper classes. Attr baseAttr = ?; - - // The fully-qualified C++ namespace where the generated class lives. - string cppNamespace = ""; } // An attribute of a specific dialect. class DialectAttr : Attr { Dialect dialect = d; - let cppNamespace = d.cppNamespace; } //===----------------------------------------------------------------------===// @@ -1119,6 +1115,16 @@ class EnumAttrInfo cases> { // underlying type is not explicitly specified. string underlyingType = ""; + // The C++ namespaces that the enum class definition and utility functions + // should be placed into. + // + // Normally you want to place the full namespace path here. If it is nested, + // use "::" as the delimiter, e.g., given "A::B", generated code will be + // placed in `namespace A { namespace B { ... } }`. To avoid placing in any + // namespace, use "". + // TODO: use dialect to provide the namespace. + string cppNamespace = ""; + // The name of the utility function that converts a value of the underlying // type to the corresponding symbol. It will have the following signature: // @@ -1457,8 +1463,7 @@ class StructFieldAttr { // useful when representing data that would normally be in a structure. class StructAttr attributes> : - DictionaryAttrBase()">, + DictionaryAttrBase()">, "DictionaryAttr with field(s): " # StrJoin.result # " (each field having its own constraints)"> { @@ -1466,16 +1471,14 @@ class StructAttr fields = attributes; } diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index cfa3da2c417a..b32775318df8 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -219,11 +219,11 @@ std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { case Kind::Operand: { // Use operand range for captured operands (to support potential variadic // operands). - return std::string(formatv( - "::mlir::Operation::operand_range {0}(op0->getOperands());\n", name)); + return std::string( + formatv("Operation::operand_range {0}(op0->getOperands());\n", name)); } case Kind::Value: { - return std::string(formatv("::llvm::ArrayRef<::mlir::Value> {0};\n", name)); + return std::string(formatv("ArrayRef {0};\n", name)); } case Kind::Result: { // Use the op itself for captured results. diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 3bcf02114555..5345bc598da9 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -487,7 +487,16 @@ void OpEmitter::genAttrGetters() { // Emit with return type specified. auto emitAttrWithReturnType = [&](StringRef name, Attribute attr) { - auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name); + Dialect attrDialect = attr.getDialect(); + // Does the current operation have a different namespace than the attribute? + bool differentNamespace = + attrDialect && opDialect && attrDialect != opDialect; + std::string returnType = differentNamespace + ? (llvm::Twine(attrDialect.getCppNamespace()) + + "::" + attr.getReturnType()) + .str() + : attr.getReturnType().str(); + auto *method = opClass.addMethodAndPrune(returnType, name); auto &body = method->body(); body << " auto attr = " << name << "Attr();\n"; if (attr.hasDefaultValue()) { @@ -1991,8 +2000,8 @@ void OpEmitter::genOpAsmInterface() { opClass.addTrait("::mlir::OpAsmOpInterface::Trait"); // Generate the right accessor for the number of results. - auto *method = opClass.addMethodAndPrune( - "void", "getAsmResultNames", "::mlir::OpAsmSetValueNameFn", "setNameFn"); + auto *method = opClass.addMethodAndPrune("void", "getAsmResultNames", + "OpAsmSetValueNameFn", "setNameFn"); auto &body = method->body(); for (int i = 0; i != numResults; ++i) { body << " auto resultGroup" << i << " = getODSResults(" << i << ");\n" diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 495fbe1715e0..e16900227759 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -221,8 +221,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) { int indent = 4 + 2 * depth; os.indent(indent) << formatv( - "auto castedOp{0} = ::llvm::dyn_cast_or_null<{1}>(op{0}); " - "(void)castedOp{0};\n", + "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n", depth, op.getQualCppClassName()); // Skip the operand matching at depth 0 as the pattern rewriter already does. if (depth != 0) { @@ -536,7 +535,7 @@ void PatternEmitter::emit(StringRef rewriteName) { os << "\n// Rewrite\n"; emitRewriteLogic(); - os << "return ::mlir::success();\n"; + os << "return success();\n"; } os << "};\n"; } @@ -1146,8 +1145,8 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { } // Emit function to add the generated matchers to the pattern list. - os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(::mlir::MLIRContext " - "*context, ::mlir::OwningRewritePatternList *patterns) {\n"; + os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(MLIRContext " + "*context, OwningRewritePatternList *patterns) {\n"; for (const auto &name : rewriterNames) { os << " patterns->insert<" << name << ">(context);\n"; }