[mlir][Python] Sync Python bindings with C API MlirStringRef modification.

MLIR C API use the `MlirStringRef` instead of `const char *` for the string type now. This patch sync the Python bindings with the C API modification.

Differential Revision: https://reviews.llvm.org/D92007
This commit is contained in:
zhanghb97 2020-11-24 18:35:22 +00:00 committed by Stella Laurenzo
parent 1e821217cb
commit 5f0c1e3806
2 changed files with 41 additions and 29 deletions

View File

@ -145,6 +145,11 @@ createCustomDialectWrapper(const std::string &dialectNamespace,
// Create the custom implementation.
return (*dialectClass)(std::move(dialectDescriptor));
}
static MlirStringRef toMlirStringRef(const std::string &s) {
return mlirStringRefCreate(s.data(), s.size());
}
//------------------------------------------------------------------------------
// Collections.
//------------------------------------------------------------------------------
@ -902,7 +907,8 @@ py::object PyOperation::create(
// Apply unpacked/validated to the operation state. Beyond this
// point, exceptions cannot be thrown or else the state will leak.
MlirOperationState state = mlirOperationStateGet(name.c_str(), location->loc);
MlirOperationState state =
mlirOperationStateGet(toMlirStringRef(name), location->loc);
if (!mlirOperands.empty())
mlirOperationStateAddOperands(&state, mlirOperands.size(),
mlirOperands.data());
@ -917,7 +923,7 @@ py::object PyOperation::create(
mlirNamedAttributes.reserve(mlirAttributes.size());
for (auto &it : mlirAttributes)
mlirNamedAttributes.push_back(
mlirNamedAttributeGet(it.first.c_str(), it.second));
mlirNamedAttributeGet(toMlirStringRef(it.first), it.second));
mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
mlirNamedAttributes.data());
}
@ -1076,7 +1082,7 @@ bool PyAttribute::operator==(const PyAttribute &other) {
PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
: ownedName(new std::string(std::move(ownedName))) {
namedAttr = mlirNamedAttributeGet(this->ownedName->c_str(), attr);
namedAttr = mlirNamedAttributeGet(toMlirStringRef(*this->ownedName), attr);
}
//------------------------------------------------------------------------------
@ -1287,8 +1293,8 @@ public:
PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
PyAttribute dunderGetItemNamed(const std::string &name) {
MlirAttribute attr =
mlirOperationGetAttributeByName(operation->get(), name.c_str());
MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
toMlirStringRef(name));
if (mlirAttributeIsNull(attr)) {
throw SetPyError(PyExc_KeyError,
"attempt to access a non-existent attribute");
@ -1303,16 +1309,18 @@ public:
}
MlirNamedAttribute namedAttr =
mlirOperationGetAttribute(operation->get(), index);
return PyNamedAttribute(namedAttr.attribute, std::string(namedAttr.name));
return PyNamedAttribute(namedAttr.attribute,
std::string(namedAttr.name.data));
}
void dunderSetItem(const std::string &name, PyAttribute attr) {
mlirOperationSetAttributeByName(operation->get(), name.c_str(), attr.attr);
mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
attr.attr);
}
void dunderDelItem(const std::string &name) {
int removed =
mlirOperationRemoveAttributeByName(operation->get(), name.c_str());
int removed = mlirOperationRemoveAttributeByName(operation->get(),
toMlirStringRef(name));
if (!removed)
throw SetPyError(PyExc_KeyError,
"attempt to delete a non-existent attribute");
@ -1323,8 +1331,8 @@ public:
}
bool dunderContains(const std::string &name) {
return !mlirAttributeIsNull(
mlirOperationGetAttributeByName(operation->get(), name.c_str()));
return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
operation->get(), toMlirStringRef(name)));
}
static void bind(py::module &m) {
@ -2599,9 +2607,10 @@ void mlir::python::populateIRSubmodule(py::module &m) {
"file",
[](std::string filename, int line, int col,
DefaultingPyMlirContext context) {
return PyLocation(context->getRef(),
mlirLocationFileLineColGet(
context->get(), filename.c_str(), line, col));
return PyLocation(
context->getRef(),
mlirLocationFileLineColGet(
context->get(), toMlirStringRef(filename), line, col));
},
py::arg("filename"), py::arg("line"), py::arg("col"),
py::arg("context") = py::none(), kContextGetFileLocationDocstring)
@ -2625,8 +2634,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
.def_static(
"parse",
[](const std::string moduleAsm, DefaultingPyMlirContext context) {
MlirModule module =
mlirModuleCreateParse(context->get(), moduleAsm.c_str());
MlirModule module = mlirModuleCreateParse(
context->get(), toMlirStringRef(moduleAsm));
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirModuleIsNull(module)) {
@ -2875,8 +2884,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
.def_static(
"parse",
[](std::string attrSpec, DefaultingPyMlirContext context) {
MlirAttribute type =
mlirAttributeParseGet(context->get(), attrSpec.c_str());
MlirAttribute type = mlirAttributeParseGet(
context->get(), toMlirStringRef(attrSpec));
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirAttributeIsNull(type)) {
@ -2940,7 +2949,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
[](PyNamedAttribute &self) {
PyPrintAccumulator printAccum;
printAccum.parts.append("NamedAttribute(");
printAccum.parts.append(self.namedAttr.name);
printAccum.parts.append(self.namedAttr.name.data);
printAccum.parts.append("=");
mlirAttributePrint(self.namedAttr.attribute,
printAccum.getCallback(),
@ -2951,7 +2960,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
.def_property_readonly(
"name",
[](PyNamedAttribute &self) {
return py::str(self.namedAttr.name, strlen(self.namedAttr.name));
return py::str(self.namedAttr.name.data,
self.namedAttr.name.length);
},
"The name of the NamedAttribute binding")
.def_property_readonly(
@ -2983,7 +2993,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
.def_static(
"parse",
[](std::string typeSpec, DefaultingPyMlirContext context) {
MlirType type = mlirTypeParseGet(context->get(), typeSpec.c_str());
MlirType type =
mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(type)) {

View File

@ -16,7 +16,6 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace mlir {
namespace python {
@ -115,10 +114,11 @@ struct PyPrintAccumulator {
void *getUserData() { return this; }
MlirStringCallback getCallback() {
return [](const char *part, intptr_t size, void *userData) {
return [](MlirStringRef part, void *userData) {
PyPrintAccumulator *printAccum =
static_cast<PyPrintAccumulator *>(userData);
pybind11::str pyPart(part, size); // Decodes as UTF-8 by default.
pybind11::str pyPart(part.data,
part.length); // Decodes as UTF-8 by default.
printAccum->parts.append(std::move(pyPart));
};
}
@ -139,15 +139,16 @@ public:
void *getUserData() { return this; }
MlirStringCallback getCallback() {
return [](const char *part, intptr_t size, void *userData) {
return [](MlirStringRef part, void *userData) {
pybind11::gil_scoped_acquire();
PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
if (accum->binary) {
// Note: Still has to copy and not avoidable with this API.
pybind11::bytes pyBytes(part, size);
pybind11::bytes pyBytes(part.data, part.length);
accum->pyWriteFunction(pyBytes);
} else {
pybind11::str pyStr(part, size); // Decodes as UTF-8 by default.
pybind11::str pyStr(part.data,
part.length); // Decodes as UTF-8 by default.
accum->pyWriteFunction(pyStr);
}
};
@ -165,13 +166,13 @@ struct PySinglePartStringAccumulator {
void *getUserData() { return this; }
MlirStringCallback getCallback() {
return [](const char *part, intptr_t size, void *userData) {
return [](MlirStringRef part, void *userData) {
PySinglePartStringAccumulator *accum =
static_cast<PySinglePartStringAccumulator *>(userData);
assert(!accum->invoked &&
"PySinglePartStringAccumulator called back multiple times");
accum->invoked = true;
accum->value = pybind11::str(part, size);
accum->value = pybind11::str(part.data, part.length);
};
}