mlir-translate: support -verify-diagnostics

MLIR translation tools can emit diagnostics and we want to be able to check if
it is indeed the case in tests. Reuse the source manager error handlers
provided for mlir-opt to support the verification in mlir-translate. This
requires us to change the signature of the functions that are registered to
translate sources to MLIR: it now takes a source manager instead of a memory
buffer.

PiperOrigin-RevId: 279132972
This commit is contained in:
Alex Zinenko 2019-11-07 11:42:11 -08:00 committed by A. Unique TensorFlower
parent eb47d5ee66
commit 09e8e7107a
7 changed files with 112 additions and 62 deletions

View File

@ -27,6 +27,7 @@
namespace llvm {
class MemoryBuffer;
class SourceMgr;
class StringRef;
} // namespace llvm
@ -36,12 +37,19 @@ class MLIRContext;
class ModuleOp;
class OwningModuleRef;
/// Interface of the function that translates a source file held by the given
/// MemoryBuffer to MLIR. The implementation should create a new MLIR ModuleOp
/// in the given context and return a pointer to it, or a nullptr in case of any
/// error.
using TranslateToMLIRFunction = std::function<OwningModuleRef(
std::unique_ptr<llvm::MemoryBuffer> input, MLIRContext *)>;
/// Interface of the function that translates the sources managed by `sourceMgr`
/// to MLIR. The source manager has at least one buffer. The implementation
/// should create a new MLIR ModuleOp in the given context and return a pointer
/// to it, or a nullptr in case of any error.
using TranslateSourceMgrToMLIRFunction =
std::function<OwningModuleRef(llvm::SourceMgr &sourceMgr, MLIRContext *)>;
/// Interface of the function that translates the given string to MLIR. The
/// implementation should create a new MLIR ModuleOp in the given context. If
/// source-related error reporting is required from within the function, use
/// TranslateSourceMgrToMLIRFunction instead.
using TranslateStringRefToMLIRFunction =
std::function<OwningModuleRef(llvm::StringRef, MLIRContext *)>;
/// Interface of the function that translates MLIR to a different format and
/// outputs the result to a stream. It is allowed to modify the module.
@ -53,11 +61,10 @@ using TranslateFromMLIRFunction =
/// should be written to the given raw_ostream. The implementation should create
/// all MLIR constructs needed during the process inside the given context. This
/// can be used for round-tripping external formats through the MLIR system.
using TranslateFunction =
std::function<LogicalResult(std::unique_ptr<llvm::MemoryBuffer> input,
llvm::raw_ostream &output, MLIRContext *)>;
using TranslateFunction = std::function<LogicalResult(
llvm::SourceMgr &sourceMgr, llvm::raw_ostream &output, MLIRContext *)>;
/// Use Translate[ToMLIR|FromMLIR|]Registration as a global initialiser that
/// Use Translate[ToMLIR|FromMLIR]Registration as a global initialiser that
/// registers a function and associates it with name. This requires that a
/// translation has not been registered to a given name.
///
@ -69,7 +76,9 @@ using TranslateFunction =
/// \{
struct TranslateToMLIRRegistration {
TranslateToMLIRRegistration(llvm::StringRef name,
const TranslateToMLIRFunction &function);
const TranslateSourceMgrToMLIRFunction &function);
TranslateToMLIRRegistration(llvm::StringRef name,
const TranslateStringRefToMLIRFunction &function);
};
struct TranslateFromMLIRRegistration {
@ -83,7 +92,8 @@ struct TranslateRegistration {
/// \}
/// Get a read-only reference to the translator registry.
const llvm::StringMap<TranslateToMLIRFunction> &getTranslationToMLIRRegistry();
const llvm::StringMap<TranslateSourceMgrToMLIRFunction> &
getTranslationToMLIRRegistry();
const llvm::StringMap<TranslateFromMLIRFunction> &
getTranslationFromMLIRRegistry();
const llvm::StringMap<TranslateFunction> &getTranslationRegistry();

View File

@ -42,7 +42,7 @@ using namespace mlir;
// Deserializes the SPIR-V binary module stored in the file named as
// `inputFilename` and returns a module containing the SPIR-V module.
OwningModuleRef deserializeModule(std::unique_ptr<llvm::MemoryBuffer> input,
OwningModuleRef deserializeModule(const llvm::MemoryBuffer *input,
MLIRContext *context) {
Builder builder(context);
@ -70,9 +70,10 @@ OwningModuleRef deserializeModule(std::unique_ptr<llvm::MemoryBuffer> input,
}
static TranslateToMLIRRegistration fromBinary(
"deserialize-spirv",
[](std::unique_ptr<llvm::MemoryBuffer> input, MLIRContext *context) {
return deserializeModule(std::move(input), context);
"deserialize-spirv", [](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
assert(sourceMgr.getNumBuffers() == 1 && "expected one buffer");
return deserializeModule(
sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), context);
});
//===----------------------------------------------------------------------===//
@ -111,13 +112,9 @@ static TranslateFromMLIRRegistration
// Round-trip registration
//===----------------------------------------------------------------------===//
LogicalResult roundTripModule(std::unique_ptr<llvm::MemoryBuffer> input,
LogicalResult roundTripModule(llvm::SourceMgr &sourceMgr,
llvm::raw_ostream &output, MLIRContext *context) {
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(input), llvm::SMLoc());
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context);
// Parse the memory buffer as a MLIR module.
// Parse an MLIR module from the source manager.
auto srcModule = OwningModuleRef(parseSourceFile(sourceMgr, context));
if (!srcModule)
return failure();
@ -151,7 +148,7 @@ LogicalResult roundTripModule(std::unique_ptr<llvm::MemoryBuffer> input,
static TranslateRegistration
roundtrip("test-spirv-roundtrip",
[](std::unique_ptr<llvm::MemoryBuffer> input,
llvm::raw_ostream &output, MLIRContext *context) {
return roundTripModule(std::move(input), output, context);
[](llvm::SourceMgr &sourceMgr, llvm::raw_ostream &output,
MLIRContext *context) {
return roundTripModule(sourceMgr, output, context);
});

View File

@ -55,15 +55,15 @@ TranslationParser::TranslationParser(llvm::cl::Option &opt)
wrapperStorage.reserve(toMLIRRegistry.size() + fromMLIRRegistry.size() +
fileToFileRegistry.size());
for (const auto &kv : toMLIRRegistry) {
TranslateToMLIRFunction function = kv.second;
TranslateFunction wrapper =
[function](std::unique_ptr<llvm::MemoryBuffer> input,
llvm::raw_ostream &output, MLIRContext *context) {
OwningModuleRef module = function(std::move(input), context);
if (!module)
return failure();
return printMLIROutput(*module, output);
};
TranslateSourceMgrToMLIRFunction function = kv.second;
TranslateFunction wrapper = [function](llvm::SourceMgr &sourceMgr,
llvm::raw_ostream &output,
MLIRContext *context) {
OwningModuleRef module = function(sourceMgr, context);
if (!module)
return failure();
return printMLIROutput(*module, output);
};
wrapperStorage.emplace_back(std::move(wrapper));
addLiteralOption(kv.first(), &wrapperStorage.back(), kv.first());
@ -71,18 +71,14 @@ TranslationParser::TranslationParser(llvm::cl::Option &opt)
for (const auto &kv : fromMLIRRegistry) {
TranslateFromMLIRFunction function = kv.second;
TranslateFunction wrapper =
[function](std::unique_ptr<llvm::MemoryBuffer> input,
llvm::raw_ostream &output, MLIRContext *context) {
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(input), llvm::SMLoc());
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context);
auto module = OwningModuleRef(parseSourceFile(sourceMgr, context));
if (!module)
return failure();
return function(module.get(), output);
};
TranslateFunction wrapper = [function](llvm::SourceMgr &sourceMgr,
llvm::raw_ostream &output,
MLIRContext *context) {
auto module = OwningModuleRef(parseSourceFile(sourceMgr, context));
if (!module)
return failure();
return function(module.get(), output);
};
wrapperStorage.emplace_back(std::move(wrapper));
addLiteralOption(kv.first(), &wrapperStorage.back(), kv.first());

View File

@ -571,15 +571,15 @@ mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
// Deserializes the LLVM bitcode stored in `input` into an MLIR module in the
// LLVM dialect.
OwningModuleRef
translateLLVMIRToModule(std::unique_ptr<llvm::MemoryBuffer> input,
MLIRContext *context) {
OwningModuleRef translateLLVMIRToModule(llvm::SourceMgr &sourceMgr,
MLIRContext *context) {
LLVMDialect *dialect = context->getRegisteredDialect<LLVMDialect>();
assert(dialect && "Could not find LLVMDialect?");
llvm::SMDiagnostic err;
std::unique_ptr<llvm::Module> llvmModule =
llvm::parseIR(*input, err, dialect->getLLVMContext(),
llvm::parseIR(*sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), err,
dialect->getLLVMContext(),
/*UpgradeDebugInfo=*/true,
/*DataLayoutString=*/"");
if (!llvmModule) {
@ -593,7 +593,7 @@ translateLLVMIRToModule(std::unique_ptr<llvm::MemoryBuffer> input,
}
static TranslateToMLIRRegistration
fromLLVM("import-llvm", [](std::unique_ptr<llvm::MemoryBuffer> input,
MLIRContext *context) {
return translateLLVMIRToModule(std::move(input), context);
});
fromLLVM("import-llvm",
[](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
return translateLLVMIRToModule(sourceMgr, context);
});

View File

@ -23,15 +23,16 @@
#include "mlir/IR/Module.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/SourceMgr.h"
using namespace mlir;
// Get the mutable static map between registered "to MLIR" translations and the
// TranslateToMLIRFunctions that perform those translations.
static llvm::StringMap<TranslateToMLIRFunction> &
static llvm::StringMap<TranslateSourceMgrToMLIRFunction> &
getMutableTranslationToMLIRRegistry() {
static llvm::StringMap<TranslateToMLIRFunction> translationToMLIRRegistry;
static llvm::StringMap<TranslateSourceMgrToMLIRFunction>
translationToMLIRRegistry;
return translationToMLIRRegistry;
}
// Get the mutable static map between registered "from MLIR" translations and
@ -49,8 +50,10 @@ static llvm::StringMap<TranslateFunction> &getMutableTranslationRegistry() {
return translationRegistry;
}
TranslateToMLIRRegistration::TranslateToMLIRRegistration(
StringRef name, const TranslateToMLIRFunction &function) {
// Puts `function` into the to-MLIR translation registry unless there is already
// a function registered for the same name.
static void registerTranslateToMLIRFunction(
StringRef name, const TranslateSourceMgrToMLIRFunction &function) {
auto &translationToMLIRRegistry = getMutableTranslationToMLIRRegistry();
if (translationToMLIRRegistry.find(name) != translationToMLIRRegistry.end())
llvm::report_fatal_error(
@ -59,6 +62,24 @@ TranslateToMLIRRegistration::TranslateToMLIRRegistration(
translationToMLIRRegistry[name] = function;
}
TranslateToMLIRRegistration::TranslateToMLIRRegistration(
StringRef name, const TranslateSourceMgrToMLIRFunction &function) {
registerTranslateToMLIRFunction(name, function);
}
// Wraps `function` with a lambda that extracts a StringRef from a source
// manager and registers the wrapper lambda as a to-MLIR conversion.
TranslateToMLIRRegistration::TranslateToMLIRRegistration(
StringRef name, const TranslateStringRefToMLIRFunction &function) {
auto translationFunction = [function](llvm::SourceMgr &sourceMgr,
MLIRContext *ctx) {
const llvm::MemoryBuffer *buffer =
sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
return function(buffer->getBuffer(), ctx);
};
registerTranslateToMLIRFunction(name, translationFunction);
}
TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
StringRef name, const TranslateFromMLIRFunction &function) {
auto &translationFromMLIRRegistry = getMutableTranslationFromMLIRRegistry();
@ -84,7 +105,7 @@ TranslateRegistration::TranslateRegistration(
// Merely add the const qualifier to the mutable registry so that external users
// cannot modify it.
const llvm::StringMap<TranslateToMLIRFunction> &
const llvm::StringMap<TranslateSourceMgrToMLIRFunction> &
mlir::getTranslationToMLIRRegistry() {
return getMutableTranslationToMLIRRegistry();
}

View File

@ -0,0 +1,6 @@
// RUN: mlir-translate -verify-diagnostics -mlir-to-llvmir %s
// expected-error @+1 {{unsupported module-level operation}}
func @foo() {
llvm.return
}

View File

@ -20,6 +20,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/LogicalResult.h"
@ -46,6 +47,12 @@ static llvm::cl::opt<bool>
"process each chunk independently"),
llvm::cl::init(false));
static llvm::cl::opt<bool> verifyDiagnostics(
"verify-diagnostics",
llvm::cl::desc("Check that emitted diagnostics match "
"expected-* lines on the corresponding line"),
llvm::cl::init(false));
int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
@ -69,11 +76,24 @@ int main(int argc, char **argv) {
return 1;
}
/// Processes the memory buffer with a new MLIRContext.
// Processes the memory buffer with a new MLIRContext.
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
raw_ostream &os) {
MLIRContext context;
return (*translationRequested)(std::move(ownedBuffer), os, &context);
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());
if (!verifyDiagnostics) {
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
return (*translationRequested)(sourceMgr, os, &context);
}
// In the diagnostic verification flow, we ignore whether the translation
// failed (in most cases, it is expected to fail). Instead, we check if the
// diagnostics were produced as expected.
SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context);
(*translationRequested)(sourceMgr, os, &context);
return sourceMgrHandler.verify();
};
if (splitInputFile) {