Add initial python bindings for attributes.

* Generic mlir.ir.Attribute class.
* First standard attribute (mlir.ir.StringAttr), following the same pattern as generic vs standard types.
* NamedAttribute class.

Differential Revision: https://reviews.llvm.org/D86250
This commit is contained in:
Stella Laurenzo 2020-08-19 15:33:02 -07:00
parent 1bc45b2fd8
commit 3137c29926
7 changed files with 385 additions and 12 deletions

View File

@ -336,6 +336,9 @@ void mlirTypeDump(MlirType type);
/** Parses an attribute. The attribute is owned by the context. */
MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr);
/** Checks whether an attribute is null. */
inline int mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; }
/** Checks if two attributes are equal. */
int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2);

View File

@ -9,6 +9,7 @@
#include "IRModules.h"
#include "PybindUtils.h"
#include "mlir-c/StandardAttributes.h"
#include "mlir-c/StandardTypes.h"
namespace py = pybind11;
@ -76,8 +77,52 @@ struct PyPrintAccumulator {
}
};
/// Accumulates into a python string from a method that is expected to make
/// one (no more, no less) call to the callback (asserts internally on
/// violation).
struct PySinglePartStringAccumulator {
void *getUserData() { return this; }
MlirStringCallback getCallback() {
return [](const char *part, intptr_t size, void *userData) {
PySinglePartStringAccumulator *accum =
static_cast<PySinglePartStringAccumulator *>(userData);
assert(!accum->invoked &&
"PySinglePartStringAccumulator called back multiple times");
accum->invoked = true;
accum->value = py::str(part, size);
};
}
py::str takeValue() {
assert(invoked && "PySinglePartStringAccumulator not called back");
return std::move(value);
}
private:
py::str value;
bool invoked = false;
};
} // namespace
//------------------------------------------------------------------------------
// PyAttribute.
//------------------------------------------------------------------------------
bool PyAttribute::operator==(const PyAttribute &other) {
return mlirAttributeEqual(attr, other.attr);
}
//------------------------------------------------------------------------------
// PyNamedAttribute.
//------------------------------------------------------------------------------
PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
: ownedName(new std::string(std::move(ownedName))) {
namedAttr = mlirNamedAttributeGet(this->ownedName->c_str(), attr);
}
//------------------------------------------------------------------------------
// PyType.
//------------------------------------------------------------------------------
@ -86,6 +131,86 @@ bool PyType::operator==(const PyType &other) {
return mlirTypeEqual(type, other.type);
}
//------------------------------------------------------------------------------
// Standard attribute subclasses.
//------------------------------------------------------------------------------
namespace {
/// CRTP base classes for Python attributes that subclass Attribute and should
/// be castable from it (i.e. via something like StringAttr(attr)).
template <typename T>
class PyConcreteAttribute : public PyAttribute {
public:
// Derived classes must define statics for:
// IsAFunctionTy isaFunction
// const char *pyClassName
using ClassTy = py::class_<T, PyAttribute>;
using IsAFunctionTy = int (*)(MlirAttribute);
PyConcreteAttribute() = default;
PyConcreteAttribute(MlirAttribute attr) : PyAttribute(attr) {}
PyConcreteAttribute(PyAttribute &orig)
: PyConcreteAttribute(castFrom(orig)) {}
static MlirAttribute castFrom(PyAttribute &orig) {
if (!T::isaFunction(orig.attr)) {
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
throw SetPyError(PyExc_ValueError,
llvm::Twine("Cannot cast attribute to ") +
T::pyClassName + " (from " + origRepr + ")");
}
return orig.attr;
}
static void bind(py::module &m) {
auto cls = ClassTy(m, T::pyClassName);
cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
T::bindDerived(cls);
}
/// Implemented by derived classes to add methods to the Python subclass.
static void bindDerived(ClassTy &m) {}
};
class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
static constexpr const char *pyClassName = "StringAttr";
using PyConcreteAttribute::PyConcreteAttribute;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](PyMlirContext &context, std::string value) {
MlirAttribute attr =
mlirStringAttrGet(context.context, value.size(), &value[0]);
return PyStringAttribute(attr);
},
py::keep_alive<0, 1>(), "Gets a uniqued string attribute");
c.def_static(
"get_typed",
[](PyType &type, std::string value) {
MlirAttribute attr =
mlirStringAttrTypedGet(type.type, value.size(), &value[0]);
return PyStringAttribute(attr);
},
py::keep_alive<0, 1>(),
"Gets a uniqued string attribute associated to a type");
c.def_property_readonly(
"value",
[](PyStringAttribute &self) {
PySinglePartStringAccumulator accum;
mlirStringAttrGetValue(self.attr, accum.getCallback(),
accum.getUserData());
return accum.takeValue();
},
"Returns the value of the string attribute");
}
};
} // namespace
//------------------------------------------------------------------------------
// Standard type subclasses.
//------------------------------------------------------------------------------
@ -118,9 +243,9 @@ public:
}
static void bind(py::module &m) {
auto class_ = ClassTy(m, T::pyClassName);
class_.def(py::init<PyType &>(), py::keep_alive<0, 1>());
T::bindDerived(class_);
auto cls = ClassTy(m, T::pyClassName);
cls.def(py::init<PyType &>(), py::keep_alive<0, 1>());
T::bindDerived(cls);
}
/// Implemented by derived classes to add methods to the Python subclass.
@ -135,21 +260,21 @@ public:
static void bindDerived(ClassTy &c) {
c.def_static(
"signless",
"get_signless",
[](PyMlirContext &context, unsigned width) {
MlirType t = mlirIntegerTypeGet(context.context, width);
return PyIntegerType(t);
},
py::keep_alive<0, 1>(), "Create a signless integer type");
c.def_static(
"signed",
"get_signed",
[](PyMlirContext &context, unsigned width) {
MlirType t = mlirIntegerTypeSignedGet(context.context, width);
return PyIntegerType(t);
},
py::keep_alive<0, 1>(), "Create a signed integer type");
c.def_static(
"unsigned",
"get_unsigned",
[](PyMlirContext &context, unsigned width) {
MlirType t = mlirIntegerTypeUnsignedGet(context.context, width);
return PyIntegerType(t);
@ -195,6 +320,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
[](PyMlirContext &self, const std::string module) {
auto moduleRef =
mlirModuleCreateParse(self.context, module.c_str());
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirModuleIsNull(moduleRef)) {
throw SetPyError(
PyExc_ValueError,
@ -203,10 +330,27 @@ void mlir::python::populateIRSubmodule(py::module &m) {
return PyModule(moduleRef);
},
py::keep_alive<0, 1>(), kContextParseDocstring)
.def(
"parse_attr",
[](PyMlirContext &self, std::string attrSpec) {
MlirAttribute type =
mlirAttributeParseGet(self.context, attrSpec.c_str());
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirAttributeIsNull(type)) {
throw SetPyError(PyExc_ValueError,
llvm::Twine("Unable to parse attribute: '") +
attrSpec + "'");
}
return PyAttribute(type);
},
py::keep_alive<0, 1>())
.def(
"parse_type",
[](PyMlirContext &self, std::string typeSpec) {
MlirType type = mlirTypeParseGet(self.context, typeSpec.c_str());
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(type)) {
throw SetPyError(PyExc_ValueError,
llvm::Twine("Unable to parse type: '") +
@ -235,6 +379,79 @@ void mlir::python::populateIRSubmodule(py::module &m) {
},
kOperationStrDunderDocstring);
// Mapping of Type.
py::class_<PyAttribute>(m, "Attribute")
.def(
"get_named",
[](PyAttribute &self, std::string name) {
return PyNamedAttribute(self.attr, std::move(name));
},
py::keep_alive<0, 1>(), "Binds a name to the attribute")
.def("__eq__",
[](PyAttribute &self, py::object &other) {
try {
PyAttribute otherAttribute = other.cast<PyAttribute>();
return self == otherAttribute;
} catch (std::exception &e) {
return false;
}
})
.def(
"dump", [](PyAttribute &self) { mlirAttributeDump(self.attr); },
kDumpDocstring)
.def(
"__str__",
[](PyAttribute &self) {
PyPrintAccumulator printAccum;
mlirAttributePrint(self.attr, printAccum.getCallback(),
printAccum.getUserData());
return printAccum.join();
},
kTypeStrDunderDocstring)
.def("__repr__", [](PyAttribute &self) {
// Generally, assembly formats are not printed for __repr__ because
// this can cause exceptionally long debug output and exceptions.
// However, attribute values are generally considered useful and are
// printed. This may need to be re-evaluated if debug dumps end up
// being excessive.
PyPrintAccumulator printAccum;
printAccum.parts.append("Attribute(");
mlirAttributePrint(self.attr, printAccum.getCallback(),
printAccum.getUserData());
printAccum.parts.append(")");
return printAccum.join();
});
py::class_<PyNamedAttribute>(m, "NamedAttribute")
.def("__repr__",
[](PyNamedAttribute &self) {
PyPrintAccumulator printAccum;
printAccum.parts.append("NamedAttribute(");
printAccum.parts.append(self.namedAttr.name);
printAccum.parts.append("=");
mlirAttributePrint(self.namedAttr.attribute,
printAccum.getCallback(),
printAccum.getUserData());
printAccum.parts.append(")");
return printAccum.join();
})
.def_property_readonly(
"name",
[](PyNamedAttribute &self) {
return py::str(self.namedAttr.name, strlen(self.namedAttr.name));
},
"The name of the NamedAttribute binding")
.def_property_readonly(
"attr",
[](PyNamedAttribute &self) {
return PyAttribute(self.namedAttr.attribute);
},
py::keep_alive<0, 1>(),
"The underlying generic attribute of the NamedAttribute binding");
// Standard attribute bindings.
PyStringAttribute::bind(m);
// Mapping of Type.
py::class_<PyType>(m, "Type")
.def("__eq__",

View File

@ -45,6 +45,39 @@ public:
MlirModule module;
};
/// Wrapper around the generic MlirAttribute.
/// The lifetime of a type is bound by the PyContext that created it.
class PyAttribute {
public:
PyAttribute(MlirAttribute attr) : attr(attr) {}
bool operator==(const PyAttribute &other);
MlirAttribute attr;
};
/// Represents a Python MlirNamedAttr, carrying an optional owned name.
/// TODO: Refactor this and the C-API to be based on an Identifier owned
/// by the context so as to avoid ownership issues here.
class PyNamedAttribute {
public:
/// Constructs a PyNamedAttr that retains an owned name. This should be
/// used in any code that originates an MlirNamedAttribute from a python
/// string.
/// The lifetime of the PyNamedAttr must extend to the lifetime of the
/// passed attribute.
PyNamedAttribute(MlirAttribute attr, std::string ownedName);
MlirNamedAttribute namedAttr;
private:
// Since the MlirNamedAttr contains an internal pointer to the actual
// memory of the owned string, it must be heap allocated to remain valid.
// Otherwise, strings that fit within the small object optimization threshold
// will have their memory address change as the containing object is moved,
// resulting in an invalid aliased pointer.
std::unique_ptr<std::string> ownedName;
};
/// Wrapper around the generic MlirType.
/// The lifetime of a type is bound by the PyContext that created it.
class PyType {

View File

@ -10,8 +10,8 @@
namespace py = pybind11;
pybind11::error_already_set mlir::python::SetPyError(PyObject *excClass,
llvm::Twine message) {
pybind11::error_already_set
mlir::python::SetPyError(PyObject *excClass, const llvm::Twine &message) {
auto messageStr = message.str();
PyErr_SetString(excClass, messageStr.c_str());
return pybind11::error_already_set();

View File

@ -20,7 +20,8 @@ namespace python {
// python runtime.
// Correct usage:
// throw SetPyError(PyExc_ValueError, "Foobar'd");
pybind11::error_already_set SetPyError(PyObject *excClass, llvm::Twine message);
pybind11::error_already_set SetPyError(PyObject *excClass,
const llvm::Twine &message);
} // namespace python
} // namespace mlir

View File

@ -0,0 +1,119 @@
# RUN: %PYTHON %s | FileCheck %s
import mlir
def run(f):
print("\nTEST:", f.__name__)
f()
# CHECK-LABEL: TEST: testParsePrint
def testParsePrint():
ctx = mlir.ir.Context()
t = ctx.parse_attr('"hello"')
# CHECK: "hello"
print(str(t))
# CHECK: Attribute("hello")
print(repr(t))
run(testParsePrint)
# CHECK-LABEL: TEST: testParseError
# TODO: Hook the diagnostic manager to capture a more meaningful error
# message.
def testParseError():
ctx = mlir.ir.Context()
try:
t = ctx.parse_attr("BAD_ATTR_DOES_NOT_EXIST")
except ValueError as e:
# CHECK: Unable to parse attribute: 'BAD_ATTR_DOES_NOT_EXIST'
print("testParseError:", e)
else:
print("Exception not produced")
run(testParseError)
# CHECK-LABEL: TEST: testAttrEq
def testAttrEq():
ctx = mlir.ir.Context()
a1 = ctx.parse_attr('"attr1"')
a2 = ctx.parse_attr('"attr2"')
a3 = ctx.parse_attr('"attr1"')
# CHECK: a1 == a1: True
print("a1 == a1:", a1 == a1)
# CHECK: a1 == a2: False
print("a1 == a2:", a1 == a2)
# CHECK: a1 == a3: True
print("a1 == a3:", a1 == a3)
# CHECK: a1 == None: False
print("a1 == None:", a1 == None)
run(testAttrEq)
# CHECK-LABEL: TEST: testAttrEqDoesNotRaise
def testAttrEqDoesNotRaise():
ctx = mlir.ir.Context()
a1 = ctx.parse_attr('"attr1"')
not_an_attr = "foo"
# CHECK: False
print(a1 == not_an_attr)
# CHECK: False
print(a1 == None)
# CHECK: True
print(a1 != None)
run(testAttrEqDoesNotRaise)
# CHECK-LABEL: TEST: testStandardAttrCasts
def testStandardAttrCasts():
ctx = mlir.ir.Context()
a1 = ctx.parse_attr('"attr1"')
astr = mlir.ir.StringAttr(a1)
aself = mlir.ir.StringAttr(astr)
# CHECK: Attribute("attr1")
print(repr(astr))
try:
tillegal = mlir.ir.StringAttr(ctx.parse_attr("1.0"))
except ValueError as e:
# CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64))
print("ValueError:", e)
else:
print("Exception not produced")
run(testStandardAttrCasts)
# CHECK-LABEL: TEST: testStringAttr
def testStringAttr():
ctx = mlir.ir.Context()
sattr = mlir.ir.StringAttr(ctx.parse_attr('"stringattr"'))
# CHECK: sattr value: stringattr
print("sattr value:", sattr.value)
# Test factory methods.
# CHECK: default_get: "foobar"
print("default_get:", mlir.ir.StringAttr.get(ctx, "foobar"))
# CHECK: typed_get: "12345" : i32
print("typed_get:", mlir.ir.StringAttr.get_typed(
mlir.ir.IntegerType.get_signless(ctx, 32), "12345"))
run(testStringAttr)
# CHECK-LABEL: TEST: testNamedAttr
def testNamedAttr():
ctx = mlir.ir.Context()
a = ctx.parse_attr('"stringattr"')
named = a.get_named("foobar") # Note: under the small object threshold
# CHECK: attr: "stringattr"
print("attr:", named.attr)
# CHECK: name: foobar
print("name:", named.name)
# CHECK: named: NamedAttribute(foobar="stringattr")
print("named:", named)
run(testNamedAttr)

View File

@ -117,10 +117,10 @@ def testIntegerType():
print("u32 unsigned:", u32.is_unsigned)
# CHECK: signless: i16
print("signless:", mlir.ir.IntegerType.signless(ctx, 16))
print("signless:", mlir.ir.IntegerType.get_signless(ctx, 16))
# CHECK: signed: si8
print("signed:", mlir.ir.IntegerType.signed(ctx, 8))
print("signed:", mlir.ir.IntegerType.get_signed(ctx, 8))
# CHECK: unsigned: ui64
print("unsigned:", mlir.ir.IntegerType.unsigned(ctx, 64))
print("unsigned:", mlir.ir.IntegerType.get_unsigned(ctx, 64))
run(testIntegerType)