forked from OSchip/llvm-project
Add initial python bindings for attributes.
* Generic mlir.ir.Attribute class. * First standard attribute (mlir.ir.StringAttr), following the same pattern as generic vs standard types. * NamedAttribute class. Differential Revision: https://reviews.llvm.org/D86250
This commit is contained in:
parent
1bc45b2fd8
commit
3137c29926
|
@ -336,6 +336,9 @@ void mlirTypeDump(MlirType type);
|
|||
/** Parses an attribute. The attribute is owned by the context. */
|
||||
MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr);
|
||||
|
||||
/** Checks whether an attribute is null. */
|
||||
inline int mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; }
|
||||
|
||||
/** Checks if two attributes are equal. */
|
||||
int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2);
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include "IRModules.h"
|
||||
#include "PybindUtils.h"
|
||||
|
||||
#include "mlir-c/StandardAttributes.h"
|
||||
#include "mlir-c/StandardTypes.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
@ -76,8 +77,52 @@ struct PyPrintAccumulator {
|
|||
}
|
||||
};
|
||||
|
||||
/// Accumulates into a python string from a method that is expected to make
|
||||
/// one (no more, no less) call to the callback (asserts internally on
|
||||
/// violation).
|
||||
struct PySinglePartStringAccumulator {
|
||||
void *getUserData() { return this; }
|
||||
|
||||
MlirStringCallback getCallback() {
|
||||
return [](const char *part, intptr_t size, void *userData) {
|
||||
PySinglePartStringAccumulator *accum =
|
||||
static_cast<PySinglePartStringAccumulator *>(userData);
|
||||
assert(!accum->invoked &&
|
||||
"PySinglePartStringAccumulator called back multiple times");
|
||||
accum->invoked = true;
|
||||
accum->value = py::str(part, size);
|
||||
};
|
||||
}
|
||||
|
||||
py::str takeValue() {
|
||||
assert(invoked && "PySinglePartStringAccumulator not called back");
|
||||
return std::move(value);
|
||||
}
|
||||
|
||||
private:
|
||||
py::str value;
|
||||
bool invoked = false;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// PyAttribute.
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
bool PyAttribute::operator==(const PyAttribute &other) {
|
||||
return mlirAttributeEqual(attr, other.attr);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// PyNamedAttribute.
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
|
||||
: ownedName(new std::string(std::move(ownedName))) {
|
||||
namedAttr = mlirNamedAttributeGet(this->ownedName->c_str(), attr);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// PyType.
|
||||
//------------------------------------------------------------------------------
|
||||
|
@ -86,6 +131,86 @@ bool PyType::operator==(const PyType &other) {
|
|||
return mlirTypeEqual(type, other.type);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Standard attribute subclasses.
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
namespace {
|
||||
|
||||
/// CRTP base classes for Python attributes that subclass Attribute and should
|
||||
/// be castable from it (i.e. via something like StringAttr(attr)).
|
||||
template <typename T>
|
||||
class PyConcreteAttribute : public PyAttribute {
|
||||
public:
|
||||
// Derived classes must define statics for:
|
||||
// IsAFunctionTy isaFunction
|
||||
// const char *pyClassName
|
||||
using ClassTy = py::class_<T, PyAttribute>;
|
||||
using IsAFunctionTy = int (*)(MlirAttribute);
|
||||
|
||||
PyConcreteAttribute() = default;
|
||||
PyConcreteAttribute(MlirAttribute attr) : PyAttribute(attr) {}
|
||||
PyConcreteAttribute(PyAttribute &orig)
|
||||
: PyConcreteAttribute(castFrom(orig)) {}
|
||||
|
||||
static MlirAttribute castFrom(PyAttribute &orig) {
|
||||
if (!T::isaFunction(orig.attr)) {
|
||||
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
|
||||
throw SetPyError(PyExc_ValueError,
|
||||
llvm::Twine("Cannot cast attribute to ") +
|
||||
T::pyClassName + " (from " + origRepr + ")");
|
||||
}
|
||||
return orig.attr;
|
||||
}
|
||||
|
||||
static void bind(py::module &m) {
|
||||
auto cls = ClassTy(m, T::pyClassName);
|
||||
cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
|
||||
T::bindDerived(cls);
|
||||
}
|
||||
|
||||
/// Implemented by derived classes to add methods to the Python subclass.
|
||||
static void bindDerived(ClassTy &m) {}
|
||||
};
|
||||
|
||||
class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
|
||||
static constexpr const char *pyClassName = "StringAttr";
|
||||
using PyConcreteAttribute::PyConcreteAttribute;
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static(
|
||||
"get",
|
||||
[](PyMlirContext &context, std::string value) {
|
||||
MlirAttribute attr =
|
||||
mlirStringAttrGet(context.context, value.size(), &value[0]);
|
||||
return PyStringAttribute(attr);
|
||||
},
|
||||
py::keep_alive<0, 1>(), "Gets a uniqued string attribute");
|
||||
c.def_static(
|
||||
"get_typed",
|
||||
[](PyType &type, std::string value) {
|
||||
MlirAttribute attr =
|
||||
mlirStringAttrTypedGet(type.type, value.size(), &value[0]);
|
||||
return PyStringAttribute(attr);
|
||||
},
|
||||
py::keep_alive<0, 1>(),
|
||||
"Gets a uniqued string attribute associated to a type");
|
||||
c.def_property_readonly(
|
||||
"value",
|
||||
[](PyStringAttribute &self) {
|
||||
PySinglePartStringAccumulator accum;
|
||||
mlirStringAttrGetValue(self.attr, accum.getCallback(),
|
||||
accum.getUserData());
|
||||
return accum.takeValue();
|
||||
},
|
||||
"Returns the value of the string attribute");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Standard type subclasses.
|
||||
//------------------------------------------------------------------------------
|
||||
|
@ -118,9 +243,9 @@ public:
|
|||
}
|
||||
|
||||
static void bind(py::module &m) {
|
||||
auto class_ = ClassTy(m, T::pyClassName);
|
||||
class_.def(py::init<PyType &>(), py::keep_alive<0, 1>());
|
||||
T::bindDerived(class_);
|
||||
auto cls = ClassTy(m, T::pyClassName);
|
||||
cls.def(py::init<PyType &>(), py::keep_alive<0, 1>());
|
||||
T::bindDerived(cls);
|
||||
}
|
||||
|
||||
/// Implemented by derived classes to add methods to the Python subclass.
|
||||
|
@ -135,21 +260,21 @@ public:
|
|||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static(
|
||||
"signless",
|
||||
"get_signless",
|
||||
[](PyMlirContext &context, unsigned width) {
|
||||
MlirType t = mlirIntegerTypeGet(context.context, width);
|
||||
return PyIntegerType(t);
|
||||
},
|
||||
py::keep_alive<0, 1>(), "Create a signless integer type");
|
||||
c.def_static(
|
||||
"signed",
|
||||
"get_signed",
|
||||
[](PyMlirContext &context, unsigned width) {
|
||||
MlirType t = mlirIntegerTypeSignedGet(context.context, width);
|
||||
return PyIntegerType(t);
|
||||
},
|
||||
py::keep_alive<0, 1>(), "Create a signed integer type");
|
||||
c.def_static(
|
||||
"unsigned",
|
||||
"get_unsigned",
|
||||
[](PyMlirContext &context, unsigned width) {
|
||||
MlirType t = mlirIntegerTypeUnsignedGet(context.context, width);
|
||||
return PyIntegerType(t);
|
||||
|
@ -195,6 +320,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
[](PyMlirContext &self, const std::string module) {
|
||||
auto moduleRef =
|
||||
mlirModuleCreateParse(self.context, module.c_str());
|
||||
// TODO: Rework error reporting once diagnostic engine is exposed
|
||||
// in C API.
|
||||
if (mlirModuleIsNull(moduleRef)) {
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
|
@ -203,10 +330,27 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
return PyModule(moduleRef);
|
||||
},
|
||||
py::keep_alive<0, 1>(), kContextParseDocstring)
|
||||
.def(
|
||||
"parse_attr",
|
||||
[](PyMlirContext &self, std::string attrSpec) {
|
||||
MlirAttribute type =
|
||||
mlirAttributeParseGet(self.context, attrSpec.c_str());
|
||||
// TODO: Rework error reporting once diagnostic engine is exposed
|
||||
// in C API.
|
||||
if (mlirAttributeIsNull(type)) {
|
||||
throw SetPyError(PyExc_ValueError,
|
||||
llvm::Twine("Unable to parse attribute: '") +
|
||||
attrSpec + "'");
|
||||
}
|
||||
return PyAttribute(type);
|
||||
},
|
||||
py::keep_alive<0, 1>())
|
||||
.def(
|
||||
"parse_type",
|
||||
[](PyMlirContext &self, std::string typeSpec) {
|
||||
MlirType type = mlirTypeParseGet(self.context, typeSpec.c_str());
|
||||
// TODO: Rework error reporting once diagnostic engine is exposed
|
||||
// in C API.
|
||||
if (mlirTypeIsNull(type)) {
|
||||
throw SetPyError(PyExc_ValueError,
|
||||
llvm::Twine("Unable to parse type: '") +
|
||||
|
@ -235,6 +379,79 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
},
|
||||
kOperationStrDunderDocstring);
|
||||
|
||||
// Mapping of Type.
|
||||
py::class_<PyAttribute>(m, "Attribute")
|
||||
.def(
|
||||
"get_named",
|
||||
[](PyAttribute &self, std::string name) {
|
||||
return PyNamedAttribute(self.attr, std::move(name));
|
||||
},
|
||||
py::keep_alive<0, 1>(), "Binds a name to the attribute")
|
||||
.def("__eq__",
|
||||
[](PyAttribute &self, py::object &other) {
|
||||
try {
|
||||
PyAttribute otherAttribute = other.cast<PyAttribute>();
|
||||
return self == otherAttribute;
|
||||
} catch (std::exception &e) {
|
||||
return false;
|
||||
}
|
||||
})
|
||||
.def(
|
||||
"dump", [](PyAttribute &self) { mlirAttributeDump(self.attr); },
|
||||
kDumpDocstring)
|
||||
.def(
|
||||
"__str__",
|
||||
[](PyAttribute &self) {
|
||||
PyPrintAccumulator printAccum;
|
||||
mlirAttributePrint(self.attr, printAccum.getCallback(),
|
||||
printAccum.getUserData());
|
||||
return printAccum.join();
|
||||
},
|
||||
kTypeStrDunderDocstring)
|
||||
.def("__repr__", [](PyAttribute &self) {
|
||||
// Generally, assembly formats are not printed for __repr__ because
|
||||
// this can cause exceptionally long debug output and exceptions.
|
||||
// However, attribute values are generally considered useful and are
|
||||
// printed. This may need to be re-evaluated if debug dumps end up
|
||||
// being excessive.
|
||||
PyPrintAccumulator printAccum;
|
||||
printAccum.parts.append("Attribute(");
|
||||
mlirAttributePrint(self.attr, printAccum.getCallback(),
|
||||
printAccum.getUserData());
|
||||
printAccum.parts.append(")");
|
||||
return printAccum.join();
|
||||
});
|
||||
|
||||
py::class_<PyNamedAttribute>(m, "NamedAttribute")
|
||||
.def("__repr__",
|
||||
[](PyNamedAttribute &self) {
|
||||
PyPrintAccumulator printAccum;
|
||||
printAccum.parts.append("NamedAttribute(");
|
||||
printAccum.parts.append(self.namedAttr.name);
|
||||
printAccum.parts.append("=");
|
||||
mlirAttributePrint(self.namedAttr.attribute,
|
||||
printAccum.getCallback(),
|
||||
printAccum.getUserData());
|
||||
printAccum.parts.append(")");
|
||||
return printAccum.join();
|
||||
})
|
||||
.def_property_readonly(
|
||||
"name",
|
||||
[](PyNamedAttribute &self) {
|
||||
return py::str(self.namedAttr.name, strlen(self.namedAttr.name));
|
||||
},
|
||||
"The name of the NamedAttribute binding")
|
||||
.def_property_readonly(
|
||||
"attr",
|
||||
[](PyNamedAttribute &self) {
|
||||
return PyAttribute(self.namedAttr.attribute);
|
||||
},
|
||||
py::keep_alive<0, 1>(),
|
||||
"The underlying generic attribute of the NamedAttribute binding");
|
||||
|
||||
// Standard attribute bindings.
|
||||
PyStringAttribute::bind(m);
|
||||
|
||||
// Mapping of Type.
|
||||
py::class_<PyType>(m, "Type")
|
||||
.def("__eq__",
|
||||
|
|
|
@ -45,6 +45,39 @@ public:
|
|||
MlirModule module;
|
||||
};
|
||||
|
||||
/// Wrapper around the generic MlirAttribute.
|
||||
/// The lifetime of a type is bound by the PyContext that created it.
|
||||
class PyAttribute {
|
||||
public:
|
||||
PyAttribute(MlirAttribute attr) : attr(attr) {}
|
||||
bool operator==(const PyAttribute &other);
|
||||
|
||||
MlirAttribute attr;
|
||||
};
|
||||
|
||||
/// Represents a Python MlirNamedAttr, carrying an optional owned name.
|
||||
/// TODO: Refactor this and the C-API to be based on an Identifier owned
|
||||
/// by the context so as to avoid ownership issues here.
|
||||
class PyNamedAttribute {
|
||||
public:
|
||||
/// Constructs a PyNamedAttr that retains an owned name. This should be
|
||||
/// used in any code that originates an MlirNamedAttribute from a python
|
||||
/// string.
|
||||
/// The lifetime of the PyNamedAttr must extend to the lifetime of the
|
||||
/// passed attribute.
|
||||
PyNamedAttribute(MlirAttribute attr, std::string ownedName);
|
||||
|
||||
MlirNamedAttribute namedAttr;
|
||||
|
||||
private:
|
||||
// Since the MlirNamedAttr contains an internal pointer to the actual
|
||||
// memory of the owned string, it must be heap allocated to remain valid.
|
||||
// Otherwise, strings that fit within the small object optimization threshold
|
||||
// will have their memory address change as the containing object is moved,
|
||||
// resulting in an invalid aliased pointer.
|
||||
std::unique_ptr<std::string> ownedName;
|
||||
};
|
||||
|
||||
/// Wrapper around the generic MlirType.
|
||||
/// The lifetime of a type is bound by the PyContext that created it.
|
||||
class PyType {
|
||||
|
|
|
@ -10,8 +10,8 @@
|
|||
|
||||
namespace py = pybind11;
|
||||
|
||||
pybind11::error_already_set mlir::python::SetPyError(PyObject *excClass,
|
||||
llvm::Twine message) {
|
||||
pybind11::error_already_set
|
||||
mlir::python::SetPyError(PyObject *excClass, const llvm::Twine &message) {
|
||||
auto messageStr = message.str();
|
||||
PyErr_SetString(excClass, messageStr.c_str());
|
||||
return pybind11::error_already_set();
|
||||
|
|
|
@ -20,7 +20,8 @@ namespace python {
|
|||
// python runtime.
|
||||
// Correct usage:
|
||||
// throw SetPyError(PyExc_ValueError, "Foobar'd");
|
||||
pybind11::error_already_set SetPyError(PyObject *excClass, llvm::Twine message);
|
||||
pybind11::error_already_set SetPyError(PyObject *excClass,
|
||||
const llvm::Twine &message);
|
||||
|
||||
} // namespace python
|
||||
} // namespace mlir
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
import mlir
|
||||
|
||||
def run(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
f()
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testParsePrint
|
||||
def testParsePrint():
|
||||
ctx = mlir.ir.Context()
|
||||
t = ctx.parse_attr('"hello"')
|
||||
# CHECK: "hello"
|
||||
print(str(t))
|
||||
# CHECK: Attribute("hello")
|
||||
print(repr(t))
|
||||
|
||||
run(testParsePrint)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testParseError
|
||||
# TODO: Hook the diagnostic manager to capture a more meaningful error
|
||||
# message.
|
||||
def testParseError():
|
||||
ctx = mlir.ir.Context()
|
||||
try:
|
||||
t = ctx.parse_attr("BAD_ATTR_DOES_NOT_EXIST")
|
||||
except ValueError as e:
|
||||
# CHECK: Unable to parse attribute: 'BAD_ATTR_DOES_NOT_EXIST'
|
||||
print("testParseError:", e)
|
||||
else:
|
||||
print("Exception not produced")
|
||||
|
||||
run(testParseError)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testAttrEq
|
||||
def testAttrEq():
|
||||
ctx = mlir.ir.Context()
|
||||
a1 = ctx.parse_attr('"attr1"')
|
||||
a2 = ctx.parse_attr('"attr2"')
|
||||
a3 = ctx.parse_attr('"attr1"')
|
||||
# CHECK: a1 == a1: True
|
||||
print("a1 == a1:", a1 == a1)
|
||||
# CHECK: a1 == a2: False
|
||||
print("a1 == a2:", a1 == a2)
|
||||
# CHECK: a1 == a3: True
|
||||
print("a1 == a3:", a1 == a3)
|
||||
# CHECK: a1 == None: False
|
||||
print("a1 == None:", a1 == None)
|
||||
|
||||
run(testAttrEq)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testAttrEqDoesNotRaise
|
||||
def testAttrEqDoesNotRaise():
|
||||
ctx = mlir.ir.Context()
|
||||
a1 = ctx.parse_attr('"attr1"')
|
||||
not_an_attr = "foo"
|
||||
# CHECK: False
|
||||
print(a1 == not_an_attr)
|
||||
# CHECK: False
|
||||
print(a1 == None)
|
||||
# CHECK: True
|
||||
print(a1 != None)
|
||||
|
||||
run(testAttrEqDoesNotRaise)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testStandardAttrCasts
|
||||
def testStandardAttrCasts():
|
||||
ctx = mlir.ir.Context()
|
||||
a1 = ctx.parse_attr('"attr1"')
|
||||
astr = mlir.ir.StringAttr(a1)
|
||||
aself = mlir.ir.StringAttr(astr)
|
||||
# CHECK: Attribute("attr1")
|
||||
print(repr(astr))
|
||||
try:
|
||||
tillegal = mlir.ir.StringAttr(ctx.parse_attr("1.0"))
|
||||
except ValueError as e:
|
||||
# CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64))
|
||||
print("ValueError:", e)
|
||||
else:
|
||||
print("Exception not produced")
|
||||
|
||||
run(testStandardAttrCasts)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testStringAttr
|
||||
def testStringAttr():
|
||||
ctx = mlir.ir.Context()
|
||||
sattr = mlir.ir.StringAttr(ctx.parse_attr('"stringattr"'))
|
||||
# CHECK: sattr value: stringattr
|
||||
print("sattr value:", sattr.value)
|
||||
|
||||
# Test factory methods.
|
||||
# CHECK: default_get: "foobar"
|
||||
print("default_get:", mlir.ir.StringAttr.get(ctx, "foobar"))
|
||||
# CHECK: typed_get: "12345" : i32
|
||||
print("typed_get:", mlir.ir.StringAttr.get_typed(
|
||||
mlir.ir.IntegerType.get_signless(ctx, 32), "12345"))
|
||||
|
||||
run(testStringAttr)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testNamedAttr
|
||||
def testNamedAttr():
|
||||
ctx = mlir.ir.Context()
|
||||
a = ctx.parse_attr('"stringattr"')
|
||||
named = a.get_named("foobar") # Note: under the small object threshold
|
||||
# CHECK: attr: "stringattr"
|
||||
print("attr:", named.attr)
|
||||
# CHECK: name: foobar
|
||||
print("name:", named.name)
|
||||
# CHECK: named: NamedAttribute(foobar="stringattr")
|
||||
print("named:", named)
|
||||
|
||||
run(testNamedAttr)
|
|
@ -117,10 +117,10 @@ def testIntegerType():
|
|||
print("u32 unsigned:", u32.is_unsigned)
|
||||
|
||||
# CHECK: signless: i16
|
||||
print("signless:", mlir.ir.IntegerType.signless(ctx, 16))
|
||||
print("signless:", mlir.ir.IntegerType.get_signless(ctx, 16))
|
||||
# CHECK: signed: si8
|
||||
print("signed:", mlir.ir.IntegerType.signed(ctx, 8))
|
||||
print("signed:", mlir.ir.IntegerType.get_signed(ctx, 8))
|
||||
# CHECK: unsigned: ui64
|
||||
print("unsigned:", mlir.ir.IntegerType.unsigned(ctx, 64))
|
||||
print("unsigned:", mlir.ir.IntegerType.get_unsigned(ctx, 64))
|
||||
|
||||
run(testIntegerType)
|
||||
|
|
Loading…
Reference in New Issue