forked from OSchip/llvm-project
Plumb write_bytecode to the Python API
This adds a `write_bytecode` method to the Operation class. The method takes a file handle and writes the binary blob to it. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D133210
This commit is contained in:
parent
f33645301e
commit
89418ddcb5
|
@ -521,6 +521,11 @@ MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op,
|
|||
MlirStringCallback callback,
|
||||
void *userData);
|
||||
|
||||
/// Same as mlirOperationPrint but writing the bytecode format out.
|
||||
MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op,
|
||||
MlirStringCallback callback,
|
||||
void *userData);
|
||||
|
||||
/// Prints an operation to stderr.
|
||||
MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op);
|
||||
|
||||
|
|
|
@ -119,6 +119,13 @@ Returns:
|
|||
argument.
|
||||
)";
|
||||
|
||||
static const char kOperationPrintBytecodeDocstring[] =
|
||||
R"(Write the bytecode form of the operation to a file like object.
|
||||
|
||||
Args:
|
||||
file: The file like object to write to.
|
||||
)";
|
||||
|
||||
static const char kOperationStrDunderDocstring[] =
|
||||
R"(Gets the assembly form of the operation with default options.
|
||||
|
||||
|
@ -1022,6 +1029,14 @@ void PyOperationBase::print(py::object fileObject, bool binary,
|
|||
mlirOpPrintingFlagsDestroy(flags);
|
||||
}
|
||||
|
||||
void PyOperationBase::writeBytecode(py::object fileObject) {
|
||||
PyOperation &operation = getOperation();
|
||||
operation.checkValid();
|
||||
PyFileAccumulator accum(fileObject, /*binary=*/true);
|
||||
mlirOperationWriteBytecode(operation, accum.getCallback(),
|
||||
accum.getUserData());
|
||||
}
|
||||
|
||||
py::object PyOperationBase::getAsm(bool binary,
|
||||
llvm::Optional<int64_t> largeElementsLimit,
|
||||
bool enableDebugInfo, bool prettyDebugInfo,
|
||||
|
@ -2627,6 +2642,8 @@ void mlir::python::populateIRCore(py::module &m) {
|
|||
py::arg("print_generic_op_form") = false,
|
||||
py::arg("use_local_scope") = false,
|
||||
py::arg("assume_verified") = false, kOperationPrintDocstring)
|
||||
.def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"),
|
||||
kOperationPrintBytecodeDocstring)
|
||||
.def("get_asm", &PyOperationBase::getAsm,
|
||||
// Careful: Lots of arguments must match up with get_asm method.
|
||||
py::arg("binary") = false,
|
||||
|
|
|
@ -512,6 +512,9 @@ public:
|
|||
bool printGenericOpForm, bool useLocalScope,
|
||||
bool assumeVerified);
|
||||
|
||||
// Implement the bound 'writeBytecode' method.
|
||||
void writeBytecode(pybind11::object fileObject);
|
||||
|
||||
/// Moves the operation before or after the other operation.
|
||||
void moveAfter(PyOperationBase &other);
|
||||
void moveBefore(PyOperationBase &other);
|
||||
|
|
|
@ -12,6 +12,7 @@ add_mlir_upstream_c_api_library(MLIRCAPIIR
|
|||
Support.cpp
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRBytecodeWriter
|
||||
MLIRIR
|
||||
MLIRParser
|
||||
MLIRSupport
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include "mlir-c/Support.h"
|
||||
|
||||
#include "mlir/AsmParser/AsmParser.h"
|
||||
#include "mlir/Bytecode/BytecodeWriter.h"
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/CAPI/Support.h"
|
||||
#include "mlir/CAPI/Utils.h"
|
||||
|
@ -23,7 +24,6 @@
|
|||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Parser/Parser.h"
|
||||
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include <cstddef>
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -485,6 +485,12 @@ void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
|
|||
unwrap(op)->print(stream, *unwrap(flags));
|
||||
}
|
||||
|
||||
void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback,
|
||||
void *userData) {
|
||||
detail::CallbackOstream stream(callback, userData);
|
||||
writeBytecodeToFile(unwrap(op), stream);
|
||||
}
|
||||
|
||||
void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
|
||||
|
||||
bool mlirOperationVerify(MlirOperation op) {
|
||||
|
|
|
@ -566,6 +566,18 @@ def testOperationPrint():
|
|||
print(str_value.__class__)
|
||||
print(f.getvalue())
|
||||
|
||||
# Test roundtrip to bytecode.
|
||||
bytecode_stream = io.BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
assert bytecode.startswith(b'ML\xefR'), "Expected bytecode to start with MLïR"
|
||||
module_roundtrip = Module.parse(bytecode, ctx)
|
||||
f = io.StringIO()
|
||||
module_roundtrip.operation.print(file=f)
|
||||
roundtrip_value = f.getvalue()
|
||||
assert str_value == roundtrip_value, "Mismatch after roundtrip bytecode"
|
||||
|
||||
|
||||
# Test print to binary file.
|
||||
f = io.BytesIO()
|
||||
# CHECK: <class 'bytes'>
|
||||
|
|
|
@ -398,6 +398,7 @@ mlir_c_api_cc_library(
|
|||
includes = ["include"],
|
||||
deps = [
|
||||
":AsmParser",
|
||||
":BytecodeWriter",
|
||||
":ConversionPassIncGen",
|
||||
":FuncDialect",
|
||||
":InferTypeOpInterface",
|
||||
|
|
Loading…
Reference in New Issue