forked from OSchip/llvm-project
[mlir] Add asserts when changing various MLIRContext configurations
This helps to prevent tsan failures when users inadvertantly mutate the context in a non-safe way. Differential Revision: https://reviews.llvm.org/D112021
This commit is contained in:
parent
9d9eddd3dd
commit
0f304ef017
|
@ -212,6 +212,10 @@ public:
|
|||
addExtension(std::make_unique<Extension>(std::move(extensionFn)));
|
||||
}
|
||||
|
||||
/// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs'
|
||||
/// contains all of the components of this registry.
|
||||
bool isSubsetOf(const DialectRegistry &rhs) const;
|
||||
|
||||
private:
|
||||
MapTy registry;
|
||||
std::vector<std::unique_ptr<DialectExtensionBase>> extensions;
|
||||
|
|
|
@ -228,3 +228,12 @@ void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
|
|||
for (const auto &extension : extensions)
|
||||
applyExtension(*extension);
|
||||
}
|
||||
|
||||
bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {
|
||||
// Treat any extensions conservatively.
|
||||
if (!extensions.empty())
|
||||
return false;
|
||||
// Check that the current dialects fully overlap with the dialects in 'rhs'.
|
||||
return llvm::all_of(
|
||||
registry, [&](const auto &it) { return rhs.registry.count(it.first); });
|
||||
}
|
||||
|
|
|
@ -355,6 +355,12 @@ DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void MLIRContext::appendDialectRegistry(const DialectRegistry ®istry) {
|
||||
if (registry.isSubsetOf(impl->dialectsRegistry))
|
||||
return;
|
||||
|
||||
assert(impl->multiThreadedExecutionContext == 0 &&
|
||||
"appending to the MLIRContext dialect registry while in a "
|
||||
"multi-threaded execution context");
|
||||
registry.appendTo(impl->dialectsRegistry);
|
||||
|
||||
// For the already loaded dialects, apply any possible extensions immediately.
|
||||
|
@ -470,6 +476,9 @@ bool MLIRContext::allowsUnregisteredDialects() {
|
|||
}
|
||||
|
||||
void MLIRContext::allowUnregisteredDialects(bool allowing) {
|
||||
assert(impl->multiThreadedExecutionContext == 0 &&
|
||||
"changing MLIRContext `allow-unregistered-dialects` configuration "
|
||||
"while in a multi-threaded execution context");
|
||||
impl->allowUnregisteredDialects = allowing;
|
||||
}
|
||||
|
||||
|
@ -484,6 +493,9 @@ void MLIRContext::disableMultithreading(bool disable) {
|
|||
// --mlir-disable-threading
|
||||
if (isThreadingGloballyDisabled())
|
||||
return;
|
||||
assert(impl->multiThreadedExecutionContext == 0 &&
|
||||
"changing MLIRContext `disable-threading` configuration while "
|
||||
"in a multi-threaded execution context");
|
||||
|
||||
impl->threadingIsEnabled = !disable;
|
||||
|
||||
|
@ -557,6 +569,9 @@ bool MLIRContext::shouldPrintOpOnDiagnostic() {
|
|||
/// Set the flag specifying if we should attach the operation to diagnostics
|
||||
/// emitted via Operation::emit.
|
||||
void MLIRContext::printOpOnDiagnostic(bool enable) {
|
||||
assert(impl->multiThreadedExecutionContext == 0 &&
|
||||
"changing MLIRContext `print-op-on-diagnostic` configuration while in "
|
||||
"a multi-threaded execution context");
|
||||
impl->printOpOnDiagnostic = enable;
|
||||
}
|
||||
|
||||
|
@ -569,6 +584,9 @@ bool MLIRContext::shouldPrintStackTraceOnDiagnostic() {
|
|||
/// Set the flag specifying if we should attach the current stacktrace when
|
||||
/// emitting diagnostics.
|
||||
void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
|
||||
assert(impl->multiThreadedExecutionContext == 0 &&
|
||||
"changing MLIRContext `print-stacktrace-on-diagnostic` configuration "
|
||||
"while in a multi-threaded execution context");
|
||||
impl->printStackTraceOnDiagnostic = enable;
|
||||
}
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ void OptReductionPass::runOnOperation() {
|
|||
ModuleOp module = this->getOperation();
|
||||
ModuleOp moduleVariant = module.clone();
|
||||
|
||||
PassManager passManager(module.getContext());
|
||||
OpPassManager passManager("builtin.module");
|
||||
if (failed(parsePassPipeline(optPass, passManager))) {
|
||||
module.emitError() << "\nfailed to parse pass pipeline";
|
||||
return signalPassFailure();
|
||||
|
@ -54,7 +54,13 @@ void OptReductionPass::runOnOperation() {
|
|||
return signalPassFailure();
|
||||
}
|
||||
|
||||
if (failed(passManager.run(moduleVariant))) {
|
||||
// Temporarily push the variant under the main module and execute the pipeline
|
||||
// on it.
|
||||
module.getBody()->push_back(moduleVariant);
|
||||
LogicalResult pipelineResult = runPipeline(passManager, moduleVariant);
|
||||
moduleVariant->remove();
|
||||
|
||||
if (failed(pipelineResult)) {
|
||||
module.emitError() << "\nfailed to run pass pipeline";
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
|
|
@ -255,14 +255,13 @@ struct TestLinalgGreedyFusion
|
|||
patterns.add<ExtractSliceOfPadTensorSwapPattern>(context);
|
||||
scf::populateSCFForLoopCanonicalizationPatterns(patterns);
|
||||
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
|
||||
OpPassManager pm(FuncOp::getOperationName());
|
||||
pm.addPass(createLoopInvariantCodeMotionPass());
|
||||
pm.addPass(createCanonicalizerPass());
|
||||
pm.addPass(createCSEPass());
|
||||
do {
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns);
|
||||
PassManager pm(context);
|
||||
pm.addPass(createLoopInvariantCodeMotionPass());
|
||||
pm.addPass(createCanonicalizerPass());
|
||||
pm.addPass(createCSEPass());
|
||||
LogicalResult res = pm.run(getOperation()->getParentOfType<ModuleOp>());
|
||||
if (failed(res))
|
||||
if (failed(runPipeline(pm, getOperation())))
|
||||
this->signalPassFailure();
|
||||
} while (succeeded(fuseLinalgOpsGreedily(getOperation())));
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue