[mlir] Add Python bindings for IntegerSet

This follows up on the introduction of C API for the same object and is similar
to AffineExpr and AffineMap.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D95437
This commit is contained in:
Alex Zinenko 2021-01-25 18:17:19 +01:00
parent 00773ef78a
commit b208e5bcd0
4 changed files with 390 additions and 18 deletions

View File

@ -26,12 +26,14 @@
#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "mlir-c/Pass.h"
#define MLIR_PYTHON_CAPSULE_AFFINE_EXPR "mlir.ir.AffineExpr._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_AFFINE_MAP "mlir.ir.AffineMap._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_ATTRIBUTE "mlir.ir.Attribute._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_CONTEXT "mlir.ir.Context._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_INTEGER_SET "mlir.ir.IntegerSet._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_LOCATION "mlir.ir.Location._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_MODULE "mlir.ir.Module._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_OPERATION "mlir.ir.Operation._CAPIPtr"
@ -240,6 +242,25 @@ static inline MlirAffineMap mlirPythonCapsuleToAffineMap(PyObject *capsule) {
return affineMap;
}
/** Creates a capsule object encapsulating the raw C-API MlirIntegerSet.
* The returned capsule does not extend or affect ownership of any Python
* objects that reference the set in any way. */
static inline PyObject *
mlirPythonIntegerSetToCapsule(MlirIntegerSet integerSet) {
return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(integerSet),
MLIR_PYTHON_CAPSULE_INTEGER_SET, NULL);
}
/** Extracts an MlirIntegerSet from a capsule as produced from
* mlirPythonIntegerSetToCapsule. If the capsule is not of the right type, then
* a null set is returned (as checked via mlirIntegerSetIsNull). In such a
* case, the Python APIs will have already set an error. */
static inline MlirIntegerSet mlirPythonCapsuleToIntegerSet(PyObject *capsule) {
void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_INTEGER_SET);
MlirIntegerSet integerSet = {ptr};
return integerSet;
}
#ifdef __cplusplus
}
#endif

View File

@ -15,6 +15,7 @@
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/IntegerSet.h"
#include "mlir-c/Registration.h"
#include "llvm/ADT/SmallVector.h"
#include <pybind11/stl.h>
@ -3331,6 +3332,102 @@ PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
rawAffineMap);
}
//------------------------------------------------------------------------------
// PyIntegerSet and utilities.
//------------------------------------------------------------------------------
class PyIntegerSetConstraint {
public:
PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {}
PyAffineExpr getExpr() {
return PyAffineExpr(set.getContext(),
mlirIntegerSetGetConstraint(set, pos));
}
bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
static void bind(py::module &m) {
py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint")
.def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
.def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
}
private:
PyIntegerSet set;
intptr_t pos;
};
class PyIntegerSetConstraintList
: public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
public:
static constexpr const char *pyClassName = "IntegerSetConstraintList";
PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0,
intptr_t length = -1, intptr_t step = 1)
: Sliceable(startIndex,
length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
step),
set(set) {}
intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); }
PyIntegerSetConstraint getElement(intptr_t pos) {
return PyIntegerSetConstraint(set, pos);
}
PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
intptr_t step) {
return PyIntegerSetConstraintList(set, startIndex, length, step);
}
private:
PyIntegerSet set;
};
bool PyIntegerSet::operator==(const PyIntegerSet &other) {
return mlirIntegerSetEqual(integerSet, other.integerSet);
}
py::object PyIntegerSet::getCapsule() {
return py::reinterpret_steal<py::object>(
mlirPythonIntegerSetToCapsule(*this));
}
PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
if (mlirIntegerSetIsNull(rawIntegerSet))
throw py::error_already_set();
return PyIntegerSet(
PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
rawIntegerSet);
}
/// Attempts to populate `result` with the content of `list` casted to the
/// appropriate type (Python and C types are provided as template arguments).
/// Throws errors in case of failure, using "action" to describe what the caller
/// was attempting to do.
template <typename PyType, typename CType>
static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result,
StringRef action) {
result.reserve(py::len(list));
for (py::handle item : list) {
try {
result.push_back(item.cast<PyType>());
} catch (py::cast_error &err) {
std::string msg = (llvm::Twine("Invalid expression when ") + action +
" (" + err.what() + ")")
.str();
throw py::cast_error(msg);
} catch (py::reference_cast_error &err) {
std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
action + " (" + err.what() + ")")
.str();
throw py::cast_error(msg);
}
}
}
//------------------------------------------------------------------------------
// Populates the pybind11 IR submodule.
//------------------------------------------------------------------------------
@ -4152,24 +4249,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
[](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
DefaultingPyMlirContext context) {
SmallVector<MlirAffineExpr> affineExprs;
affineExprs.reserve(py::len(exprs));
for (py::handle expr : exprs) {
try {
affineExprs.push_back(expr.cast<PyAffineExpr>());
} catch (py::cast_error &err) {
std::string msg =
std::string("Invalid expression when attempting to create "
"an AffineMap (") +
err.what() + ")";
throw py::cast_error(msg);
} catch (py::reference_cast_error &err) {
std::string msg =
std::string("Invalid expression (None?) when attempting to "
"create an AffineMap (") +
err.what() + ")";
throw py::cast_error(msg);
}
}
pyListToVector<PyAffineExpr, MlirAffineExpr>(
exprs, affineExprs, "attempting to create an AffineMap");
MlirAffineMap map =
mlirAffineMapGet(context->get(), dimCount, symbolCount,
affineExprs.size(), affineExprs.data());
@ -4275,4 +4356,125 @@ void mlir::python::populateIRSubmodule(py::module &m) {
return PyAffineMapExprList(self);
});
PyAffineMapExprList::bind(m);
//----------------------------------------------------------------------------
// Mapping of PyIntegerSet.
//----------------------------------------------------------------------------
py::class_<PyIntegerSet>(m, "IntegerSet")
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyIntegerSet::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
.def("__eq__", [](PyIntegerSet &self,
PyIntegerSet &other) { return self == other; })
.def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
.def("__str__",
[](PyIntegerSet &self) {
PyPrintAccumulator printAccum;
mlirIntegerSetPrint(self, printAccum.getCallback(),
printAccum.getUserData());
return printAccum.join();
})
.def("__repr__",
[](PyIntegerSet &self) {
PyPrintAccumulator printAccum;
printAccum.parts.append("IntegerSet(");
mlirIntegerSetPrint(self, printAccum.getCallback(),
printAccum.getUserData());
printAccum.parts.append(")");
return printAccum.join();
})
.def_property_readonly(
"context",
[](PyIntegerSet &self) { return self.getContext().getObject(); })
.def(
"dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
kDumpDocstring)
.def_static(
"get",
[](intptr_t numDims, intptr_t numSymbols, py::list exprs,
std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
if (exprs.size() != eqFlags.size())
throw py::value_error(
"Expected the number of constraints to match "
"that of equality flags");
if (exprs.empty())
throw py::value_error("Expected non-empty list of constraints");
// Copy over to a SmallVector because std::vector has a
// specialization for booleans that packs data and does not
// expose a `bool *`.
SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());
SmallVector<MlirAffineExpr> affineExprs;
pyListToVector<PyAffineExpr>(exprs, affineExprs,
"attempting to create an IntegerSet");
MlirIntegerSet set = mlirIntegerSetGet(
context->get(), numDims, numSymbols, exprs.size(),
affineExprs.data(), flags.data());
return PyIntegerSet(context->getRef(), set);
},
py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
py::arg("eq_flags"), py::arg("context") = py::none())
.def_static(
"get_empty",
[](intptr_t numDims, intptr_t numSymbols,
DefaultingPyMlirContext context) {
MlirIntegerSet set =
mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
return PyIntegerSet(context->getRef(), set);
},
py::arg("num_dims"), py::arg("num_symbols"),
py::arg("context") = py::none())
.def("get_replaced",
[](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
intptr_t numResultDims, intptr_t numResultSymbols) {
if (static_cast<intptr_t>(dimExprs.size()) !=
mlirIntegerSetGetNumDims(self))
throw py::value_error(
"Expected the number of dimension replacement expressions "
"to match that of dimensions");
if (static_cast<intptr_t>(symbolExprs.size()) !=
mlirIntegerSetGetNumSymbols(self))
throw py::value_error(
"Expected the number of symbol replacement expressions "
"to match that of symbols");
SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
pyListToVector<PyAffineExpr>(
dimExprs, dimAffineExprs,
"attempting to create an IntegerSet by replacing dimensions");
pyListToVector<PyAffineExpr>(
symbolExprs, symbolAffineExprs,
"attempting to create an IntegerSet by replacing symbols");
MlirIntegerSet set = mlirIntegerSetReplaceGet(
self, dimAffineExprs.data(), symbolAffineExprs.data(),
numResultDims, numResultSymbols);
return PyIntegerSet(self.getContext(), set);
})
.def_property_readonly("is_canonical_empty",
[](PyIntegerSet &self) {
return mlirIntegerSetIsCanonicalEmpty(self);
})
.def_property_readonly(
"n_dims",
[](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
.def_property_readonly(
"n_symbols",
[](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
.def_property_readonly(
"n_inputs",
[](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
.def_property_readonly("n_equalities",
[](PyIntegerSet &self) {
return mlirIntegerSetGetNumEqualities(self);
})
.def_property_readonly("n_inequalities",
[](PyIntegerSet &self) {
return mlirIntegerSetGetNumInequalities(self);
})
.def_property_readonly("constraints", [](PyIntegerSet &self) {
return PyIntegerSetConstraintList(self);
});
PyIntegerSetConstraint::bind(m);
PyIntegerSetConstraintList::bind(m);
}

View File

@ -16,6 +16,7 @@
#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "llvm/ADT/DenseMap.h"
namespace mlir {
@ -726,6 +727,26 @@ private:
MlirAffineMap affineMap;
};
class PyIntegerSet : public BaseContextObject {
public:
PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet)
: BaseContextObject(std::move(contextRef)), integerSet(integerSet) {}
bool operator==(const PyIntegerSet &other);
operator MlirIntegerSet() const { return integerSet; }
MlirIntegerSet get() const { return integerSet; }
/// Gets a capsule wrapping the void* within the MlirIntegerSet.
pybind11::object getCapsule();
/// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule.
/// Note that PyIntegerSet instances may be uniqued, so the returned object
/// may be a pre-existing object. Integer sets are owned by the context.
static PyIntegerSet createFromCapsule(pybind11::object capsule);
private:
MlirIntegerSet integerSet;
};
void populateIRSubmodule(pybind11::module &m);
} // namespace python

View File

@ -0,0 +1,128 @@
# RUN: %PYTHON %s | FileCheck %s
import gc
from mlir.ir import *
def run(f):
print("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0
# CHECK-LABEL: TEST: testIntegerSetCapsule
def testIntegerSetCapsule():
with Context() as ctx:
is1 = IntegerSet.get_empty(1, 1, ctx)
capsule = is1._CAPIPtr
# CHECK: mlir.ir.IntegerSet._CAPIPtr
print(capsule)
is2 = IntegerSet._CAPICreate(capsule)
assert is1 == is2
assert is2.context is ctx
run(testIntegerSetCapsule)
# CHECK-LABEL: TEST: testIntegerSetGet
def testIntegerSetGet():
with Context():
d0 = AffineDimExpr.get(0)
d1 = AffineDimExpr.get(1)
s0 = AffineSymbolExpr.get(0)
c42 = AffineConstantExpr.get(42)
# CHECK: (d0, d1)[s0] : (d0 - d1 == 0, s0 - 42 >= 0)
set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42], [True, False])
print(set0)
# CHECK: (d0)[s0] : (1 == 0)
set1 = IntegerSet.get_empty(1, 1)
print(set1)
# CHECK: (d0)[s0, s1] : (d0 - s1 == 0, s0 - 42 >= 0)
set2 = set0.get_replaced([d0, AffineSymbolExpr.get(1)], [s0], 1, 2)
print(set2)
try:
IntegerSet.get(2, 1, [], [])
except ValueError as e:
# CHECK: Expected non-empty list of constraints
print(e)
try:
IntegerSet.get(2, 1, [d0 - d1], [True, False])
except ValueError as e:
# CHECK: Expected the number of constraints to match that of equality flags
print(e)
try:
IntegerSet.get(2, 1, [0], [True])
except RuntimeError as e:
# CHECK: Invalid expression when attempting to create an IntegerSet
print(e)
try:
IntegerSet.get(2, 1, [None], [True])
except RuntimeError as e:
# CHECK: Invalid expression (None?) when attempting to create an IntegerSet
print(e)
try:
set0.get_replaced([d0], [s0], 1, 1)
except ValueError as e:
# CHECK: Expected the number of dimension replacement expressions to match that of dimensions
print(e)
try:
set0.get_replaced([d0, d1], [s0, s0], 1, 1)
except ValueError as e:
# CHECK: Expected the number of symbol replacement expressions to match that of symbols
print(e)
try:
set0.get_replaced([d0, 1], [s0], 1, 1)
except RuntimeError as e:
# CHECK: Invalid expression when attempting to create an IntegerSet by replacing dimensions
print(e)
try:
set0.get_replaced([d0, d1], [None], 1, 1)
except RuntimeError as e:
# CHECK: Invalid expression (None?) when attempting to create an IntegerSet by replacing symbols
print(e)
run(testIntegerSetGet)
# CHECK-LABEL: TEST: testIntegerSetProperties
def testIntegerSetProperties():
with Context():
d0 = AffineDimExpr.get(0)
d1 = AffineDimExpr.get(1)
s0 = AffineSymbolExpr.get(0)
c42 = AffineConstantExpr.get(42)
set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42, s0 - d0], [True, False, False])
# CHECK: 2
print(set0.n_dims)
# CHECK: 1
print(set0.n_symbols)
# CHECK: 3
print(set0.n_inputs)
# CHECK: 1
print(set0.n_equalities)
# CHECK: 2
print(set0.n_inequalities)
# CHECK: 3
print(len(set0.constraints))
# CHECK-DAG: d0 - d1 == 0
# CHECK-DAG: s0 - 42 >= 0
# CHECK-DAG: -d0 + s0 >= 0
for cstr in set0.constraints:
print(cstr.expr, end='')
print(" == 0" if cstr.is_eq else " >= 0")
run(testIntegerSetProperties)