diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp index 1e848c2d1531..e145a58d0d27 100644 --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -145,6 +145,11 @@ createCustomDialectWrapper(const std::string &dialectNamespace, // Create the custom implementation. return (*dialectClass)(std::move(dialectDescriptor)); } + +static MlirStringRef toMlirStringRef(const std::string &s) { + return mlirStringRefCreate(s.data(), s.size()); +} + //------------------------------------------------------------------------------ // Collections. //------------------------------------------------------------------------------ @@ -902,7 +907,8 @@ py::object PyOperation::create( // Apply unpacked/validated to the operation state. Beyond this // point, exceptions cannot be thrown or else the state will leak. - MlirOperationState state = mlirOperationStateGet(name.c_str(), location->loc); + MlirOperationState state = + mlirOperationStateGet(toMlirStringRef(name), location->loc); if (!mlirOperands.empty()) mlirOperationStateAddOperands(&state, mlirOperands.size(), mlirOperands.data()); @@ -917,7 +923,7 @@ py::object PyOperation::create( mlirNamedAttributes.reserve(mlirAttributes.size()); for (auto &it : mlirAttributes) mlirNamedAttributes.push_back( - mlirNamedAttributeGet(it.first.c_str(), it.second)); + mlirNamedAttributeGet(toMlirStringRef(it.first), it.second)); mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), mlirNamedAttributes.data()); } @@ -1076,7 +1082,7 @@ bool PyAttribute::operator==(const PyAttribute &other) { PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) : ownedName(new std::string(std::move(ownedName))) { - namedAttr = mlirNamedAttributeGet(this->ownedName->c_str(), attr); + namedAttr = mlirNamedAttributeGet(toMlirStringRef(*this->ownedName), attr); } //------------------------------------------------------------------------------ @@ -1287,8 +1293,8 @@ public: PyOpAttributeMap(PyOperationRef operation) : operation(operation) {} PyAttribute dunderGetItemNamed(const std::string &name) { - MlirAttribute attr = - mlirOperationGetAttributeByName(operation->get(), name.c_str()); + MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), + toMlirStringRef(name)); if (mlirAttributeIsNull(attr)) { throw SetPyError(PyExc_KeyError, "attempt to access a non-existent attribute"); @@ -1303,16 +1309,18 @@ public: } MlirNamedAttribute namedAttr = mlirOperationGetAttribute(operation->get(), index); - return PyNamedAttribute(namedAttr.attribute, std::string(namedAttr.name)); + return PyNamedAttribute(namedAttr.attribute, + std::string(namedAttr.name.data)); } void dunderSetItem(const std::string &name, PyAttribute attr) { - mlirOperationSetAttributeByName(operation->get(), name.c_str(), attr.attr); + mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), + attr.attr); } void dunderDelItem(const std::string &name) { - int removed = - mlirOperationRemoveAttributeByName(operation->get(), name.c_str()); + int removed = mlirOperationRemoveAttributeByName(operation->get(), + toMlirStringRef(name)); if (!removed) throw SetPyError(PyExc_KeyError, "attempt to delete a non-existent attribute"); @@ -1323,8 +1331,8 @@ public: } bool dunderContains(const std::string &name) { - return !mlirAttributeIsNull( - mlirOperationGetAttributeByName(operation->get(), name.c_str())); + return !mlirAttributeIsNull(mlirOperationGetAttributeByName( + operation->get(), toMlirStringRef(name))); } static void bind(py::module &m) { @@ -2599,9 +2607,10 @@ void mlir::python::populateIRSubmodule(py::module &m) { "file", [](std::string filename, int line, int col, DefaultingPyMlirContext context) { - return PyLocation(context->getRef(), - mlirLocationFileLineColGet( - context->get(), filename.c_str(), line, col)); + return PyLocation( + context->getRef(), + mlirLocationFileLineColGet( + context->get(), toMlirStringRef(filename), line, col)); }, py::arg("filename"), py::arg("line"), py::arg("col"), py::arg("context") = py::none(), kContextGetFileLocationDocstring) @@ -2625,8 +2634,8 @@ void mlir::python::populateIRSubmodule(py::module &m) { .def_static( "parse", [](const std::string moduleAsm, DefaultingPyMlirContext context) { - MlirModule module = - mlirModuleCreateParse(context->get(), moduleAsm.c_str()); + MlirModule module = mlirModuleCreateParse( + context->get(), toMlirStringRef(moduleAsm)); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirModuleIsNull(module)) { @@ -2875,8 +2884,8 @@ void mlir::python::populateIRSubmodule(py::module &m) { .def_static( "parse", [](std::string attrSpec, DefaultingPyMlirContext context) { - MlirAttribute type = - mlirAttributeParseGet(context->get(), attrSpec.c_str()); + MlirAttribute type = mlirAttributeParseGet( + context->get(), toMlirStringRef(attrSpec)); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirAttributeIsNull(type)) { @@ -2940,7 +2949,7 @@ void mlir::python::populateIRSubmodule(py::module &m) { [](PyNamedAttribute &self) { PyPrintAccumulator printAccum; printAccum.parts.append("NamedAttribute("); - printAccum.parts.append(self.namedAttr.name); + printAccum.parts.append(self.namedAttr.name.data); printAccum.parts.append("="); mlirAttributePrint(self.namedAttr.attribute, printAccum.getCallback(), @@ -2951,7 +2960,8 @@ void mlir::python::populateIRSubmodule(py::module &m) { .def_property_readonly( "name", [](PyNamedAttribute &self) { - return py::str(self.namedAttr.name, strlen(self.namedAttr.name)); + return py::str(self.namedAttr.name.data, + self.namedAttr.name.length); }, "The name of the NamedAttribute binding") .def_property_readonly( @@ -2983,7 +2993,8 @@ void mlir::python::populateIRSubmodule(py::module &m) { .def_static( "parse", [](std::string typeSpec, DefaultingPyMlirContext context) { - MlirType type = mlirTypeParseGet(context->get(), typeSpec.c_str()); + MlirType type = + mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(type)) { diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h index 25cbba282129..4116e9f30b6b 100644 --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -16,7 +16,6 @@ #include #include - namespace mlir { namespace python { @@ -115,10 +114,11 @@ struct PyPrintAccumulator { void *getUserData() { return this; } MlirStringCallback getCallback() { - return [](const char *part, intptr_t size, void *userData) { + return [](MlirStringRef part, void *userData) { PyPrintAccumulator *printAccum = static_cast(userData); - pybind11::str pyPart(part, size); // Decodes as UTF-8 by default. + pybind11::str pyPart(part.data, + part.length); // Decodes as UTF-8 by default. printAccum->parts.append(std::move(pyPart)); }; } @@ -139,15 +139,16 @@ public: void *getUserData() { return this; } MlirStringCallback getCallback() { - return [](const char *part, intptr_t size, void *userData) { + return [](MlirStringRef part, void *userData) { pybind11::gil_scoped_acquire(); PyFileAccumulator *accum = static_cast(userData); if (accum->binary) { // Note: Still has to copy and not avoidable with this API. - pybind11::bytes pyBytes(part, size); + pybind11::bytes pyBytes(part.data, part.length); accum->pyWriteFunction(pyBytes); } else { - pybind11::str pyStr(part, size); // Decodes as UTF-8 by default. + pybind11::str pyStr(part.data, + part.length); // Decodes as UTF-8 by default. accum->pyWriteFunction(pyStr); } }; @@ -165,13 +166,13 @@ struct PySinglePartStringAccumulator { void *getUserData() { return this; } MlirStringCallback getCallback() { - return [](const char *part, intptr_t size, void *userData) { + return [](MlirStringRef part, void *userData) { PySinglePartStringAccumulator *accum = static_cast(userData); assert(!accum->invoked && "PySinglePartStringAccumulator called back multiple times"); accum->invoked = true; - accum->value = pybind11::str(part, size); + accum->value = pybind11::str(part.data, part.length); }; }