Use the generic form when printing from the python bindings and the verifier fails

This reduces the chances of segfault. While it is a good practice to ensure
robust custom printers, it is unfortunately common to have them crash on
invalid input.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D92536
This commit is contained in:
Mehdi Amini 2020-12-03 17:45:28 +00:00
parent df8a79258f
commit 1c2159494d
4 changed files with 28 additions and 0 deletions

View File

@ -386,6 +386,9 @@ MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op,
/// Prints an operation to stderr.
MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op);
/// Verify the operation and return true if it passes, false if it fails.
MLIR_CAPI_EXPORTED bool mlirOperationVerify(MlirOperation op);
//===----------------------------------------------------------------------===//
// Region API.
//===----------------------------------------------------------------------===//

View File

@ -809,6 +809,12 @@ void PyOperationBase::print(py::object fileObject, bool binary,
operation.checkValid();
if (fileObject.is_none())
fileObject = py::module::import("sys").attr("stdout");
if (!printGenericOpForm && !mlirOperationVerify(operation)) {
fileObject.attr("write")("// Verification failed, printing generic form\n");
printGenericOpForm = true;
}
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
if (largeElementsLimit)
mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);

View File

@ -17,6 +17,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Parser.h"
using namespace mlir;
@ -339,6 +340,10 @@ void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
bool mlirOperationVerify(MlirOperation op) {
return succeeded(verify(unwrap(op)));
}
//===----------------------------------------------------------------------===//
// Region API.
//===----------------------------------------------------------------------===//

View File

@ -537,3 +537,17 @@ def testSingleResultProperty():
print(module.body.operations[2])
run(testSingleResultProperty)
# CHECK-LABEL: TEST: testPrintInvalidOperation
def testPrintInvalidOperation():
ctx = Context()
with Location.unknown(ctx):
module = Operation.create("module", regions=1)
# This block does not have a terminator, it may crash the custom printer.
# Verify that we fallback to the generic printer for safety.
block = module.regions[0].blocks.append()
print(module)
# CHECK: // Verification failed, printing generic form
# CHECK: "module"() ( {
# CHECK: }) : () -> ()
run(testPrintInvalidOperation)