[mlir] Fix invalidated reference when loading dependent dialects

When a dialect is loaded with `getOrLoadDialect`, its constructor may recurse and call `getOrLoadDialect` on a dependent dialect, which may result in an insertion in the dialect map, invalidating the reference to the (previously null) dialect pointer.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D115846
This commit is contained in:
Mogball 2021-12-16 01:19:56 +00:00
parent d08a801b5f
commit ff459c1f67
1 changed files with 5 additions and 3 deletions

View File

@ -409,9 +409,9 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
function_ref<std::unique_ptr<Dialect>()> ctor) { function_ref<std::unique_ptr<Dialect>()> ctor) {
auto &impl = getImpl(); auto &impl = getImpl();
// Get the correct insertion position sorted by namespace. // Get the correct insertion position sorted by namespace.
std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace]; auto dialectIt = impl.loadedDialects.find(dialectNamespace);
if (!dialect) { if (dialectIt == impl.loadedDialects.end()) {
LLVM_DEBUG(llvm::dbgs() LLVM_DEBUG(llvm::dbgs()
<< "Load new dialect in Context " << dialectNamespace << "\n"); << "Load new dialect in Context " << dialectNamespace << "\n");
#ifndef NDEBUG #ifndef NDEBUG
@ -422,7 +422,8 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
"the PassManager): this can indicate a " "the PassManager): this can indicate a "
"missing `dependentDialects` in a pass for example."); "missing `dependentDialects` in a pass for example.");
#endif #endif
dialect = ctor(); std::unique_ptr<Dialect> &dialect =
impl.loadedDialects.insert({dialectNamespace, ctor()}).first->second;
assert(dialect && "dialect ctor failed"); assert(dialect && "dialect ctor failed");
// Refresh all the identifiers dialect field, this catches cases where a // Refresh all the identifiers dialect field, this catches cases where a
@ -441,6 +442,7 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
} }
// Abort if dialect with namespace has already been registered. // Abort if dialect with namespace has already been registered.
std::unique_ptr<Dialect> &dialect = dialectIt->second;
if (dialect->getTypeID() != dialectID) if (dialect->getTypeID() != dialectID)
llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace + llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
"' has already been registered"); "' has already been registered");