[mlir][python] Provide some methods and properties for API completeness

When writing the user-facing documentation, I noticed several inconsistencies
and asymmetries in the Python API we provide. Fix them by adding:

- the `owner` property to regions, similarly to blocks;
- the `isinstance` method to any class derived from `PyConcreteAttr`,
  `PyConcreteValue` and `PyConreteAffineExpr`, similar to `PyConcreteType` to
  enable `isa`-like calls without having to handle exceptions;
- a mechanism to create the first block in the region as we could only create
  blocks relative to other blocks, with is impossible in an empty region.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D111556
This commit is contained in:
Alex Zinenko 2021-10-11 18:24:48 +02:00
parent e845ca2ff1
commit 78f2dae00d
9 changed files with 183 additions and 74 deletions

View File

@ -612,8 +612,22 @@ operations (unlike in C++ that supports detached regions).
Blocks can be created within a given region and inserted before or after another
block of the same region using `create_before()`, `create_after()` methods of
the `Block` class. They are not expected to exist outside of regions (unlike in
C++ that supports detached blocks).
the `Block` class, or the `create_at_start()` static method of the same class.
They are not expected to exist outside of regions (unlike in C++ that supports
detached blocks).
```python
from mlir.ir import Block, Context, Operation
with Context():
op = Operation.create("generic.op", regions=1)
# Create the first block in the region.
entry_block = Block.create_at_start(op.regions[0])
# Create further blocks.
other_block = entry_block.create_after()
```
Blocks can be used to create `InsertionPoint`s, which can point to the beginning
or the end of the block, or just before its terminator. It is common for

View File

@ -99,6 +99,9 @@ public:
static void bind(py::module &m) {
auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
cls.def(py::init<PyAffineExpr &>());
cls.def_static("isinstance", [](PyAffineExpr &otherAffineExpr) -> bool {
return DerivedTy::isaFunction(otherAffineExpr);
});
DerivedTy::bindDerived(cls);
}

View File

@ -1548,6 +1548,9 @@ public:
static void bind(py::module &m) {
auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>());
cls.def_static("isinstance", [](PyValue &otherValue) -> bool {
return DerivedTy::isaFunction(otherValue);
});
DerivedTy::bindDerived(cls);
}
@ -2248,6 +2251,12 @@ void mlir::python::populateIRCore(py::module &m) {
return PyBlockList(self.getParentOperation(), self.get());
},
"Returns a forward-optimized sequence of blocks.")
.def_property_readonly(
"owner",
[](PyRegion &self) {
return self.getParentOperation()->createOpView();
},
"Returns the operation owning this region.")
.def(
"__iter__",
[](PyRegion &self) {
@ -2291,6 +2300,23 @@ void mlir::python::populateIRCore(py::module &m) {
return PyOperationList(self.getParentOperation(), self.get());
},
"Returns a forward-optimized sequence of operations.")
.def_static(
"create_at_start",
[](PyRegion &parent, py::list pyArgTypes) {
parent.checkValid();
llvm::SmallVector<MlirType, 4> argTypes;
argTypes.reserve(pyArgTypes.size());
for (auto &pyArg : pyArgTypes) {
argTypes.push_back(pyArg.cast<PyType &>());
}
MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
mlirRegionInsertOwnedBlock(parent, 0, block);
return PyBlock(parent.getParentOperation(), block);
},
py::arg("parent"), py::arg("pyArgTypes") = py::list(),
"Creates and returns a new Block at the beginning of the given "
"region (with given argument types).")
.def(
"create_before",
[](PyBlock &self, py::args pyArgTypes) {

View File

@ -533,6 +533,7 @@ public:
: parentOperation(std::move(parentOperation)), region(region) {
assert(!mlirRegionIsNull(region) && "python region cannot be null");
}
operator MlirRegion() const { return region; }
MlirRegion get() { return region; }
PyOperationRef &getParentOperation() { return parentOperation; }
@ -681,6 +682,9 @@ public:
auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(),
pybind11::module_local());
cls.def(pybind11::init<PyAttribute &>(), pybind11::keep_alive<0, 1>());
cls.def_static("isinstance", [](PyAttribute &otherAttr) -> bool {
return DerivedTy::isaFunction(otherAttr);
});
DerivedTy::bindDerived(cls);
}
@ -764,6 +768,7 @@ class PyValue {
public:
PyValue(PyOperationRef parentOperation, MlirValue value)
: parentOperation(parentOperation), value(value) {}
operator MlirValue() const { return value; }
MlirValue get() { return value; }
PyOperationRef &getParentOperation() { return parentOperation; }

View File

@ -8,9 +8,11 @@ def run(f):
f()
gc.collect()
assert Context._get_live_count() == 0
return f
# CHECK-LABEL: TEST: testAffineExprCapsule
@run
def testAffineExprCapsule():
with Context() as ctx:
affine_expr = AffineExpr.get_constant(42)
@ -24,10 +26,9 @@ def testAffineExprCapsule():
assert affine_expr == affine_expr_2
assert affine_expr_2.context == ctx
run(testAffineExprCapsule)
# CHECK-LABEL: TEST: testAffineExprEq
@run
def testAffineExprEq():
with Context():
a1 = AffineExpr.get_constant(42)
@ -44,10 +45,9 @@ def testAffineExprEq():
# CHECK: False
print(a1 == "foo")
run(testAffineExprEq)
# CHECK-LABEL: TEST: testAffineExprContext
@run
def testAffineExprContext():
with Context():
a1 = AffineExpr.get_constant(42)
@ -61,6 +61,7 @@ run(testAffineExprContext)
# CHECK-LABEL: TEST: testAffineExprConstant
@run
def testAffineExprConstant():
with Context():
a1 = AffineExpr.get_constant(42)
@ -77,10 +78,9 @@ def testAffineExprConstant():
assert a1 == a2
run(testAffineExprConstant)
# CHECK-LABEL: TEST: testAffineExprDim
@run
def testAffineExprDim():
with Context():
d1 = AffineExpr.get_dim(1)
@ -100,10 +100,9 @@ def testAffineExprDim():
assert d1 == d11
assert d1 != d2
run(testAffineExprDim)
# CHECK-LABEL: TEST: testAffineExprSymbol
@run
def testAffineExprSymbol():
with Context():
s1 = AffineExpr.get_symbol(1)
@ -123,10 +122,9 @@ def testAffineExprSymbol():
assert s1 == s11
assert s1 != s2
run(testAffineExprSymbol)
# CHECK-LABEL: TEST: testAffineAddExpr
@run
def testAffineAddExpr():
with Context():
d1 = AffineDimExpr.get(1)
@ -143,10 +141,9 @@ def testAffineAddExpr():
assert d12.lhs == d1
assert d12.rhs == d2
run(testAffineAddExpr)
# CHECK-LABEL: TEST: testAffineMulExpr
@run
def testAffineMulExpr():
with Context():
d1 = AffineDimExpr.get(1)
@ -163,10 +160,9 @@ def testAffineMulExpr():
assert expr.lhs == d1
assert expr.rhs == c2
run(testAffineMulExpr)
# CHECK-LABEL: TEST: testAffineModExpr
@run
def testAffineModExpr():
with Context():
d1 = AffineDimExpr.get(1)
@ -183,10 +179,9 @@ def testAffineModExpr():
assert expr.lhs == d1
assert expr.rhs == c2
run(testAffineModExpr)
# CHECK-LABEL: TEST: testAffineFloorDivExpr
@run
def testAffineFloorDivExpr():
with Context():
d1 = AffineDimExpr.get(1)
@ -198,10 +193,9 @@ def testAffineFloorDivExpr():
assert expr.lhs == d1
assert expr.rhs == c2
run(testAffineFloorDivExpr)
# CHECK-LABEL: TEST: testAffineCeilDivExpr
@run
def testAffineCeilDivExpr():
with Context():
d1 = AffineDimExpr.get(1)
@ -213,10 +207,9 @@ def testAffineCeilDivExpr():
assert expr.lhs == d1
assert expr.rhs == c2
run(testAffineCeilDivExpr)
# CHECK-LABEL: TEST: testAffineExprSub
@run
def testAffineExprSub():
with Context():
d1 = AffineDimExpr.get(1)
@ -232,9 +225,8 @@ def testAffineExprSub():
# CHECK: -1
print(rhs.rhs)
run(testAffineExprSub)
# CHECK-LABEL: TEST: testClassHierarchy
@run
def testClassHierarchy():
with Context():
d1 = AffineDimExpr.get(1)
@ -272,4 +264,28 @@ def testClassHierarchy():
# CHECK: Cannot cast affine expression to AffineBinaryExpr
print(e)
run(testClassHierarchy)
# CHECK-LABEL: TEST: testIsInstance
@run
def testIsInstance():
with Context():
d1 = AffineDimExpr.get(1)
c2 = AffineConstantExpr.get(2)
add = AffineAddExpr.get(d1, c2)
mul = AffineMulExpr.get(d1, c2)
# CHECK: True
print(AffineDimExpr.isinstance(d1))
# CHECK: False
print(AffineConstantExpr.isinstance(d1))
# CHECK: True
print(AffineConstantExpr.isinstance(c2))
# CHECK: False
print(AffineMulExpr.isinstance(c2))
# CHECK: True
print(AffineAddExpr.isinstance(add))
# CHECK: False
print(AffineMulExpr.isinstance(add))
# CHECK: True
print(AffineMulExpr.isinstance(mul))
# CHECK: False
print(AffineAddExpr.isinstance(mul))

View File

@ -89,6 +89,18 @@ def testAttrCast():
print("a1 == a2:", a1 == a2)
# CHECK-LABEL: TEST: testAttrIsInstance
@run
def testAttrIsInstance():
with Context():
a1 = Attribute.parse("42")
a2 = Attribute.parse("[42]")
assert IntegerAttr.isinstance(a1)
assert not IntegerAttr.isinstance(a2)
assert not ArrayAttr.isinstance(a1)
assert ArrayAttr.isinstance(a2)
# CHECK-LABEL: TEST: testAttrEqDoesNotRaise
@run
def testAttrEqDoesNotRaise():

View File

@ -51,3 +51,22 @@ def testBlockCreation():
print(module.operation)
# Ensure region back references are coherent.
assert entry_block.region == middle_block.region == successor_block.region
# CHECK-LABEL: TEST: testFirstBlockCreation
# CHECK: func @test(%{{.*}}: f32)
# CHECK: return
@run
def testFirstBlockCreation():
with Context() as ctx, Location.unknown():
module = Module.create()
f32 = F32Type.get()
with InsertionPoint(module.body):
func = builtin.FuncOp("test", ([f32], []))
entry_block = Block.create_at_start(func.operation.regions[0], [f32])
with InsertionPoint(entry_block):
std.ReturnOp([])
print(module)
assert module.operation.verify()
assert func.body.blocks[0] == entry_block

View File

@ -11,10 +11,12 @@ def run(f):
f()
gc.collect()
assert Context._get_live_count() == 0
return f
# Verify iterator based traversal of the op/region/block hierarchy.
# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
@run
def testTraverseOpRegionBlockIterators():
ctx = Context()
ctx.allow_unregistered_dialects = True
@ -69,11 +71,9 @@ def testTraverseOpRegionBlockIterators():
walk_operations("", op)
run(testTraverseOpRegionBlockIterators)
# Verify index based traversal of the op/region/block hierarchy.
# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices
@run
def testTraverseOpRegionBlockIndices():
ctx = Context()
ctx.allow_unregistered_dialects = True
@ -111,10 +111,30 @@ def testTraverseOpRegionBlockIndices():
walk_operations("", module.operation)
run(testTraverseOpRegionBlockIndices)
# CHECK-LABEL: TEST: testBlockAndRegionOwners
@run
def testBlockAndRegionOwners():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
builtin.module {
builtin.func @f() {
std.return
}
}
""", ctx)
assert module.operation.regions[0].owner == module.operation
assert module.operation.regions[0].blocks[0].owner == module.operation
func = module.body.operations[0]
assert func.operation.regions[0].owner == func
assert func.operation.regions[0].blocks[0].owner == func
# CHECK-LABEL: TEST: testBlockArgumentList
@run
def testBlockArgumentList():
with Context() as ctx:
module = Module.parse(
@ -158,10 +178,8 @@ def testBlockArgumentList():
print("Type: ", t)
run(testBlockArgumentList)
# CHECK-LABEL: TEST: testOperationOperands
@run
def testOperationOperands():
with Context() as ctx:
ctx.allow_unregistered_dialects = True
@ -181,10 +199,10 @@ def testOperationOperands():
print(f"Operand {i}, type {operand.type}")
run(testOperationOperands)
# CHECK-LABEL: TEST: testOperationOperandsSlice
@run
def testOperationOperandsSlice():
with Context() as ctx:
ctx.allow_unregistered_dialects = True
@ -239,10 +257,10 @@ def testOperationOperandsSlice():
print(operand)
run(testOperationOperandsSlice)
# CHECK-LABEL: TEST: testOperationOperandsSet
@run
def testOperationOperandsSet():
with Context() as ctx, Location.unknown(ctx):
ctx.allow_unregistered_dialects = True
@ -271,10 +289,10 @@ def testOperationOperandsSet():
print(consumer.operands[0])
run(testOperationOperandsSet)
# CHECK-LABEL: TEST: testDetachedOperation
@run
def testDetachedOperation():
ctx = Context()
ctx.allow_unregistered_dialects = True
@ -295,10 +313,8 @@ def testDetachedOperation():
# TODO: Check successors once enough infra exists to do it properly.
run(testDetachedOperation)
# CHECK-LABEL: TEST: testOperationInsertionPoint
@run
def testOperationInsertionPoint():
ctx = Context()
ctx.allow_unregistered_dialects = True
@ -335,10 +351,8 @@ def testOperationInsertionPoint():
assert False, "expected insert of attached op to raise"
run(testOperationInsertionPoint)
# CHECK-LABEL: TEST: testOperationWithRegion
@run
def testOperationWithRegion():
ctx = Context()
ctx.allow_unregistered_dialects = True
@ -377,10 +391,8 @@ def testOperationWithRegion():
print(module)
run(testOperationWithRegion)
# CHECK-LABEL: TEST: testOperationResultList
@run
def testOperationResultList():
ctx = Context()
module = Module.parse(
@ -407,10 +419,10 @@ def testOperationResultList():
print(f"Result type {t}")
run(testOperationResultList)
# CHECK-LABEL: TEST: testOperationResultListSlice
@run
def testOperationResultListSlice():
with Context() as ctx:
ctx.allow_unregistered_dialects = True
@ -458,10 +470,10 @@ def testOperationResultListSlice():
print(f"Result {res.result_number}, type {res.type}")
run(testOperationResultListSlice)
# CHECK-LABEL: TEST: testOperationAttributes
@run
def testOperationAttributes():
ctx = Context()
ctx.allow_unregistered_dialects = True
@ -506,10 +518,10 @@ def testOperationAttributes():
assert False, "expected IndexError on accessing an out-of-bounds attribute"
run(testOperationAttributes)
# CHECK-LABEL: TEST: testOperationPrint
@run
def testOperationPrint():
ctx = Context()
module = Module.parse(
@ -553,10 +565,10 @@ def testOperationPrint():
use_local_scope=True)
run(testOperationPrint)
# CHECK-LABEL: TEST: testKnownOpView
@run
def testKnownOpView():
with Context(), Location.unknown():
Context.current.allow_unregistered_dialects = True
@ -586,10 +598,8 @@ def testKnownOpView():
print(repr(custom))
run(testKnownOpView)
# CHECK-LABEL: TEST: testSingleResultProperty
@run
def testSingleResultProperty():
with Context(), Location.unknown():
Context.current.allow_unregistered_dialects = True
@ -620,10 +630,8 @@ def testSingleResultProperty():
print(module.body.operations[2])
run(testSingleResultProperty)
# CHECK-LABEL: TEST: testPrintInvalidOperation
@run
def testPrintInvalidOperation():
ctx = Context()
with Location.unknown(ctx):
@ -639,10 +647,8 @@ def testPrintInvalidOperation():
print(f".verify = {module.operation.verify()}")
run(testPrintInvalidOperation)
# CHECK-LABEL: TEST: testCreateWithInvalidAttributes
@run
def testCreateWithInvalidAttributes():
ctx = Context()
with Location.unknown(ctx):
@ -670,10 +676,8 @@ def testCreateWithInvalidAttributes():
print(e)
run(testCreateWithInvalidAttributes)
# CHECK-LABEL: TEST: testOperationName
@run
def testOperationName():
ctx = Context()
ctx.allow_unregistered_dialects = True
@ -691,10 +695,8 @@ def testOperationName():
print(op.operation.name)
run(testOperationName)
# CHECK-LABEL: TEST: testCapsuleConversions
@run
def testCapsuleConversions():
ctx = Context()
ctx.allow_unregistered_dialects = True
@ -706,10 +708,8 @@ def testCapsuleConversions():
assert m2 is m
run(testCapsuleConversions)
# CHECK-LABEL: TEST: testOperationErase
@run
def testOperationErase():
ctx = Context()
ctx.allow_unregistered_dialects = True
@ -728,6 +728,3 @@ def testOperationErase():
# Ensure we can create another operation
Operation.create("custom.op2")
run(testOperationErase)

View File

@ -9,9 +9,11 @@ def run(f):
f()
gc.collect()
assert Context._get_live_count() == 0
return f
# CHECK-LABEL: TEST: testCapsuleConversions
@run
def testCapsuleConversions():
ctx = Context()
ctx.allow_unregistered_dialects = True
@ -24,10 +26,8 @@ def testCapsuleConversions():
assert value2 == value
run(testCapsuleConversions)
# CHECK-LABEL: TEST: testOpResultOwner
@run
def testOpResultOwner():
ctx = Context()
ctx.allow_unregistered_dialects = True
@ -37,4 +37,21 @@ def testOpResultOwner():
assert op.result.owner == op
run(testOpResultOwner)
# CHECK-LABEL: TEST: testValueIsInstance
@run
def testValueIsInstance():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func @foo(%arg0: f32) {
%0 = "some_dialect.some_op"() : () -> f64
return
}""", ctx)
func = module.body.operations[0]
assert BlockArgument.isinstance(func.regions[0].blocks[0].arguments[0])
assert not OpResult.isinstance(func.regions[0].blocks[0].arguments[0])
op = func.regions[0].blocks[0].operations[0]
assert not BlockArgument.isinstance(op.results[0])
assert OpResult.isinstance(op.results[0])