forked from OSchip/llvm-project
[mlir] Fix invalidated reference when loading dependent dialects
When a dialect is loaded with `getOrLoadDialect`, its constructor may recurse and call `getOrLoadDialect` on a dependent dialect, which may result in an insertion in the dialect map, invalidating the reference to the (previously null) dialect pointer. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D115846
This commit is contained in:
parent
d08a801b5f
commit
ff459c1f67
|
@ -409,9 +409,9 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
|
||||||
function_ref<std::unique_ptr<Dialect>()> ctor) {
|
function_ref<std::unique_ptr<Dialect>()> ctor) {
|
||||||
auto &impl = getImpl();
|
auto &impl = getImpl();
|
||||||
// Get the correct insertion position sorted by namespace.
|
// Get the correct insertion position sorted by namespace.
|
||||||
std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace];
|
auto dialectIt = impl.loadedDialects.find(dialectNamespace);
|
||||||
|
|
||||||
if (!dialect) {
|
if (dialectIt == impl.loadedDialects.end()) {
|
||||||
LLVM_DEBUG(llvm::dbgs()
|
LLVM_DEBUG(llvm::dbgs()
|
||||||
<< "Load new dialect in Context " << dialectNamespace << "\n");
|
<< "Load new dialect in Context " << dialectNamespace << "\n");
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
|
@ -422,7 +422,8 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
|
||||||
"the PassManager): this can indicate a "
|
"the PassManager): this can indicate a "
|
||||||
"missing `dependentDialects` in a pass for example.");
|
"missing `dependentDialects` in a pass for example.");
|
||||||
#endif
|
#endif
|
||||||
dialect = ctor();
|
std::unique_ptr<Dialect> &dialect =
|
||||||
|
impl.loadedDialects.insert({dialectNamespace, ctor()}).first->second;
|
||||||
assert(dialect && "dialect ctor failed");
|
assert(dialect && "dialect ctor failed");
|
||||||
|
|
||||||
// Refresh all the identifiers dialect field, this catches cases where a
|
// Refresh all the identifiers dialect field, this catches cases where a
|
||||||
|
@ -441,6 +442,7 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Abort if dialect with namespace has already been registered.
|
// Abort if dialect with namespace has already been registered.
|
||||||
|
std::unique_ptr<Dialect> &dialect = dialectIt->second;
|
||||||
if (dialect->getTypeID() != dialectID)
|
if (dialect->getTypeID() != dialectID)
|
||||||
llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
|
llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
|
||||||
"' has already been registered");
|
"' has already been registered");
|
||||||
|
|
Loading…
Reference in New Issue