Emit strong definition for TypeID storage in Op/Type/Attributes definition

By making an explicit template specialization for the TypeID provided by these classes,
the compiler will not emit an inline weak definition and rely on the linker to unique it.
Instead a single definition will be emitted in the C++ file alongside the implementation
for these classes. That will turn into a linker error what is now a hard-to-debug runtime
behavior where instances of the same class may be using a different TypeID inside of
different DSOs.

Differential Revision: https://reviews.llvm.org/D105903
This commit is contained in:
Mehdi Amini 2021-07-28 05:22:45 +00:00
parent 6cba96332b
commit 660a56956c
4 changed files with 99 additions and 43 deletions

View File

@ -137,6 +137,25 @@ TypeID TypeID::get() {
} // end namespace mlir
// Declare/define an explicit specialization for TypeID: this forces the
// compiler to emit a strong definition for a class and controls which
// translation unit and shared object will actually have it.
// This can be useful to turn to a link-time failure what would be in other
// circumstances a hard-to-catch runtime bug when a TypeID is hidden in two
// different shared libraries and instances of the same class only gets the same
// TypeID inside a given DSO.
#define DECLARE_EXPLICIT_TYPE_ID(CLASS_NAME) \
template <> \
LLVM_EXTERNAL_VISIBILITY mlir::TypeID \
mlir::detail::TypeIDExported::get<CLASS_NAME>();
#define DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME) \
template <> \
LLVM_EXTERNAL_VISIBILITY mlir::TypeID \
mlir::detail::TypeIDExported::get<CLASS_NAME>() { \
static mlir::TypeID::Storage instance; \
return mlir::TypeID(&instance); \
}
namespace llvm {
template <> struct DenseMapInfo<mlir::TypeID> {
static mlir::TypeID getEmptyKey() {

View File

@ -440,16 +440,24 @@ bool DefGenerator::emitDecls(StringRef selectedDialect) {
collectAllDefs(selectedDialect, defRecords, defs);
if (defs.empty())
return false;
{
NamespaceEmitter nsEmitter(os, defs.front().getDialect());
NamespaceEmitter nsEmitter(os, defs.front().getDialect());
// Declare all the def classes first (in case they reference each other).
for (const AttrOrTypeDef &def : defs)
os << " class " << def.getCppClassName() << ";\n";
// Declare all the def classes first (in case they reference each other).
// Emit the declarations.
for (const AttrOrTypeDef &def : defs)
emitDefDecl(def);
}
// Emit the TypeID explicit specializations to have a single definition for
// each of these.
for (const AttrOrTypeDef &def : defs)
os << " class " << def.getCppClassName() << ";\n";
if (!def.getDialect().getCppNamespace().empty())
os << "DECLARE_EXPLICIT_TYPE_ID(" << def.getDialect().getCppNamespace()
<< "::" << def.getCppClassName() << ")\n";
// Emit the declarations.
for (const AttrOrTypeDef &def : defs)
emitDefDecl(def);
return false;
}
@ -934,8 +942,13 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
IfDefScope scope("GET_" + defTypePrefix.upper() + "DEF_CLASSES", os);
emitParsePrintDispatch(defs);
for (const AttrOrTypeDef &def : defs)
for (const AttrOrTypeDef &def : defs) {
emitDefDef(def);
// Emit the TypeID explicit specializations to have a single symbol def.
if (!def.getDialect().getCppNamespace().empty())
os << "DEFINE_EXPLICIT_TYPE_ID(" << def.getDialect().getCppNamespace()
<< "::" << def.getCppClassName() << ")\n";
}
return false;
}

View File

@ -198,38 +198,44 @@ static void emitDialectDecl(Dialect &dialect,
}
// Emit all nested namespaces.
NamespaceEmitter nsEmitter(os, dialect);
{
NamespaceEmitter nsEmitter(os, dialect);
// Emit the start of the decl.
std::string cppName = dialect.getCppClassName();
os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(),
dependentDialectRegistrations);
// Emit the start of the decl.
std::string cppName = dialect.getCppClassName();
os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(),
dependentDialectRegistrations);
// Check for any attributes/types registered to this dialect. If there are,
// add the hooks for parsing/printing.
if (!dialectAttrs.empty())
os << attrParserDecl;
if (!dialectTypes.empty())
os << typeParserDecl;
// Check for any attributes/types registered to this dialect. If there are,
// add the hooks for parsing/printing.
if (!dialectAttrs.empty())
os << attrParserDecl;
if (!dialectTypes.empty())
os << typeParserDecl;
// Add the decls for the various features of the dialect.
if (dialect.hasCanonicalizer())
os << canonicalizerDecl;
if (dialect.hasConstantMaterializer())
os << constantMaterializerDecl;
if (dialect.hasOperationAttrVerify())
os << opAttrVerifierDecl;
if (dialect.hasRegionArgAttrVerify())
os << regionArgAttrVerifierDecl;
if (dialect.hasRegionResultAttrVerify())
os << regionResultAttrVerifierDecl;
if (dialect.hasOperationInterfaceFallback())
os << operationInterfaceFallbackDecl;
if (llvm::Optional<StringRef> extraDecl = dialect.getExtraClassDeclaration())
os << *extraDecl;
// Add the decls for the various features of the dialect.
if (dialect.hasCanonicalizer())
os << canonicalizerDecl;
if (dialect.hasConstantMaterializer())
os << constantMaterializerDecl;
if (dialect.hasOperationAttrVerify())
os << opAttrVerifierDecl;
if (dialect.hasRegionArgAttrVerify())
os << regionArgAttrVerifierDecl;
if (dialect.hasRegionResultAttrVerify())
os << regionResultAttrVerifierDecl;
if (dialect.hasOperationInterfaceFallback())
os << operationInterfaceFallbackDecl;
if (llvm::Optional<StringRef> extraDecl =
dialect.getExtraClassDeclaration())
os << *extraDecl;
// End the dialect decl.
os << "};\n";
// End the dialect decl.
os << "};\n";
}
if (!dialect.getCppNamespace().empty())
os << "DECLARE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace()
<< "::" << dialect.getCppClassName() << ")\n";
}
static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
@ -263,6 +269,11 @@ static const char *const dialectDestructorStr = R"(
)";
static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
// Emit the TypeID explicit specializations to have a single symbol def.
if (!dialect.getCppNamespace().empty())
os << "DEFINE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace()
<< "::" << dialect.getCppClassName() << ")\n";
// Emit all nested namespaces.
NamespaceEmitter nsEmitter(os, dialect);

View File

@ -650,7 +650,6 @@ OpEmitter::OpEmitter(const Operator &op,
generateOpFormat(op, opClass);
genSideEffectInterfaceMethods();
}
void OpEmitter::emitDecl(
const Operator &op, raw_ostream &os,
const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
@ -2576,15 +2575,29 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
emitDecl);
for (auto *def : defs) {
Operator op(*def);
NamespaceEmitter emitter(os, op.getCppNamespace());
if (emitDecl) {
os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
OpOperandAdaptorEmitter::emitDecl(op, os);
OpEmitter::emitDecl(op, os, staticVerifierEmitter);
{
NamespaceEmitter emitter(os, op.getCppNamespace());
os << formatv(opCommentHeader, op.getQualCppClassName(),
"declarations");
OpOperandAdaptorEmitter::emitDecl(op, os);
OpEmitter::emitDecl(op, os, staticVerifierEmitter);
}
// Emit the TypeID explicit specialization to have a single definition.
if (!op.getCppNamespace().empty())
os << "DECLARE_EXPLICIT_TYPE_ID(" << op.getCppNamespace()
<< "::" << op.getCppClassName() << ")\n\n";
} else {
os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
OpOperandAdaptorEmitter::emitDef(op, os);
OpEmitter::emitDef(op, os, staticVerifierEmitter);
{
NamespaceEmitter emitter(os, op.getCppNamespace());
os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
OpOperandAdaptorEmitter::emitDef(op, os);
OpEmitter::emitDef(op, os, staticVerifierEmitter);
}
// Emit the TypeID explicit specialization to have a single definition.
if (!op.getCppNamespace().empty())
os << "DEFINE_EXPLICIT_TYPE_ID(" << op.getCppNamespace()
<< "::" << op.getCppClassName() << ")\n\n";
}
}
}