diff --git a/mlir/include/mlir/Support/TypeID.h b/mlir/include/mlir/Support/TypeID.h index 2e4d8b9e8377..d717c77ee3a4 100644 --- a/mlir/include/mlir/Support/TypeID.h +++ b/mlir/include/mlir/Support/TypeID.h @@ -137,6 +137,25 @@ TypeID TypeID::get() { } // end namespace mlir +// Declare/define an explicit specialization for TypeID: this forces the +// compiler to emit a strong definition for a class and controls which +// translation unit and shared object will actually have it. +// This can be useful to turn to a link-time failure what would be in other +// circumstances a hard-to-catch runtime bug when a TypeID is hidden in two +// different shared libraries and instances of the same class only gets the same +// TypeID inside a given DSO. +#define DECLARE_EXPLICIT_TYPE_ID(CLASS_NAME) \ + template <> \ + LLVM_EXTERNAL_VISIBILITY mlir::TypeID \ + mlir::detail::TypeIDExported::get(); +#define DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME) \ + template <> \ + LLVM_EXTERNAL_VISIBILITY mlir::TypeID \ + mlir::detail::TypeIDExported::get() { \ + static mlir::TypeID::Storage instance; \ + return mlir::TypeID(&instance); \ + } + namespace llvm { template <> struct DenseMapInfo { static mlir::TypeID getEmptyKey() { diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp index c8910eab69ec..5b1b80321237 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -440,16 +440,24 @@ bool DefGenerator::emitDecls(StringRef selectedDialect) { collectAllDefs(selectedDialect, defRecords, defs); if (defs.empty()) return false; + { + NamespaceEmitter nsEmitter(os, defs.front().getDialect()); - NamespaceEmitter nsEmitter(os, defs.front().getDialect()); + // Declare all the def classes first (in case they reference each other). + for (const AttrOrTypeDef &def : defs) + os << " class " << def.getCppClassName() << ";\n"; - // Declare all the def classes first (in case they reference each other). + // Emit the declarations. + for (const AttrOrTypeDef &def : defs) + emitDefDecl(def); + } + // Emit the TypeID explicit specializations to have a single definition for + // each of these. for (const AttrOrTypeDef &def : defs) - os << " class " << def.getCppClassName() << ";\n"; + if (!def.getDialect().getCppNamespace().empty()) + os << "DECLARE_EXPLICIT_TYPE_ID(" << def.getDialect().getCppNamespace() + << "::" << def.getCppClassName() << ")\n"; - // Emit the declarations. - for (const AttrOrTypeDef &def : defs) - emitDefDecl(def); return false; } @@ -934,8 +942,13 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) { IfDefScope scope("GET_" + defTypePrefix.upper() + "DEF_CLASSES", os); emitParsePrintDispatch(defs); - for (const AttrOrTypeDef &def : defs) + for (const AttrOrTypeDef &def : defs) { emitDefDef(def); + // Emit the TypeID explicit specializations to have a single symbol def. + if (!def.getDialect().getCppNamespace().empty()) + os << "DEFINE_EXPLICIT_TYPE_ID(" << def.getDialect().getCppNamespace() + << "::" << def.getCppClassName() << ")\n"; + } return false; } diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp index 2ebabc5dd171..2e5b98380538 100644 --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -198,38 +198,44 @@ static void emitDialectDecl(Dialect &dialect, } // Emit all nested namespaces. - NamespaceEmitter nsEmitter(os, dialect); + { + NamespaceEmitter nsEmitter(os, dialect); - // Emit the start of the decl. - std::string cppName = dialect.getCppClassName(); - os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(), - dependentDialectRegistrations); + // Emit the start of the decl. + std::string cppName = dialect.getCppClassName(); + os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(), + dependentDialectRegistrations); - // Check for any attributes/types registered to this dialect. If there are, - // add the hooks for parsing/printing. - if (!dialectAttrs.empty()) - os << attrParserDecl; - if (!dialectTypes.empty()) - os << typeParserDecl; + // Check for any attributes/types registered to this dialect. If there are, + // add the hooks for parsing/printing. + if (!dialectAttrs.empty()) + os << attrParserDecl; + if (!dialectTypes.empty()) + os << typeParserDecl; - // Add the decls for the various features of the dialect. - if (dialect.hasCanonicalizer()) - os << canonicalizerDecl; - if (dialect.hasConstantMaterializer()) - os << constantMaterializerDecl; - if (dialect.hasOperationAttrVerify()) - os << opAttrVerifierDecl; - if (dialect.hasRegionArgAttrVerify()) - os << regionArgAttrVerifierDecl; - if (dialect.hasRegionResultAttrVerify()) - os << regionResultAttrVerifierDecl; - if (dialect.hasOperationInterfaceFallback()) - os << operationInterfaceFallbackDecl; - if (llvm::Optional extraDecl = dialect.getExtraClassDeclaration()) - os << *extraDecl; + // Add the decls for the various features of the dialect. + if (dialect.hasCanonicalizer()) + os << canonicalizerDecl; + if (dialect.hasConstantMaterializer()) + os << constantMaterializerDecl; + if (dialect.hasOperationAttrVerify()) + os << opAttrVerifierDecl; + if (dialect.hasRegionArgAttrVerify()) + os << regionArgAttrVerifierDecl; + if (dialect.hasRegionResultAttrVerify()) + os << regionResultAttrVerifierDecl; + if (dialect.hasOperationInterfaceFallback()) + os << operationInterfaceFallbackDecl; + if (llvm::Optional extraDecl = + dialect.getExtraClassDeclaration()) + os << *extraDecl; - // End the dialect decl. - os << "};\n"; + // End the dialect decl. + os << "};\n"; + } + if (!dialect.getCppNamespace().empty()) + os << "DECLARE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace() + << "::" << dialect.getCppClassName() << ")\n"; } static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper, @@ -263,6 +269,11 @@ static const char *const dialectDestructorStr = R"( )"; static void emitDialectDef(Dialect &dialect, raw_ostream &os) { + // Emit the TypeID explicit specializations to have a single symbol def. + if (!dialect.getCppNamespace().empty()) + os << "DEFINE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace() + << "::" << dialect.getCppClassName() << ")\n"; + // Emit all nested namespaces. NamespaceEmitter nsEmitter(os, dialect); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 2bc9ea465f5e..1c630e03d3e3 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -650,7 +650,6 @@ OpEmitter::OpEmitter(const Operator &op, generateOpFormat(op, opClass); genSideEffectInterfaceMethods(); } - void OpEmitter::emitDecl( const Operator &op, raw_ostream &os, const StaticVerifierFunctionEmitter &staticVerifierEmitter) { @@ -2576,15 +2575,29 @@ static void emitOpClasses(const RecordKeeper &recordKeeper, emitDecl); for (auto *def : defs) { Operator op(*def); - NamespaceEmitter emitter(os, op.getCppNamespace()); if (emitDecl) { - os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations"); - OpOperandAdaptorEmitter::emitDecl(op, os); - OpEmitter::emitDecl(op, os, staticVerifierEmitter); + { + NamespaceEmitter emitter(os, op.getCppNamespace()); + os << formatv(opCommentHeader, op.getQualCppClassName(), + "declarations"); + OpOperandAdaptorEmitter::emitDecl(op, os); + OpEmitter::emitDecl(op, os, staticVerifierEmitter); + } + // Emit the TypeID explicit specialization to have a single definition. + if (!op.getCppNamespace().empty()) + os << "DECLARE_EXPLICIT_TYPE_ID(" << op.getCppNamespace() + << "::" << op.getCppClassName() << ")\n\n"; } else { - os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions"); - OpOperandAdaptorEmitter::emitDef(op, os); - OpEmitter::emitDef(op, os, staticVerifierEmitter); + { + NamespaceEmitter emitter(os, op.getCppNamespace()); + os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions"); + OpOperandAdaptorEmitter::emitDef(op, os); + OpEmitter::emitDef(op, os, staticVerifierEmitter); + } + // Emit the TypeID explicit specialization to have a single definition. + if (!op.getCppNamespace().empty()) + os << "DEFINE_EXPLICIT_TYPE_ID(" << op.getCppNamespace() + << "::" << op.getCppClassName() << ")\n\n"; } } }