forked from OSchip/llvm-project
[mlir] Add Python bindings for IntegerSet
This follows up on the introduction of C API for the same object and is similar to AffineExpr and AffineMap. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D95437
This commit is contained in:
parent
00773ef78a
commit
b208e5bcd0
|
@ -26,12 +26,14 @@
|
|||
#include "mlir-c/AffineExpr.h"
|
||||
#include "mlir-c/AffineMap.h"
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/IntegerSet.h"
|
||||
#include "mlir-c/Pass.h"
|
||||
|
||||
#define MLIR_PYTHON_CAPSULE_AFFINE_EXPR "mlir.ir.AffineExpr._CAPIPtr"
|
||||
#define MLIR_PYTHON_CAPSULE_AFFINE_MAP "mlir.ir.AffineMap._CAPIPtr"
|
||||
#define MLIR_PYTHON_CAPSULE_ATTRIBUTE "mlir.ir.Attribute._CAPIPtr"
|
||||
#define MLIR_PYTHON_CAPSULE_CONTEXT "mlir.ir.Context._CAPIPtr"
|
||||
#define MLIR_PYTHON_CAPSULE_INTEGER_SET "mlir.ir.IntegerSet._CAPIPtr"
|
||||
#define MLIR_PYTHON_CAPSULE_LOCATION "mlir.ir.Location._CAPIPtr"
|
||||
#define MLIR_PYTHON_CAPSULE_MODULE "mlir.ir.Module._CAPIPtr"
|
||||
#define MLIR_PYTHON_CAPSULE_OPERATION "mlir.ir.Operation._CAPIPtr"
|
||||
|
@ -240,6 +242,25 @@ static inline MlirAffineMap mlirPythonCapsuleToAffineMap(PyObject *capsule) {
|
|||
return affineMap;
|
||||
}
|
||||
|
||||
/** Creates a capsule object encapsulating the raw C-API MlirIntegerSet.
|
||||
* The returned capsule does not extend or affect ownership of any Python
|
||||
* objects that reference the set in any way. */
|
||||
static inline PyObject *
|
||||
mlirPythonIntegerSetToCapsule(MlirIntegerSet integerSet) {
|
||||
return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(integerSet),
|
||||
MLIR_PYTHON_CAPSULE_INTEGER_SET, NULL);
|
||||
}
|
||||
|
||||
/** Extracts an MlirIntegerSet from a capsule as produced from
|
||||
* mlirPythonIntegerSetToCapsule. If the capsule is not of the right type, then
|
||||
* a null set is returned (as checked via mlirIntegerSetIsNull). In such a
|
||||
* case, the Python APIs will have already set an error. */
|
||||
static inline MlirIntegerSet mlirPythonCapsuleToIntegerSet(PyObject *capsule) {
|
||||
void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_INTEGER_SET);
|
||||
MlirIntegerSet integerSet = {ptr};
|
||||
return integerSet;
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include "mlir-c/Bindings/Python/Interop.h"
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir-c/IntegerSet.h"
|
||||
#include "mlir-c/Registration.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include <pybind11/stl.h>
|
||||
|
@ -3331,6 +3332,102 @@ PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
|
|||
rawAffineMap);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// PyIntegerSet and utilities.
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
class PyIntegerSetConstraint {
|
||||
public:
|
||||
PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {}
|
||||
|
||||
PyAffineExpr getExpr() {
|
||||
return PyAffineExpr(set.getContext(),
|
||||
mlirIntegerSetGetConstraint(set, pos));
|
||||
}
|
||||
|
||||
bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
|
||||
|
||||
static void bind(py::module &m) {
|
||||
py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint")
|
||||
.def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
|
||||
.def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
|
||||
}
|
||||
|
||||
private:
|
||||
PyIntegerSet set;
|
||||
intptr_t pos;
|
||||
};
|
||||
|
||||
class PyIntegerSetConstraintList
|
||||
: public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
|
||||
public:
|
||||
static constexpr const char *pyClassName = "IntegerSetConstraintList";
|
||||
|
||||
PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0,
|
||||
intptr_t length = -1, intptr_t step = 1)
|
||||
: Sliceable(startIndex,
|
||||
length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
|
||||
step),
|
||||
set(set) {}
|
||||
|
||||
intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); }
|
||||
|
||||
PyIntegerSetConstraint getElement(intptr_t pos) {
|
||||
return PyIntegerSetConstraint(set, pos);
|
||||
}
|
||||
|
||||
PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
|
||||
intptr_t step) {
|
||||
return PyIntegerSetConstraintList(set, startIndex, length, step);
|
||||
}
|
||||
|
||||
private:
|
||||
PyIntegerSet set;
|
||||
};
|
||||
|
||||
bool PyIntegerSet::operator==(const PyIntegerSet &other) {
|
||||
return mlirIntegerSetEqual(integerSet, other.integerSet);
|
||||
}
|
||||
|
||||
py::object PyIntegerSet::getCapsule() {
|
||||
return py::reinterpret_steal<py::object>(
|
||||
mlirPythonIntegerSetToCapsule(*this));
|
||||
}
|
||||
|
||||
PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
|
||||
MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
|
||||
if (mlirIntegerSetIsNull(rawIntegerSet))
|
||||
throw py::error_already_set();
|
||||
return PyIntegerSet(
|
||||
PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
|
||||
rawIntegerSet);
|
||||
}
|
||||
|
||||
/// Attempts to populate `result` with the content of `list` casted to the
|
||||
/// appropriate type (Python and C types are provided as template arguments).
|
||||
/// Throws errors in case of failure, using "action" to describe what the caller
|
||||
/// was attempting to do.
|
||||
template <typename PyType, typename CType>
|
||||
static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result,
|
||||
StringRef action) {
|
||||
result.reserve(py::len(list));
|
||||
for (py::handle item : list) {
|
||||
try {
|
||||
result.push_back(item.cast<PyType>());
|
||||
} catch (py::cast_error &err) {
|
||||
std::string msg = (llvm::Twine("Invalid expression when ") + action +
|
||||
" (" + err.what() + ")")
|
||||
.str();
|
||||
throw py::cast_error(msg);
|
||||
} catch (py::reference_cast_error &err) {
|
||||
std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
|
||||
action + " (" + err.what() + ")")
|
||||
.str();
|
||||
throw py::cast_error(msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Populates the pybind11 IR submodule.
|
||||
//------------------------------------------------------------------------------
|
||||
|
@ -4152,24 +4249,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
[](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
|
||||
DefaultingPyMlirContext context) {
|
||||
SmallVector<MlirAffineExpr> affineExprs;
|
||||
affineExprs.reserve(py::len(exprs));
|
||||
for (py::handle expr : exprs) {
|
||||
try {
|
||||
affineExprs.push_back(expr.cast<PyAffineExpr>());
|
||||
} catch (py::cast_error &err) {
|
||||
std::string msg =
|
||||
std::string("Invalid expression when attempting to create "
|
||||
"an AffineMap (") +
|
||||
err.what() + ")";
|
||||
throw py::cast_error(msg);
|
||||
} catch (py::reference_cast_error &err) {
|
||||
std::string msg =
|
||||
std::string("Invalid expression (None?) when attempting to "
|
||||
"create an AffineMap (") +
|
||||
err.what() + ")";
|
||||
throw py::cast_error(msg);
|
||||
}
|
||||
}
|
||||
pyListToVector<PyAffineExpr, MlirAffineExpr>(
|
||||
exprs, affineExprs, "attempting to create an AffineMap");
|
||||
MlirAffineMap map =
|
||||
mlirAffineMapGet(context->get(), dimCount, symbolCount,
|
||||
affineExprs.size(), affineExprs.data());
|
||||
|
@ -4275,4 +4356,125 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
return PyAffineMapExprList(self);
|
||||
});
|
||||
PyAffineMapExprList::bind(m);
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
// Mapping of PyIntegerSet.
|
||||
//----------------------------------------------------------------------------
|
||||
py::class_<PyIntegerSet>(m, "IntegerSet")
|
||||
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
|
||||
&PyIntegerSet::getCapsule)
|
||||
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
|
||||
.def("__eq__", [](PyIntegerSet &self,
|
||||
PyIntegerSet &other) { return self == other; })
|
||||
.def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
|
||||
.def("__str__",
|
||||
[](PyIntegerSet &self) {
|
||||
PyPrintAccumulator printAccum;
|
||||
mlirIntegerSetPrint(self, printAccum.getCallback(),
|
||||
printAccum.getUserData());
|
||||
return printAccum.join();
|
||||
})
|
||||
.def("__repr__",
|
||||
[](PyIntegerSet &self) {
|
||||
PyPrintAccumulator printAccum;
|
||||
printAccum.parts.append("IntegerSet(");
|
||||
mlirIntegerSetPrint(self, printAccum.getCallback(),
|
||||
printAccum.getUserData());
|
||||
printAccum.parts.append(")");
|
||||
return printAccum.join();
|
||||
})
|
||||
.def_property_readonly(
|
||||
"context",
|
||||
[](PyIntegerSet &self) { return self.getContext().getObject(); })
|
||||
.def(
|
||||
"dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
|
||||
kDumpDocstring)
|
||||
.def_static(
|
||||
"get",
|
||||
[](intptr_t numDims, intptr_t numSymbols, py::list exprs,
|
||||
std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
|
||||
if (exprs.size() != eqFlags.size())
|
||||
throw py::value_error(
|
||||
"Expected the number of constraints to match "
|
||||
"that of equality flags");
|
||||
if (exprs.empty())
|
||||
throw py::value_error("Expected non-empty list of constraints");
|
||||
|
||||
// Copy over to a SmallVector because std::vector has a
|
||||
// specialization for booleans that packs data and does not
|
||||
// expose a `bool *`.
|
||||
SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());
|
||||
|
||||
SmallVector<MlirAffineExpr> affineExprs;
|
||||
pyListToVector<PyAffineExpr>(exprs, affineExprs,
|
||||
"attempting to create an IntegerSet");
|
||||
MlirIntegerSet set = mlirIntegerSetGet(
|
||||
context->get(), numDims, numSymbols, exprs.size(),
|
||||
affineExprs.data(), flags.data());
|
||||
return PyIntegerSet(context->getRef(), set);
|
||||
},
|
||||
py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
|
||||
py::arg("eq_flags"), py::arg("context") = py::none())
|
||||
.def_static(
|
||||
"get_empty",
|
||||
[](intptr_t numDims, intptr_t numSymbols,
|
||||
DefaultingPyMlirContext context) {
|
||||
MlirIntegerSet set =
|
||||
mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
|
||||
return PyIntegerSet(context->getRef(), set);
|
||||
},
|
||||
py::arg("num_dims"), py::arg("num_symbols"),
|
||||
py::arg("context") = py::none())
|
||||
.def("get_replaced",
|
||||
[](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
|
||||
intptr_t numResultDims, intptr_t numResultSymbols) {
|
||||
if (static_cast<intptr_t>(dimExprs.size()) !=
|
||||
mlirIntegerSetGetNumDims(self))
|
||||
throw py::value_error(
|
||||
"Expected the number of dimension replacement expressions "
|
||||
"to match that of dimensions");
|
||||
if (static_cast<intptr_t>(symbolExprs.size()) !=
|
||||
mlirIntegerSetGetNumSymbols(self))
|
||||
throw py::value_error(
|
||||
"Expected the number of symbol replacement expressions "
|
||||
"to match that of symbols");
|
||||
|
||||
SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
|
||||
pyListToVector<PyAffineExpr>(
|
||||
dimExprs, dimAffineExprs,
|
||||
"attempting to create an IntegerSet by replacing dimensions");
|
||||
pyListToVector<PyAffineExpr>(
|
||||
symbolExprs, symbolAffineExprs,
|
||||
"attempting to create an IntegerSet by replacing symbols");
|
||||
MlirIntegerSet set = mlirIntegerSetReplaceGet(
|
||||
self, dimAffineExprs.data(), symbolAffineExprs.data(),
|
||||
numResultDims, numResultSymbols);
|
||||
return PyIntegerSet(self.getContext(), set);
|
||||
})
|
||||
.def_property_readonly("is_canonical_empty",
|
||||
[](PyIntegerSet &self) {
|
||||
return mlirIntegerSetIsCanonicalEmpty(self);
|
||||
})
|
||||
.def_property_readonly(
|
||||
"n_dims",
|
||||
[](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
|
||||
.def_property_readonly(
|
||||
"n_symbols",
|
||||
[](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
|
||||
.def_property_readonly(
|
||||
"n_inputs",
|
||||
[](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
|
||||
.def_property_readonly("n_equalities",
|
||||
[](PyIntegerSet &self) {
|
||||
return mlirIntegerSetGetNumEqualities(self);
|
||||
})
|
||||
.def_property_readonly("n_inequalities",
|
||||
[](PyIntegerSet &self) {
|
||||
return mlirIntegerSetGetNumInequalities(self);
|
||||
})
|
||||
.def_property_readonly("constraints", [](PyIntegerSet &self) {
|
||||
return PyIntegerSetConstraintList(self);
|
||||
});
|
||||
PyIntegerSetConstraint::bind(m);
|
||||
PyIntegerSetConstraintList::bind(m);
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "mlir-c/AffineExpr.h"
|
||||
#include "mlir-c/AffineMap.h"
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/IntegerSet.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
||||
namespace mlir {
|
||||
|
@ -726,6 +727,26 @@ private:
|
|||
MlirAffineMap affineMap;
|
||||
};
|
||||
|
||||
class PyIntegerSet : public BaseContextObject {
|
||||
public:
|
||||
PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet)
|
||||
: BaseContextObject(std::move(contextRef)), integerSet(integerSet) {}
|
||||
bool operator==(const PyIntegerSet &other);
|
||||
operator MlirIntegerSet() const { return integerSet; }
|
||||
MlirIntegerSet get() const { return integerSet; }
|
||||
|
||||
/// Gets a capsule wrapping the void* within the MlirIntegerSet.
|
||||
pybind11::object getCapsule();
|
||||
|
||||
/// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule.
|
||||
/// Note that PyIntegerSet instances may be uniqued, so the returned object
|
||||
/// may be a pre-existing object. Integer sets are owned by the context.
|
||||
static PyIntegerSet createFromCapsule(pybind11::object capsule);
|
||||
|
||||
private:
|
||||
MlirIntegerSet integerSet;
|
||||
};
|
||||
|
||||
void populateIRSubmodule(pybind11::module &m);
|
||||
|
||||
} // namespace python
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
import gc
|
||||
from mlir.ir import *
|
||||
|
||||
def run(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
f()
|
||||
gc.collect()
|
||||
assert Context._get_live_count() == 0
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testIntegerSetCapsule
|
||||
def testIntegerSetCapsule():
|
||||
with Context() as ctx:
|
||||
is1 = IntegerSet.get_empty(1, 1, ctx)
|
||||
capsule = is1._CAPIPtr
|
||||
# CHECK: mlir.ir.IntegerSet._CAPIPtr
|
||||
print(capsule)
|
||||
is2 = IntegerSet._CAPICreate(capsule)
|
||||
assert is1 == is2
|
||||
assert is2.context is ctx
|
||||
|
||||
run(testIntegerSetCapsule)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testIntegerSetGet
|
||||
def testIntegerSetGet():
|
||||
with Context():
|
||||
d0 = AffineDimExpr.get(0)
|
||||
d1 = AffineDimExpr.get(1)
|
||||
s0 = AffineSymbolExpr.get(0)
|
||||
c42 = AffineConstantExpr.get(42)
|
||||
|
||||
# CHECK: (d0, d1)[s0] : (d0 - d1 == 0, s0 - 42 >= 0)
|
||||
set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42], [True, False])
|
||||
print(set0)
|
||||
|
||||
# CHECK: (d0)[s0] : (1 == 0)
|
||||
set1 = IntegerSet.get_empty(1, 1)
|
||||
print(set1)
|
||||
|
||||
# CHECK: (d0)[s0, s1] : (d0 - s1 == 0, s0 - 42 >= 0)
|
||||
set2 = set0.get_replaced([d0, AffineSymbolExpr.get(1)], [s0], 1, 2)
|
||||
print(set2)
|
||||
|
||||
try:
|
||||
IntegerSet.get(2, 1, [], [])
|
||||
except ValueError as e:
|
||||
# CHECK: Expected non-empty list of constraints
|
||||
print(e)
|
||||
|
||||
try:
|
||||
IntegerSet.get(2, 1, [d0 - d1], [True, False])
|
||||
except ValueError as e:
|
||||
# CHECK: Expected the number of constraints to match that of equality flags
|
||||
print(e)
|
||||
|
||||
try:
|
||||
IntegerSet.get(2, 1, [0], [True])
|
||||
except RuntimeError as e:
|
||||
# CHECK: Invalid expression when attempting to create an IntegerSet
|
||||
print(e)
|
||||
|
||||
try:
|
||||
IntegerSet.get(2, 1, [None], [True])
|
||||
except RuntimeError as e:
|
||||
# CHECK: Invalid expression (None?) when attempting to create an IntegerSet
|
||||
print(e)
|
||||
|
||||
try:
|
||||
set0.get_replaced([d0], [s0], 1, 1)
|
||||
except ValueError as e:
|
||||
# CHECK: Expected the number of dimension replacement expressions to match that of dimensions
|
||||
print(e)
|
||||
|
||||
try:
|
||||
set0.get_replaced([d0, d1], [s0, s0], 1, 1)
|
||||
except ValueError as e:
|
||||
# CHECK: Expected the number of symbol replacement expressions to match that of symbols
|
||||
print(e)
|
||||
|
||||
try:
|
||||
set0.get_replaced([d0, 1], [s0], 1, 1)
|
||||
except RuntimeError as e:
|
||||
# CHECK: Invalid expression when attempting to create an IntegerSet by replacing dimensions
|
||||
print(e)
|
||||
|
||||
try:
|
||||
set0.get_replaced([d0, d1], [None], 1, 1)
|
||||
except RuntimeError as e:
|
||||
# CHECK: Invalid expression (None?) when attempting to create an IntegerSet by replacing symbols
|
||||
print(e)
|
||||
|
||||
run(testIntegerSetGet)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testIntegerSetProperties
|
||||
def testIntegerSetProperties():
|
||||
with Context():
|
||||
d0 = AffineDimExpr.get(0)
|
||||
d1 = AffineDimExpr.get(1)
|
||||
s0 = AffineSymbolExpr.get(0)
|
||||
c42 = AffineConstantExpr.get(42)
|
||||
|
||||
set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42, s0 - d0], [True, False, False])
|
||||
# CHECK: 2
|
||||
print(set0.n_dims)
|
||||
# CHECK: 1
|
||||
print(set0.n_symbols)
|
||||
# CHECK: 3
|
||||
print(set0.n_inputs)
|
||||
# CHECK: 1
|
||||
print(set0.n_equalities)
|
||||
# CHECK: 2
|
||||
print(set0.n_inequalities)
|
||||
|
||||
# CHECK: 3
|
||||
print(len(set0.constraints))
|
||||
|
||||
# CHECK-DAG: d0 - d1 == 0
|
||||
# CHECK-DAG: s0 - 42 >= 0
|
||||
# CHECK-DAG: -d0 + s0 >= 0
|
||||
for cstr in set0.constraints:
|
||||
print(cstr.expr, end='')
|
||||
print(" == 0" if cstr.is_eq else " >= 0")
|
||||
|
||||
run(testIntegerSetProperties)
|
Loading…
Reference in New Issue