Make the --mlir-disable-threading command line option overrides the C++ API usage

This seems in-line with the intent and how we build tools around it.
Update the description for the flag accordingly.
Also use an injected thread pool in MLIROptMain, now we will create
threads up-front and reuse them across split buffers.

Differential Revision: https://reviews.llvm.org/D109802
This commit is contained in:
Mehdi Amini 2021-09-15 01:38:38 +00:00
parent 500d4c45ba
commit a32300a68f
3 changed files with 41 additions and 14 deletions

View File

@ -129,6 +129,8 @@ public:
bool isMultithreadingEnabled(); bool isMultithreadingEnabled();
/// Set the flag specifying if multi-threading is disabled by the context. /// Set the flag specifying if multi-threading is disabled by the context.
/// The command line debugging flag `--mlir-disable-threading` is overriding
/// this call and making it a no-op!
void disableMultithreading(bool disable = true); void disableMultithreading(bool disable = true);
void enableMultithreading(bool enable = true) { void enableMultithreading(bool enable = true) {
disableMultithreading(!enable); disableMultithreading(!enable);
@ -140,6 +142,9 @@ public:
/// decoupling the lifetime of the threads from the contexts. The thread pool /// decoupling the lifetime of the threads from the contexts. The thread pool
/// must outlive the context. Multi-threading will be enabled as part of this /// must outlive the context. Multi-threading will be enabled as part of this
/// method. /// method.
/// The command line debugging flag `--mlir-disable-threading` will still
/// prevent threading from being enabled and threading won't be enabled after
/// this call in this case.
void setThreadPool(llvm::ThreadPool &pool); void setThreadPool(llvm::ThreadPool &pool);
/// Return the thread pool used by this context. This method requires that /// Return the thread pool used by this context. This method requires that

View File

@ -57,7 +57,8 @@ namespace {
struct MLIRContextOptions { struct MLIRContextOptions {
llvm::cl::opt<bool> disableThreading{ llvm::cl::opt<bool> disableThreading{
"mlir-disable-threading", "mlir-disable-threading",
llvm::cl::desc("Disabling multi-threading within MLIR")}; llvm::cl::desc("Disable multi-threading within MLIR, overrides any "
"further call to MLIRContext::enableMultiThreading()")};
llvm::cl::opt<bool> printOpOnDiagnostic{ llvm::cl::opt<bool> printOpOnDiagnostic{
"mlir-print-op-on-diagnostic", "mlir-print-op-on-diagnostic",
@ -74,6 +75,14 @@ struct MLIRContextOptions {
static llvm::ManagedStatic<MLIRContextOptions> clOptions; static llvm::ManagedStatic<MLIRContextOptions> clOptions;
static bool isThreadingGloballyDisabled() {
#if LLVM_ENABLE_THREADS != 0
return clOptions.isConstructed() && clOptions->disableThreading;
#else
return true;
#endif
}
/// Register a set of useful command-line options that can be used to configure /// Register a set of useful command-line options that can be used to configure
/// various flags within the MLIRContext. These flags are used when constructing /// various flags within the MLIRContext. These flags are used when constructing
/// an MLIR context for initialization. /// an MLIR context for initialization.
@ -362,10 +371,10 @@ MLIRContext::MLIRContext(Threading setting)
: MLIRContext(DialectRegistry(), setting) {} : MLIRContext(DialectRegistry(), setting) {}
MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting) MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
: impl(new MLIRContextImpl(setting == Threading::ENABLED)) { : impl(new MLIRContextImpl(setting == Threading::ENABLED &&
!isThreadingGloballyDisabled())) {
// Initialize values based on the command line flags if they were provided. // Initialize values based on the command line flags if they were provided.
if (clOptions.isConstructed()) { if (clOptions.isConstructed()) {
disableMultithreading(clOptions->disableThreading);
printOpOnDiagnostic(clOptions->printOpOnDiagnostic); printOpOnDiagnostic(clOptions->printOpOnDiagnostic);
printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic); printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic);
} }
@ -582,6 +591,11 @@ bool MLIRContext::isMultithreadingEnabled() {
/// Set the flag specifying if multi-threading is disabled by the context. /// Set the flag specifying if multi-threading is disabled by the context.
void MLIRContext::disableMultithreading(bool disable) { void MLIRContext::disableMultithreading(bool disable) {
// This API can be overridden by the global debugging flag
// --mlir-disable-threading
if (isThreadingGloballyDisabled())
return;
impl->threadingIsEnabled = !disable; impl->threadingIsEnabled = !disable;
// Update the threading mode for each of the uniquers. // Update the threading mode for each of the uniquers.

View File

@ -32,6 +32,7 @@
#include "llvm/Support/Regex.h" #include "llvm/Support/Regex.h"
#include "llvm/Support/SourceMgr.h" #include "llvm/Support/SourceMgr.h"
#include "llvm/Support/StringSaver.h" #include "llvm/Support/StringSaver.h"
#include "llvm/Support/ThreadPool.h"
#include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/ToolOutputFile.h"
using namespace mlir; using namespace mlir;
@ -93,19 +94,22 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
/// Parses the memory buffer. If successfully, run a series of passes against /// Parses the memory buffer. If successfully, run a series of passes against
/// it and print the result. /// it and print the result.
static LogicalResult processBuffer(raw_ostream &os, static LogicalResult
std::unique_ptr<MemoryBuffer> ownedBuffer, processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
bool verifyDiagnostics, bool verifyPasses, bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects, bool allowUnregisteredDialects, bool preloadDialectsInContext,
bool preloadDialectsInContext,
const PassPipelineCLParser &passPipeline, const PassPipelineCLParser &passPipeline,
DialectRegistry &registry) { DialectRegistry &registry, llvm::ThreadPool &threadPool) {
// Tell sourceMgr about this buffer, which is what the parser will pick up. // Tell sourceMgr about this buffer, which is what the parser will pick up.
SourceMgr sourceMgr; SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
// Create a context just for the current buffer. Disable threading on creation
// since we'll inject the thread-pool separately.
MLIRContext context(registry, MLIRContext::Threading::DISABLED);
context.setThreadPool(threadPool);
// Parse the input file. // Parse the input file.
MLIRContext context(registry);
if (preloadDialectsInContext) if (preloadDialectsInContext)
context.loadAllAvailableDialects(); context.loadAllAvailableDialects();
context.allowUnregisteredDialects(allowUnregisteredDialects); context.allowUnregisteredDialects(allowUnregisteredDialects);
@ -143,20 +147,24 @@ LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
bool preloadDialectsInContext) { bool preloadDialectsInContext) {
// The split-input-file mode is a very specific mode that slices the file // The split-input-file mode is a very specific mode that slices the file
// up into small pieces and checks each independently. // up into small pieces and checks each independently.
// We use an explicit threadpool to avoid creating and joining/destroying
// threads for each of the split.
llvm::ThreadPool threadPool;
if (splitInputFile) if (splitInputFile)
return splitAndProcessBuffer( return splitAndProcessBuffer(
std::move(buffer), std::move(buffer),
[&](std::unique_ptr<MemoryBuffer> chunkBuffer, raw_ostream &os) { [&](std::unique_ptr<MemoryBuffer> chunkBuffer, raw_ostream &os) {
return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics, return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics,
verifyPasses, allowUnregisteredDialects, verifyPasses, allowUnregisteredDialects,
preloadDialectsInContext, passPipeline, preloadDialectsInContext, passPipeline, registry,
registry); threadPool);
}, },
outputStream); outputStream);
return processBuffer(outputStream, std::move(buffer), verifyDiagnostics, return processBuffer(outputStream, std::move(buffer), verifyDiagnostics,
verifyPasses, allowUnregisteredDialects, verifyPasses, allowUnregisteredDialects,
preloadDialectsInContext, passPipeline, registry); preloadDialectsInContext, passPipeline, registry,
threadPool);
} }
LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName, LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,