Plumb write_bytecode to the Python API

This adds a `write_bytecode` method to the Operation class.
The method takes a file handle and writes the binary blob to it.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D133210
This commit is contained in:
Mehdi Amini 2022-09-05 11:54:19 +00:00
parent f33645301e
commit 89418ddcb5
7 changed files with 46 additions and 1 deletions

View File

@ -521,6 +521,11 @@ MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op,
MlirStringCallback callback,
void *userData);
/// Same as mlirOperationPrint but writing the bytecode format out.
MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op,
MlirStringCallback callback,
void *userData);
/// Prints an operation to stderr.
MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op);

View File

@ -119,6 +119,13 @@ Returns:
argument.
)";
static const char kOperationPrintBytecodeDocstring[] =
R"(Write the bytecode form of the operation to a file like object.
Args:
file: The file like object to write to.
)";
static const char kOperationStrDunderDocstring[] =
R"(Gets the assembly form of the operation with default options.
@ -1022,6 +1029,14 @@ void PyOperationBase::print(py::object fileObject, bool binary,
mlirOpPrintingFlagsDestroy(flags);
}
void PyOperationBase::writeBytecode(py::object fileObject) {
PyOperation &operation = getOperation();
operation.checkValid();
PyFileAccumulator accum(fileObject, /*binary=*/true);
mlirOperationWriteBytecode(operation, accum.getCallback(),
accum.getUserData());
}
py::object PyOperationBase::getAsm(bool binary,
llvm::Optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
@ -2627,6 +2642,8 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("print_generic_op_form") = false,
py::arg("use_local_scope") = false,
py::arg("assume_verified") = false, kOperationPrintDocstring)
.def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"),
kOperationPrintBytecodeDocstring)
.def("get_asm", &PyOperationBase::getAsm,
// Careful: Lots of arguments must match up with get_asm method.
py::arg("binary") = false,

View File

@ -512,6 +512,9 @@ public:
bool printGenericOpForm, bool useLocalScope,
bool assumeVerified);
// Implement the bound 'writeBytecode' method.
void writeBytecode(pybind11::object fileObject);
/// Moves the operation before or after the other operation.
void moveAfter(PyOperationBase &other);
void moveBefore(PyOperationBase &other);

View File

@ -12,6 +12,7 @@ add_mlir_upstream_c_api_library(MLIRCAPIIR
Support.cpp
LINK_LIBS PUBLIC
MLIRBytecodeWriter
MLIRIR
MLIRParser
MLIRSupport

View File

@ -10,6 +10,7 @@
#include "mlir-c/Support.h"
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/CAPI/Utils.h"
@ -23,7 +24,6 @@
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Parser/Parser.h"
#include "llvm/Support/Debug.h"
#include <cstddef>
using namespace mlir;
@ -485,6 +485,12 @@ void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
unwrap(op)->print(stream, *unwrap(flags));
}
void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback,
void *userData) {
detail::CallbackOstream stream(callback, userData);
writeBytecodeToFile(unwrap(op), stream);
}
void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
bool mlirOperationVerify(MlirOperation op) {

View File

@ -566,6 +566,18 @@ def testOperationPrint():
print(str_value.__class__)
print(f.getvalue())
# Test roundtrip to bytecode.
bytecode_stream = io.BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
assert bytecode.startswith(b'ML\xefR'), "Expected bytecode to start with MLïR"
module_roundtrip = Module.parse(bytecode, ctx)
f = io.StringIO()
module_roundtrip.operation.print(file=f)
roundtrip_value = f.getvalue()
assert str_value == roundtrip_value, "Mismatch after roundtrip bytecode"
# Test print to binary file.
f = io.BytesIO()
# CHECK: <class 'bytes'>

View File

@ -398,6 +398,7 @@ mlir_c_api_cc_library(
includes = ["include"],
deps = [
":AsmParser",
":BytecodeWriter",
":ConversionPassIncGen",
":FuncDialect",
":InferTypeOpInterface",