[mlir] Add asserts when changing various MLIRContext configurations

This helps to prevent tsan failures when users inadvertantly mutate the
context in a non-safe way.

Differential Revision: https://reviews.llvm.org/D112021
This commit is contained in:
River Riddle 2022-01-18 16:16:54 -08:00
parent 9d9eddd3dd
commit 0f304ef017
5 changed files with 44 additions and 8 deletions

View File

@ -212,6 +212,10 @@ public:
addExtension(std::make_unique<Extension>(std::move(extensionFn)));
}
/// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs'
/// contains all of the components of this registry.
bool isSubsetOf(const DialectRegistry &rhs) const;
private:
MapTy registry;
std::vector<std::unique_ptr<DialectExtensionBase>> extensions;

View File

@ -228,3 +228,12 @@ void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
for (const auto &extension : extensions)
applyExtension(*extension);
}
bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {
// Treat any extensions conservatively.
if (!extensions.empty())
return false;
// Check that the current dialects fully overlap with the dialects in 'rhs'.
return llvm::all_of(
registry, [&](const auto &it) { return rhs.registry.count(it.first); });
}

View File

@ -355,6 +355,12 @@ DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
//===----------------------------------------------------------------------===//
void MLIRContext::appendDialectRegistry(const DialectRegistry &registry) {
if (registry.isSubsetOf(impl->dialectsRegistry))
return;
assert(impl->multiThreadedExecutionContext == 0 &&
"appending to the MLIRContext dialect registry while in a "
"multi-threaded execution context");
registry.appendTo(impl->dialectsRegistry);
// For the already loaded dialects, apply any possible extensions immediately.
@ -470,6 +476,9 @@ bool MLIRContext::allowsUnregisteredDialects() {
}
void MLIRContext::allowUnregisteredDialects(bool allowing) {
assert(impl->multiThreadedExecutionContext == 0 &&
"changing MLIRContext `allow-unregistered-dialects` configuration "
"while in a multi-threaded execution context");
impl->allowUnregisteredDialects = allowing;
}
@ -484,6 +493,9 @@ void MLIRContext::disableMultithreading(bool disable) {
// --mlir-disable-threading
if (isThreadingGloballyDisabled())
return;
assert(impl->multiThreadedExecutionContext == 0 &&
"changing MLIRContext `disable-threading` configuration while "
"in a multi-threaded execution context");
impl->threadingIsEnabled = !disable;
@ -557,6 +569,9 @@ bool MLIRContext::shouldPrintOpOnDiagnostic() {
/// Set the flag specifying if we should attach the operation to diagnostics
/// emitted via Operation::emit.
void MLIRContext::printOpOnDiagnostic(bool enable) {
assert(impl->multiThreadedExecutionContext == 0 &&
"changing MLIRContext `print-op-on-diagnostic` configuration while in "
"a multi-threaded execution context");
impl->printOpOnDiagnostic = enable;
}
@ -569,6 +584,9 @@ bool MLIRContext::shouldPrintStackTraceOnDiagnostic() {
/// Set the flag specifying if we should attach the current stacktrace when
/// emitting diagnostics.
void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
assert(impl->multiThreadedExecutionContext == 0 &&
"changing MLIRContext `print-stacktrace-on-diagnostic` configuration "
"while in a multi-threaded execution context");
impl->printStackTraceOnDiagnostic = enable;
}

View File

@ -42,7 +42,7 @@ void OptReductionPass::runOnOperation() {
ModuleOp module = this->getOperation();
ModuleOp moduleVariant = module.clone();
PassManager passManager(module.getContext());
OpPassManager passManager("builtin.module");
if (failed(parsePassPipeline(optPass, passManager))) {
module.emitError() << "\nfailed to parse pass pipeline";
return signalPassFailure();
@ -54,7 +54,13 @@ void OptReductionPass::runOnOperation() {
return signalPassFailure();
}
if (failed(passManager.run(moduleVariant))) {
// Temporarily push the variant under the main module and execute the pipeline
// on it.
module.getBody()->push_back(moduleVariant);
LogicalResult pipelineResult = runPipeline(passManager, moduleVariant);
moduleVariant->remove();
if (failed(pipelineResult)) {
module.emitError() << "\nfailed to run pass pipeline";
return signalPassFailure();
}

View File

@ -255,14 +255,13 @@ struct TestLinalgGreedyFusion
patterns.add<ExtractSliceOfPadTensorSwapPattern>(context);
scf::populateSCFForLoopCanonicalizationPatterns(patterns);
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
OpPassManager pm(FuncOp::getOperationName());
pm.addPass(createLoopInvariantCodeMotionPass());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
do {
(void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns);
PassManager pm(context);
pm.addPass(createLoopInvariantCodeMotionPass());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
LogicalResult res = pm.run(getOperation()->getParentOfType<ModuleOp>());
if (failed(res))
if (failed(runPipeline(pm, getOperation())))
this->signalPassFailure();
} while (succeeded(fuseLinalgOpsGreedily(getOperation())));
}