forked from OSchip/llvm-project
[mlir][Python] Sync Python bindings with C API MlirStringRef modification.
MLIR C API use the `MlirStringRef` instead of `const char *` for the string type now. This patch sync the Python bindings with the C API modification. Differential Revision: https://reviews.llvm.org/D92007
This commit is contained in:
parent
1e821217cb
commit
5f0c1e3806
|
@ -145,6 +145,11 @@ createCustomDialectWrapper(const std::string &dialectNamespace,
|
|||
// Create the custom implementation.
|
||||
return (*dialectClass)(std::move(dialectDescriptor));
|
||||
}
|
||||
|
||||
static MlirStringRef toMlirStringRef(const std::string &s) {
|
||||
return mlirStringRefCreate(s.data(), s.size());
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Collections.
|
||||
//------------------------------------------------------------------------------
|
||||
|
@ -902,7 +907,8 @@ py::object PyOperation::create(
|
|||
|
||||
// Apply unpacked/validated to the operation state. Beyond this
|
||||
// point, exceptions cannot be thrown or else the state will leak.
|
||||
MlirOperationState state = mlirOperationStateGet(name.c_str(), location->loc);
|
||||
MlirOperationState state =
|
||||
mlirOperationStateGet(toMlirStringRef(name), location->loc);
|
||||
if (!mlirOperands.empty())
|
||||
mlirOperationStateAddOperands(&state, mlirOperands.size(),
|
||||
mlirOperands.data());
|
||||
|
@ -917,7 +923,7 @@ py::object PyOperation::create(
|
|||
mlirNamedAttributes.reserve(mlirAttributes.size());
|
||||
for (auto &it : mlirAttributes)
|
||||
mlirNamedAttributes.push_back(
|
||||
mlirNamedAttributeGet(it.first.c_str(), it.second));
|
||||
mlirNamedAttributeGet(toMlirStringRef(it.first), it.second));
|
||||
mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
|
||||
mlirNamedAttributes.data());
|
||||
}
|
||||
|
@ -1076,7 +1082,7 @@ bool PyAttribute::operator==(const PyAttribute &other) {
|
|||
|
||||
PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
|
||||
: ownedName(new std::string(std::move(ownedName))) {
|
||||
namedAttr = mlirNamedAttributeGet(this->ownedName->c_str(), attr);
|
||||
namedAttr = mlirNamedAttributeGet(toMlirStringRef(*this->ownedName), attr);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
@ -1287,8 +1293,8 @@ public:
|
|||
PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
|
||||
|
||||
PyAttribute dunderGetItemNamed(const std::string &name) {
|
||||
MlirAttribute attr =
|
||||
mlirOperationGetAttributeByName(operation->get(), name.c_str());
|
||||
MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
|
||||
toMlirStringRef(name));
|
||||
if (mlirAttributeIsNull(attr)) {
|
||||
throw SetPyError(PyExc_KeyError,
|
||||
"attempt to access a non-existent attribute");
|
||||
|
@ -1303,16 +1309,18 @@ public:
|
|||
}
|
||||
MlirNamedAttribute namedAttr =
|
||||
mlirOperationGetAttribute(operation->get(), index);
|
||||
return PyNamedAttribute(namedAttr.attribute, std::string(namedAttr.name));
|
||||
return PyNamedAttribute(namedAttr.attribute,
|
||||
std::string(namedAttr.name.data));
|
||||
}
|
||||
|
||||
void dunderSetItem(const std::string &name, PyAttribute attr) {
|
||||
mlirOperationSetAttributeByName(operation->get(), name.c_str(), attr.attr);
|
||||
mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
|
||||
attr.attr);
|
||||
}
|
||||
|
||||
void dunderDelItem(const std::string &name) {
|
||||
int removed =
|
||||
mlirOperationRemoveAttributeByName(operation->get(), name.c_str());
|
||||
int removed = mlirOperationRemoveAttributeByName(operation->get(),
|
||||
toMlirStringRef(name));
|
||||
if (!removed)
|
||||
throw SetPyError(PyExc_KeyError,
|
||||
"attempt to delete a non-existent attribute");
|
||||
|
@ -1323,8 +1331,8 @@ public:
|
|||
}
|
||||
|
||||
bool dunderContains(const std::string &name) {
|
||||
return !mlirAttributeIsNull(
|
||||
mlirOperationGetAttributeByName(operation->get(), name.c_str()));
|
||||
return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
|
||||
operation->get(), toMlirStringRef(name)));
|
||||
}
|
||||
|
||||
static void bind(py::module &m) {
|
||||
|
@ -2599,9 +2607,10 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
"file",
|
||||
[](std::string filename, int line, int col,
|
||||
DefaultingPyMlirContext context) {
|
||||
return PyLocation(context->getRef(),
|
||||
mlirLocationFileLineColGet(
|
||||
context->get(), filename.c_str(), line, col));
|
||||
return PyLocation(
|
||||
context->getRef(),
|
||||
mlirLocationFileLineColGet(
|
||||
context->get(), toMlirStringRef(filename), line, col));
|
||||
},
|
||||
py::arg("filename"), py::arg("line"), py::arg("col"),
|
||||
py::arg("context") = py::none(), kContextGetFileLocationDocstring)
|
||||
|
@ -2625,8 +2634,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
.def_static(
|
||||
"parse",
|
||||
[](const std::string moduleAsm, DefaultingPyMlirContext context) {
|
||||
MlirModule module =
|
||||
mlirModuleCreateParse(context->get(), moduleAsm.c_str());
|
||||
MlirModule module = mlirModuleCreateParse(
|
||||
context->get(), toMlirStringRef(moduleAsm));
|
||||
// TODO: Rework error reporting once diagnostic engine is exposed
|
||||
// in C API.
|
||||
if (mlirModuleIsNull(module)) {
|
||||
|
@ -2875,8 +2884,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
.def_static(
|
||||
"parse",
|
||||
[](std::string attrSpec, DefaultingPyMlirContext context) {
|
||||
MlirAttribute type =
|
||||
mlirAttributeParseGet(context->get(), attrSpec.c_str());
|
||||
MlirAttribute type = mlirAttributeParseGet(
|
||||
context->get(), toMlirStringRef(attrSpec));
|
||||
// TODO: Rework error reporting once diagnostic engine is exposed
|
||||
// in C API.
|
||||
if (mlirAttributeIsNull(type)) {
|
||||
|
@ -2940,7 +2949,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
[](PyNamedAttribute &self) {
|
||||
PyPrintAccumulator printAccum;
|
||||
printAccum.parts.append("NamedAttribute(");
|
||||
printAccum.parts.append(self.namedAttr.name);
|
||||
printAccum.parts.append(self.namedAttr.name.data);
|
||||
printAccum.parts.append("=");
|
||||
mlirAttributePrint(self.namedAttr.attribute,
|
||||
printAccum.getCallback(),
|
||||
|
@ -2951,7 +2960,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
.def_property_readonly(
|
||||
"name",
|
||||
[](PyNamedAttribute &self) {
|
||||
return py::str(self.namedAttr.name, strlen(self.namedAttr.name));
|
||||
return py::str(self.namedAttr.name.data,
|
||||
self.namedAttr.name.length);
|
||||
},
|
||||
"The name of the NamedAttribute binding")
|
||||
.def_property_readonly(
|
||||
|
@ -2983,7 +2993,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
.def_static(
|
||||
"parse",
|
||||
[](std::string typeSpec, DefaultingPyMlirContext context) {
|
||||
MlirType type = mlirTypeParseGet(context->get(), typeSpec.c_str());
|
||||
MlirType type =
|
||||
mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
|
||||
// TODO: Rework error reporting once diagnostic engine is exposed
|
||||
// in C API.
|
||||
if (mlirTypeIsNull(type)) {
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
|
||||
namespace mlir {
|
||||
namespace python {
|
||||
|
||||
|
@ -115,10 +114,11 @@ struct PyPrintAccumulator {
|
|||
void *getUserData() { return this; }
|
||||
|
||||
MlirStringCallback getCallback() {
|
||||
return [](const char *part, intptr_t size, void *userData) {
|
||||
return [](MlirStringRef part, void *userData) {
|
||||
PyPrintAccumulator *printAccum =
|
||||
static_cast<PyPrintAccumulator *>(userData);
|
||||
pybind11::str pyPart(part, size); // Decodes as UTF-8 by default.
|
||||
pybind11::str pyPart(part.data,
|
||||
part.length); // Decodes as UTF-8 by default.
|
||||
printAccum->parts.append(std::move(pyPart));
|
||||
};
|
||||
}
|
||||
|
@ -139,15 +139,16 @@ public:
|
|||
void *getUserData() { return this; }
|
||||
|
||||
MlirStringCallback getCallback() {
|
||||
return [](const char *part, intptr_t size, void *userData) {
|
||||
return [](MlirStringRef part, void *userData) {
|
||||
pybind11::gil_scoped_acquire();
|
||||
PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
|
||||
if (accum->binary) {
|
||||
// Note: Still has to copy and not avoidable with this API.
|
||||
pybind11::bytes pyBytes(part, size);
|
||||
pybind11::bytes pyBytes(part.data, part.length);
|
||||
accum->pyWriteFunction(pyBytes);
|
||||
} else {
|
||||
pybind11::str pyStr(part, size); // Decodes as UTF-8 by default.
|
||||
pybind11::str pyStr(part.data,
|
||||
part.length); // Decodes as UTF-8 by default.
|
||||
accum->pyWriteFunction(pyStr);
|
||||
}
|
||||
};
|
||||
|
@ -165,13 +166,13 @@ struct PySinglePartStringAccumulator {
|
|||
void *getUserData() { return this; }
|
||||
|
||||
MlirStringCallback getCallback() {
|
||||
return [](const char *part, intptr_t size, void *userData) {
|
||||
return [](MlirStringRef part, void *userData) {
|
||||
PySinglePartStringAccumulator *accum =
|
||||
static_cast<PySinglePartStringAccumulator *>(userData);
|
||||
assert(!accum->invoked &&
|
||||
"PySinglePartStringAccumulator called back multiple times");
|
||||
accum->invoked = true;
|
||||
accum->value = pybind11::str(part, size);
|
||||
accum->value = pybind11::str(part.data, part.length);
|
||||
};
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue