Refactor DiagnosticEngine to support multiple registered diagnostic handlers.

This fixes a problem with current save-restore pattern of diagnostics handlers, as there may be a thread race between when the previous handler is destroyed. For example, this occurs when using multiple ParallelDiagnosticHandlers asynchronously:

Handler A
Handler B | - LifeTime - |    Restore A here.
Handler C | --- LifeTime ---| Restore B after it has been destroyed.

The new design allows for multiple handlers to be registered in a stack like fashion. Handlers can return success() to signal that they have fully processed a diagnostic, or failure to propagate otherwise.

PiperOrigin-RevId: 270720625
This commit is contained in:
River Riddle 2019-09-23 11:24:28 -07:00 committed by A. Unique TensorFlower
parent 3eade43046
commit c61991ef01
4 changed files with 162 additions and 113 deletions

View File

@ -72,13 +72,33 @@ represents an unspecified source location.
The `DiagnosticEngine` acts as the main interface for diagnostics in MLIR. It
manages the registration of diagnostic handlers, as well as the core API for
diagnostic emission. It can be interfaced with via an `MLIRContext` instance.
diagnostic emission. Handlers generally take the form of
`LogicalResult(Diagnostic &)`. If the result is `success`, it signals that the
diagnostic has been fully processed and consumed. If `failure`, it signals that
the diagnostic should be propagated to any previously registered handlers. It
can be interfaced with via an `MLIRContext` instance.
```c++
DiagnosticEngine engine = ctx->getDiagEngine();
engine.setHandler([](Diagnostic diag) {
// Handle the reported diagnostic.
/// Handle the reported diagnostic.
// Return success to signal that the diagnostic has either been fully processed,
// or failure if the diagnostic should be propagated to the previous handlers.
DiagnosticEngine::HandlerID id = engine.registerHandler(
[](Diagnostic &diag) -> LogicalResult {
bool should_propage_diagnostic = ...;
return failure(should_propage_diagnostic);
});
// We can also elide the return value completely, in which the engine assumes
// that all diagnostics are consumed(i.e. a success() result).
DiagnosticEngine::HandlerID id = engine.registerHandler([](Diagnostic &diag) {
return;
});
// Unregister this handler when we are done.
engine.eraseHandler(id);
```
### Constructing a Diagnostic
@ -179,21 +199,22 @@ provides several common diagnostic handlers for immediate use.
### Scoped Diagnostic Handler
This diagnostic handler is a simple RAII class that saves and restores the
current diagnostic handler registered to a given context. This class can be
either be used directly, or in conjunction with a derived diagnostic handler.
This diagnostic handler is a simple RAII class that registers and unregisters a
given diagnostic handler. This class can be either be used directly, or in
conjunction with a derived diagnostic handler.
```c++
// Construct the handler directly.
MLIRContext context;
ScopedDiagnosticHandler scopedHandler(&context, [](Diagnostic diag) {
ScopedDiagnosticHandler scopedHandler(&context, [](Diagnostic &diag) {
...
});
// Use this handler in conjunction with another.
class MyDerivedHandler : public ScopedDiagnosticHandler {
MyDerivedHandler(MLIRContext *ctx) : ScopedDiagnosticHandler(ctx) {
ctx->getDiagEngine().setHandler([&](Diagnostic diag) {
// Set the handler that should be RAII managed.
setHandler([&](Diagnostic diag) {
...
});
}

View File

@ -409,8 +409,8 @@ class DiagnosticEngine {
public:
~DiagnosticEngine();
// Diagnostic handler registration and use. MLIR supports the ability for the
// IR to carry arbitrary metadata about operation location information. If a
// Diagnostic handler registration and use. MLIR supports the ability for the
// IR to carry arbitrary metadata about operation location information. If a
// problem is detected by the compiler, it can invoke the emitError /
// emitWarning / emitRemark method on an Operation and have it get reported
// through this interface.
@ -419,14 +419,36 @@ public:
// schema for their location information. If they don't, then warnings and
// notes will be dropped and errors will be emitted to errs.
using HandlerTy = std::function<void(Diagnostic)>;
/// The handler type for MLIR diagnostics. This function takes a diagnostic as
/// input, and returns success if the handler has fully processed this
/// diagnostic. Returns failure otherwise.
using HandlerTy = std::function<LogicalResult(Diagnostic &)>;
/// Set the diagnostic handler for this engine. Note that this replaces any
/// existing handler.
void setHandler(const HandlerTy &handler);
/// A handle to a specific registered handler object.
using HandlerID = uint64_t;
/// Return the current diagnostic handler, or null if none is present.
HandlerTy getHandler();
/// Register a new handler for diagnostics to the engine. Diagnostics are
/// process by handlers in stack-like order, meaning that the last added
/// handlers will process diagnostics first. This function returns a unique
/// identifier for the registered handler, which can be used to unregister
/// this handler at a later time.
HandlerID registerHandler(const HandlerTy &handler);
/// Set the diagnostic handler with a function that returns void. This is a
/// convient wrapper for handlers that always completely process the given
/// diagnostic.
template <typename FuncTy, typename RetT = decltype(std::declval<FuncTy>()(
std::declval<Diagnostic &>()))>
std::enable_if_t<std::is_same<RetT, void>::value, HandlerID>
registerHandler(FuncTy &&handler) {
return registerHandler([=](Diagnostic &diag) {
handler(diag);
return success();
});
}
/// Erase the registered diagnostic handler with the given identifier.
void eraseHandler(HandlerID id);
/// Create a new inflight diagnostic with the given location and severity.
InFlightDiagnostic emit(Location loc, DiagnosticSeverity severity) {
@ -447,36 +469,6 @@ private:
std::unique_ptr<detail::DiagnosticEngineImpl> impl;
};
//===----------------------------------------------------------------------===//
// ScopedDiagnosticHandler
//===----------------------------------------------------------------------===//
/// This diagnostic handler is a simple RAII class that saves and restores the
/// current diagnostic handler registered to a given context. This class can
/// be either be used directly, or in conjunction with a derived diagnostic
/// handler.
class ScopedDiagnosticHandler {
public:
ScopedDiagnosticHandler(MLIRContext *ctx);
ScopedDiagnosticHandler(MLIRContext *ctx,
const DiagnosticEngine::HandlerTy &handler);
~ScopedDiagnosticHandler();
/// Propagate a diagnostic to the existing diagnostic handler.
void propagateDiagnostic(Diagnostic diag) {
if (existingHandler)
existingHandler(std::move(diag));
}
private:
/// The existing diagnostic handler registered with the context at the time of
/// construction.
DiagnosticEngine::HandlerTy existingHandler;
/// The context to register the handler back to.
MLIRContext *ctx;
};
/// Utility method to emit an error message using this location.
InFlightDiagnostic emitError(Location loc);
InFlightDiagnostic emitError(Location loc, const Twine &message);
@ -489,6 +481,40 @@ InFlightDiagnostic emitWarning(Location loc, const Twine &message);
InFlightDiagnostic emitRemark(Location loc);
InFlightDiagnostic emitRemark(Location loc, const Twine &message);
//===----------------------------------------------------------------------===//
// ScopedDiagnosticHandler
//===----------------------------------------------------------------------===//
/// This diagnostic handler is a simple RAII class that registers and erases a
/// diagnostic handler on a given context. This class can be either be used
/// directly, or in conjunction with a derived diagnostic handler.
class ScopedDiagnosticHandler {
public:
explicit ScopedDiagnosticHandler(MLIRContext *ctx) : handlerID(0), ctx(ctx) {}
template <typename FuncTy>
ScopedDiagnosticHandler(MLIRContext *ctx, FuncTy &&handler)
: handlerID(0), ctx(ctx) {
setHandler(std::forward<FuncTy>(handler));
}
~ScopedDiagnosticHandler();
protected:
/// Set the handler to manage via RAII.
template <typename FuncTy> void setHandler(FuncTy &&handler) {
auto &diagEngine = ctx->getDiagEngine();
if (handlerID)
diagEngine.eraseHandler(handlerID);
handlerID = diagEngine.registerHandler(std::forward<FuncTy>(handler));
}
private:
/// The unique id for the scoped handler.
DiagnosticEngine::HandlerID handlerID;
/// The context to erase the handler from.
MLIRContext *ctx;
};
//===----------------------------------------------------------------------===//
// SourceMgrDiagnosticHandler
//===----------------------------------------------------------------------===//

View File

@ -22,6 +22,7 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/Mutex.h"
@ -205,9 +206,14 @@ struct DiagnosticEngineImpl {
/// A mutex to ensure that diagnostics emission is thread-safe.
llvm::sys::SmartMutex<true> mutex;
/// This is the handler to use to report diagnostics, or null if not
/// registered.
DiagnosticEngine::HandlerTy handler;
/// These are the handlers used to report diagnostics.
llvm::SmallMapVector<DiagnosticEngine::HandlerID, DiagnosticEngine::HandlerTy,
2>
handlers;
/// This is a unique identifier counter for diagnostic handlers in the
/// context. This id starts at 1 to allow for 0 to be used as a sentinel.
DiagnosticEngine::HandlerID uniqueHandlerId = 1;
};
} // namespace detail
} // namespace mlir
@ -217,9 +223,12 @@ struct DiagnosticEngineImpl {
void DiagnosticEngineImpl::emit(Diagnostic diag) {
llvm::sys::SmartScopedLock<true> lock(mutex);
// If we had a handler registered, emit the diagnostic using it.
if (handler)
return handler(std::move(diag));
// Try to process the given diagnostic on one of the registered handlers.
// Handlers are walked in reverse order, so that the most recent handler is
// processed first.
for (auto &handlerIt : llvm::reverse(handlers))
if (succeeded(handlerIt.second(diag)))
return;
// Otherwise, if this is an error we emit it to stderr.
if (diag.getSeverity() != DiagnosticSeverity::Error)
@ -242,18 +251,20 @@ void DiagnosticEngineImpl::emit(Diagnostic diag) {
DiagnosticEngine::DiagnosticEngine() : impl(new DiagnosticEngineImpl()) {}
DiagnosticEngine::~DiagnosticEngine() {}
/// Set the diagnostic handler for this engine. The handler is passed
/// location information if present (nullptr if not) along with a message and
/// a severity that indicates whether this is an error, warning, etc. Note
/// that this replaces any existing handler.
void DiagnosticEngine::setHandler(const HandlerTy &handler) {
impl->handler = handler;
/// Register a new handler for diagnostics to the engine. This function returns
/// a unique identifier for the registered handler, which can be used to
/// unregister this handler at a later time.
auto DiagnosticEngine::registerHandler(const HandlerTy &handler) -> HandlerID {
llvm::sys::SmartScopedLock<true> lock(impl->mutex);
auto uniqueID = impl->uniqueHandlerId++;
impl->handlers.insert({uniqueID, handler});
return uniqueID;
}
/// Return the current diagnostic handler, or null if none is present.
auto DiagnosticEngine::getHandler() -> HandlerTy {
/// Erase the registered diagnostic handler with the given identifier.
void DiagnosticEngine::eraseHandler(HandlerID handlerID) {
llvm::sys::SmartScopedLock<true> lock(impl->mutex);
return impl->handler;
impl->handlers.erase(handlerID);
}
/// Emit a diagnostic using the registered issue handler if present, or with
@ -303,15 +314,9 @@ InFlightDiagnostic mlir::emitRemark(Location loc, const Twine &message) {
// ScopedDiagnosticHandler
//===----------------------------------------------------------------------===//
ScopedDiagnosticHandler::ScopedDiagnosticHandler(MLIRContext *ctx)
: existingHandler(ctx->getDiagEngine().getHandler()), ctx(ctx) {}
ScopedDiagnosticHandler::ScopedDiagnosticHandler(
MLIRContext *ctx, const DiagnosticEngine::HandlerTy &handler)
: ScopedDiagnosticHandler(ctx) {
ctx->getDiagEngine().setHandler(handler);
}
ScopedDiagnosticHandler::~ScopedDiagnosticHandler() {
ctx->getDiagEngine().setHandler(existingHandler);
if (handlerID)
ctx->getDiagEngine().eraseHandler(handlerID);
}
//===----------------------------------------------------------------------===//
@ -384,9 +389,7 @@ SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr,
llvm::raw_ostream &os)
: ScopedDiagnosticHandler(ctx), mgr(mgr), os(os),
impl(new SourceMgrDiagnosticHandlerImpl()) {
// Register a simple diagnostic handler.
ctx->getDiagEngine().setHandler(
[this](Diagnostic diag) { emitDiagnostic(diag); });
setHandler([this](Diagnostic &diag) { emitDiagnostic(diag); });
}
SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr,
@ -636,7 +639,7 @@ SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
(void)impl->computeExpectedDiags(mgr.getMemoryBuffer(i + 1));
// Register a handler to verfy the diagnostics.
ctx->getDiagEngine().setHandler([&](Diagnostic diag) {
setHandler([&](Diagnostic &diag) {
// Process the main diagnostics.
process(diag);
@ -753,25 +756,25 @@ struct ParallelDiagnosticHandlerImpl : public llvm::PrettyStackTraceEntry {
Diagnostic diag;
};
ParallelDiagnosticHandlerImpl(MLIRContext *ctx)
: prevHandler(ctx->getDiagEngine().getHandler()), context(ctx) {
ctx->getDiagEngine().setHandler([this](Diagnostic diag) {
ParallelDiagnosticHandlerImpl(MLIRContext *ctx) : handlerID(0), context(ctx) {
handlerID = ctx->getDiagEngine().registerHandler([this](Diagnostic &diag) {
uint64_t tid = llvm::get_threadid();
llvm::sys::SmartScopedLock<true> lock(mutex);
// If this diagnostic was not emitted on a thread we track, then forward
// it to the previous handler.
// If this thread is not tracked, then return failure to let another
// handler process this diagnostic.
if (!threadToOrderID.count(tid))
prevHandler(std::move(diag));
else
// Otherwise, append a new diagnostic.
diagnostics.emplace_back(threadToOrderID[tid], std::move(diag));
return failure();
// Append a new diagnostic.
diagnostics.emplace_back(threadToOrderID[tid], std::move(diag));
return success();
});
}
~ParallelDiagnosticHandlerImpl() override {
// Restore the previous diagnostic handler.
context->getDiagEngine().setHandler(prevHandler);
// Erase this handler from the context.
context->getDiagEngine().eraseHandler(handlerID);
// Early exit if there are no diagnostics, this is the common case.
if (diagnostics.empty())
@ -784,7 +787,7 @@ struct ParallelDiagnosticHandlerImpl : public llvm::PrettyStackTraceEntry {
}
/// Utility method to emit any held diagnostics.
void emitDiagnostics(std::function<void(Diagnostic)> emitFn) {
void emitDiagnostics(std::function<void(Diagnostic)> emitFn) const {
// Stable sort all of the diagnostics that were emitted. This creates a
// deterministic ordering for the diagnostics based upon which order id they
// were emitted for.
@ -816,35 +819,31 @@ struct ParallelDiagnosticHandlerImpl : public llvm::PrettyStackTraceEntry {
return;
os << "In-Flight Diagnostics:\n";
const_cast<ParallelDiagnosticHandlerImpl *>(this)->emitDiagnostics(
[&](Diagnostic diag) {
os.indent(4);
emitDiagnostics([&](Diagnostic diag) {
os.indent(4);
// Print each diagnostic with the format:
// "<location>: <kind>: <msg>"
if (!diag.getLocation().isa<UnknownLoc>())
os << diag.getLocation() << ": ";
switch (diag.getSeverity()) {
case DiagnosticSeverity::Error:
os << "error: ";
break;
case DiagnosticSeverity::Warning:
os << "warning: ";
break;
case DiagnosticSeverity::Note:
os << "note: ";
break;
case DiagnosticSeverity::Remark:
os << "remark: ";
break;
}
os << diag << '\n';
});
// Print each diagnostic with the format:
// "<location>: <kind>: <msg>"
if (!diag.getLocation().isa<UnknownLoc>())
os << diag.getLocation() << ": ";
switch (diag.getSeverity()) {
case DiagnosticSeverity::Error:
os << "error: ";
break;
case DiagnosticSeverity::Warning:
os << "warning: ";
break;
case DiagnosticSeverity::Note:
os << "note: ";
break;
case DiagnosticSeverity::Remark:
os << "remark: ";
break;
}
os << diag << '\n';
});
}
/// The previous context diagnostic handler.
DiagnosticEngine::HandlerTy prevHandler;
/// A smart mutex to lock access to the internal state.
llvm::sys::SmartMutex<true> mutex;
@ -852,7 +851,10 @@ struct ParallelDiagnosticHandlerImpl : public llvm::PrettyStackTraceEntry {
DenseMap<uint64_t, size_t> threadToOrderID;
/// An unordered list of diagnostics that were emitted.
std::vector<ThreadDiagnostic> diagnostics;
mutable std::vector<ThreadDiagnostic> diagnostics;
/// The unique id for the parallel handler.
DiagnosticEngine::HandlerID handlerID;
/// The context to emit the diagnostics to.
MLIRContext *context;

View File

@ -45,7 +45,7 @@ protected:
DeserializationTest() {
// Register a diagnostic handler to capture the diagnostic so that we can
// check it later.
context.getDiagEngine().setHandler([&](Diagnostic diag) {
context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
diagnostic.reset(new Diagnostic(std::move(diag)));
});
}